test_sqlalchemy.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. from pathlib import Path
  2. from typing import Optional, Type
  3. from unittest import mock
  4. import pytest
  5. from sqlalchemy import select
  6. from sqlalchemy.exc import OperationalError
  7. from sqlalchemy.orm import (
  8. DeclarativeBase,
  9. Mapped,
  10. MappedAsDataclass,
  11. declared_attr,
  12. mapped_column,
  13. )
  14. import reflex.constants
  15. import reflex.model
  16. from reflex.model import Model, ModelRegistry, sqla_session
  17. @pytest.mark.filterwarnings(
  18. "ignore:This declarative base already contains a class with the same class name",
  19. )
  20. def test_automigration(
  21. tmp_working_dir: Path,
  22. monkeypatch: pytest.MonkeyPatch,
  23. model_registry: Type[ModelRegistry],
  24. ):
  25. """Test alembic automigration with add and drop table and column.
  26. Args:
  27. tmp_working_dir: directory where database and migrations are stored
  28. monkeypatch: pytest fixture to overwrite attributes
  29. model_registry: clean reflex ModelRegistry
  30. """
  31. alembic_ini = tmp_working_dir / "alembic.ini"
  32. versions = tmp_working_dir / "alembic" / "versions"
  33. monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
  34. config_mock = mock.Mock()
  35. config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
  36. monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
  37. assert alembic_ini.exists() is False
  38. assert versions.exists() is False
  39. Model.alembic_init()
  40. assert alembic_ini.exists()
  41. assert versions.exists()
  42. class Base(DeclarativeBase):
  43. @declared_attr.directive
  44. def __tablename__(cls) -> str:
  45. return cls.__name__.lower()
  46. assert model_registry.register(Base)
  47. class ModelBase(Base, MappedAsDataclass):
  48. __abstract__ = True
  49. id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)
  50. # initial table
  51. class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
  52. t1: Mapped[str] = mapped_column(default="")
  53. with Model.get_db_engine().connect() as connection:
  54. assert Model.alembic_autogenerate(
  55. connection=connection, message="Initial Revision"
  56. )
  57. assert Model.migrate()
  58. version_scripts = list(versions.glob("*.py"))
  59. assert len(version_scripts) == 1
  60. assert version_scripts[0].name.endswith("initial_revision.py")
  61. with sqla_session() as session:
  62. session.add(AlembicThing(t1="foo"))
  63. session.commit()
  64. model_registry.get_metadata().clear()
  65. # Create column t2, mark t1 as optional with default
  66. class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
  67. t1: Mapped[Optional[str]] = mapped_column(default="default")
  68. t2: Mapped[str] = mapped_column(default="bar")
  69. assert Model.migrate(autogenerate=True)
  70. assert len(list(versions.glob("*.py"))) == 2
  71. with sqla_session() as session:
  72. session.add(AlembicThing(t2="baz"))
  73. session.commit()
  74. result = session.scalars(select(AlembicThing)).all()
  75. assert len(result) == 2
  76. assert result[0].t1 == "foo"
  77. assert result[0].t2 == "bar"
  78. assert result[1].t1 == "default"
  79. assert result[1].t2 == "baz"
  80. model_registry.get_metadata().clear()
  81. # Drop column t1
  82. class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
  83. t2: Mapped[str] = mapped_column(default="bar")
  84. assert Model.migrate(autogenerate=True)
  85. assert len(list(versions.glob("*.py"))) == 3
  86. with sqla_session() as session:
  87. result = session.scalars(select(AlembicThing)).all()
  88. assert len(result) == 2
  89. assert result[0].t2 == "bar"
  90. assert result[1].t2 == "baz"
  91. # Add table
  92. class AlembicSecond(ModelBase):
  93. a: Mapped[int] = mapped_column(default=42)
  94. b: Mapped[float] = mapped_column(default=4.2)
  95. assert Model.migrate(autogenerate=True)
  96. assert len(list(versions.glob("*.py"))) == 4
  97. with reflex.model.session() as session:
  98. session.add(AlembicSecond(id=None))
  99. session.commit()
  100. result = session.scalars(select(AlembicSecond)).all()
  101. assert len(result) == 1
  102. assert result[0].a == 42
  103. assert result[0].b == 4.2
  104. # No-op
  105. # assert Model.migrate(autogenerate=True) #noqa: ERA001
  106. # assert len(list(versions.glob("*.py"))) == 4 #noqa: ERA001
  107. # drop table (AlembicSecond)
  108. model_registry.get_metadata().clear()
  109. class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
  110. t2: Mapped[str] = mapped_column(default="bar")
  111. assert Model.migrate(autogenerate=True)
  112. assert len(list(versions.glob("*.py"))) == 5
  113. with reflex.model.session() as session:
  114. with pytest.raises(OperationalError) as errctx:
  115. _ = session.scalars(select(AlembicSecond)).all()
  116. assert errctx.match(r"no such table: alembicsecond")
  117. # first table should still exist
  118. result = session.scalars(select(AlembicThing)).all()
  119. assert len(result) == 2
  120. assert result[0].t2 == "bar"
  121. assert result[1].t2 == "baz"
  122. model_registry.get_metadata().clear()
  123. class AlembicThing(ModelBase):
  124. # changing column type not supported by default
  125. t2: Mapped[int] = mapped_column(default=42)
  126. assert Model.migrate(autogenerate=True)
  127. assert len(list(versions.glob("*.py"))) == 5
  128. # clear all metadata to avoid influencing subsequent tests
  129. model_registry.get_metadata().clear()
  130. # drop remaining tables
  131. assert Model.migrate(autogenerate=True)
  132. assert len(list(versions.glob("*.py"))) == 6