test_sqlalchemy.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from pathlib import Path
  2. from unittest import mock
  3. import pytest
  4. from sqlalchemy import select
  5. from sqlalchemy.exc import OperationalError
  6. from sqlalchemy.orm import (
  7. DeclarativeBase,
  8. Mapped,
  9. MappedAsDataclass,
  10. declared_attr,
  11. mapped_column,
  12. )
  13. import reflex.constants
  14. import reflex.model
  15. from reflex.model import Model, ModelRegistry, sqla_session
  16. @pytest.mark.filterwarnings(
  17. "ignore:This declarative base already contains a class with the same class name",
  18. )
  19. def test_automigration(
  20. tmp_working_dir: Path,
  21. monkeypatch: pytest.MonkeyPatch,
  22. model_registry: type[ModelRegistry],
  23. ):
  24. """Test alembic automigration with add and drop table and column.
  25. Args:
  26. tmp_working_dir: directory where database and migrations are stored
  27. monkeypatch: pytest fixture to overwrite attributes
  28. model_registry: clean reflex ModelRegistry
  29. """
  30. alembic_ini = tmp_working_dir / "alembic.ini"
  31. versions = tmp_working_dir / "alembic" / "versions"
  32. monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
  33. config_mock = mock.Mock()
  34. config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
  35. monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
  36. assert alembic_ini.exists() is False
  37. assert versions.exists() is False
  38. Model.alembic_init()
  39. assert alembic_ini.exists()
  40. assert versions.exists()
  41. class Base(DeclarativeBase):
  42. @declared_attr.directive
  43. def __tablename__(cls) -> str:
  44. return cls.__name__.lower()
  45. assert model_registry.register(Base)
  46. class ModelBase(Base, MappedAsDataclass):
  47. __abstract__ = True
  48. id: Mapped[int | None] = mapped_column(primary_key=True, default=None)
  49. # initial table
  50. class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
  51. t1: Mapped[str] = mapped_column(default="")
  52. with Model.get_db_engine().connect() as connection:
  53. assert Model.alembic_autogenerate(
  54. connection=connection, message="Initial Revision"
  55. )
  56. assert Model.migrate()
  57. version_scripts = list(versions.glob("*.py"))
  58. assert len(version_scripts) == 1
  59. assert version_scripts[0].name.endswith("initial_revision.py")
  60. with sqla_session() as session:
  61. session.add(AlembicThing(t1="foo"))
  62. session.commit()
  63. model_registry.get_metadata().clear()
  64. # Create column t2, mark t1 as optional with default
  65. class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
  66. t1: Mapped[str | None] = mapped_column(default="default")
  67. t2: Mapped[str] = mapped_column(default="bar")
  68. assert Model.migrate(autogenerate=True)
  69. assert len(list(versions.glob("*.py"))) == 2
  70. with sqla_session() as session:
  71. session.add(AlembicThing(t2="baz"))
  72. session.commit()
  73. result = session.scalars(select(AlembicThing)).all()
  74. assert len(result) == 2
  75. assert result[0].t1 == "foo"
  76. assert result[0].t2 == "bar"
  77. assert result[1].t1 == "default"
  78. assert result[1].t2 == "baz"
  79. model_registry.get_metadata().clear()
  80. # Drop column t1
  81. class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
  82. t2: Mapped[str] = mapped_column(default="bar")
  83. assert Model.migrate(autogenerate=True)
  84. assert len(list(versions.glob("*.py"))) == 3
  85. with sqla_session() as session:
  86. result = session.scalars(select(AlembicThing)).all()
  87. assert len(result) == 2
  88. assert result[0].t2 == "bar"
  89. assert result[1].t2 == "baz"
  90. # Add table
  91. class AlembicSecond(ModelBase):
  92. a: Mapped[int] = mapped_column(default=42)
  93. b: Mapped[float] = mapped_column(default=4.2)
  94. assert Model.migrate(autogenerate=True)
  95. assert len(list(versions.glob("*.py"))) == 4
  96. with reflex.model.session() as session:
  97. session.add(AlembicSecond(id=None))
  98. session.commit()
  99. result = session.scalars(select(AlembicSecond)).all()
  100. assert len(result) == 1
  101. assert result[0].a == 42
  102. assert result[0].b == 4.2
  103. # No-op
  104. # assert Model.migrate(autogenerate=True) #noqa: ERA001
  105. # assert len(list(versions.glob("*.py"))) == 4 #noqa: ERA001
  106. # drop table (AlembicSecond)
  107. model_registry.get_metadata().clear()
  108. class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
  109. t2: Mapped[str] = mapped_column(default="bar")
  110. assert Model.migrate(autogenerate=True)
  111. assert len(list(versions.glob("*.py"))) == 5
  112. with reflex.model.session() as session:
  113. with pytest.raises(OperationalError) as errctx:
  114. _ = session.scalars(select(AlembicSecond)).all()
  115. assert errctx.match(r"no such table: alembicsecond")
  116. # first table should still exist
  117. result = session.scalars(select(AlembicThing)).all()
  118. assert len(result) == 2
  119. assert result[0].t2 == "bar"
  120. assert result[1].t2 == "baz"
  121. model_registry.get_metadata().clear()
  122. class AlembicThing(ModelBase):
  123. # changing column type not supported by default
  124. t2: Mapped[int] = mapped_column(default=42)
  125. assert Model.migrate(autogenerate=True)
  126. assert len(list(versions.glob("*.py"))) == 5
  127. # clear all metadata to avoid influencing subsequent tests
  128. model_registry.get_metadata().clear()
  129. # drop remaining tables
  130. assert Model.migrate(autogenerate=True)
  131. assert len(list(versions.glob("*.py"))) == 6