1
0

test_model.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from pathlib import Path
  2. from unittest import mock
  3. import pytest
  4. import sqlalchemy
  5. import sqlmodel
  6. import reflex.constants
  7. import reflex.model
  8. from reflex.model import Model, ModelRegistry
  9. @pytest.fixture
  10. def model_default_primary() -> Model:
  11. """Returns a model object with no defined primary key.
  12. Returns:
  13. Model: Model object.
  14. """
  15. class ChildModel(Model):
  16. name: str
  17. return ChildModel(name="name")
  18. @pytest.fixture
  19. def model_custom_primary() -> Model:
  20. """Returns a model object with a custom primary key.
  21. Returns:
  22. Model: Model object.
  23. """
  24. class ChildModel(Model):
  25. custom_id: int | None = sqlmodel.Field(default=None, primary_key=True)
  26. name: str
  27. return ChildModel(name="name")
  28. def test_default_primary_key(model_default_primary: Model):
  29. """Test that if a primary key is not defined a default is added.
  30. Args:
  31. model_default_primary: Fixture.
  32. """
  33. assert "id" in type(model_default_primary).__fields__
  34. def test_custom_primary_key(model_custom_primary: Model):
  35. """Test that if a primary key is defined no default key is added.
  36. Args:
  37. model_custom_primary: Fixture.
  38. """
  39. assert "id" not in type(model_custom_primary).__fields__
  40. @pytest.mark.filterwarnings(
  41. "ignore:This declarative base already contains a class with the same class name",
  42. )
  43. def test_automigration(
  44. tmp_working_dir: Path,
  45. monkeypatch: pytest.MonkeyPatch,
  46. model_registry: type[ModelRegistry],
  47. ):
  48. """Test alembic automigration with add and drop table and column.
  49. Args:
  50. tmp_working_dir: directory where database and migrations are stored
  51. monkeypatch: pytest fixture to overwrite attributes
  52. model_registry: clean reflex ModelRegistry
  53. """
  54. alembic_ini = tmp_working_dir / "alembic.ini"
  55. versions = tmp_working_dir / "alembic" / "versions"
  56. monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
  57. config_mock = mock.Mock()
  58. config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
  59. monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
  60. Model.alembic_init()
  61. assert alembic_ini.exists()
  62. assert versions.exists()
  63. # initial table
  64. class AlembicThing(Model, table=True): # pyright: ignore [reportRedeclaration]
  65. t1: str
  66. with Model.get_db_engine().connect() as connection:
  67. assert Model.alembic_autogenerate(
  68. connection=connection, message="Initial Revision"
  69. )
  70. assert Model.migrate()
  71. version_scripts = list(versions.glob("*.py"))
  72. assert len(version_scripts) == 1
  73. assert version_scripts[0].name.endswith("initial_revision.py")
  74. with reflex.model.session() as session:
  75. session.add(AlembicThing(id=None, t1="foo"))
  76. session.commit()
  77. model_registry.get_metadata().clear()
  78. # Create column t2, mark t1 as optional with default
  79. class AlembicThing(Model, table=True): # pyright: ignore [reportRedeclaration]
  80. t1: str | None = "default"
  81. t2: str = "bar"
  82. assert Model.migrate(autogenerate=True)
  83. assert len(list(versions.glob("*.py"))) == 2
  84. with reflex.model.session() as session:
  85. session.add(AlembicThing(t2="baz"))
  86. session.commit()
  87. result = session.exec(sqlmodel.select(AlembicThing)).all()
  88. assert len(result) == 2
  89. assert result[0].t1 == "foo"
  90. assert result[0].t2 == "bar"
  91. assert result[1].t1 == "default"
  92. assert result[1].t2 == "baz"
  93. model_registry.get_metadata().clear()
  94. # Drop column t1
  95. class AlembicThing(Model, table=True): # pyright: ignore [reportRedeclaration]
  96. t2: str = "bar"
  97. assert Model.migrate(autogenerate=True)
  98. assert len(list(versions.glob("*.py"))) == 3
  99. with reflex.model.session() as session:
  100. result = session.exec(sqlmodel.select(AlembicThing)).all()
  101. assert len(result) == 2
  102. assert result[0].t2 == "bar"
  103. assert result[1].t2 == "baz"
  104. # Add table
  105. class AlembicSecond(Model, table=True):
  106. a: int = 42
  107. b: float = 4.2
  108. assert Model.migrate(autogenerate=True)
  109. assert len(list(versions.glob("*.py"))) == 4
  110. with reflex.model.session() as session:
  111. session.add(AlembicSecond(id=None))
  112. session.commit()
  113. result = session.exec(sqlmodel.select(AlembicSecond)).all()
  114. assert len(result) == 1
  115. assert result[0].a == 42
  116. assert result[0].b == 4.2
  117. # No-op
  118. assert Model.migrate(autogenerate=True)
  119. assert len(list(versions.glob("*.py"))) == 4
  120. # drop table (AlembicSecond)
  121. model_registry.get_metadata().clear()
  122. class AlembicThing(Model, table=True): # pyright: ignore [reportRedeclaration]
  123. t2: str = "bar"
  124. assert Model.migrate(autogenerate=True)
  125. assert len(list(versions.glob("*.py"))) == 5
  126. with reflex.model.session() as session:
  127. with pytest.raises(sqlalchemy.exc.OperationalError) as errctx:
  128. session.exec(sqlmodel.select(AlembicSecond)).all()
  129. assert errctx.match(r"no such table: alembicsecond")
  130. # first table should still exist
  131. result = session.exec(sqlmodel.select(AlembicThing)).all()
  132. assert len(result) == 2
  133. assert result[0].t2 == "bar"
  134. assert result[1].t2 == "baz"
  135. model_registry.get_metadata().clear()
  136. class AlembicThing(Model, table=True):
  137. # changing column type not supported by default
  138. t2: int = 42
  139. assert Model.migrate(autogenerate=True)
  140. assert len(list(versions.glob("*.py"))) == 5
  141. # clear all metadata to avoid influencing subsequent tests
  142. model_registry.get_metadata().clear()
  143. # drop remaining tables
  144. assert Model.migrate(autogenerate=True)
  145. assert len(list(versions.glob("*.py"))) == 6