test_model.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. from pathlib import Path
  2. from typing import Optional, Type
  3. from unittest import mock
  4. import pytest
  5. import sqlalchemy
  6. import sqlmodel
  7. import reflex.constants
  8. import reflex.model
  9. from reflex.model import Model, ModelRegistry
  10. @pytest.fixture
  11. def model_default_primary() -> Model:
  12. """Returns a model object with no defined primary key.
  13. Returns:
  14. Model: Model object.
  15. """
  16. class ChildModel(Model):
  17. name: str
  18. return ChildModel(name="name")
  19. @pytest.fixture
  20. def model_custom_primary() -> Model:
  21. """Returns a model object with a custom primary key.
  22. Returns:
  23. Model: Model object.
  24. """
  25. class ChildModel(Model):
  26. custom_id: Optional[int] = sqlmodel.Field(default=None, primary_key=True)
  27. name: str
  28. return ChildModel(name="name")
  29. def test_default_primary_key(model_default_primary: Model):
  30. """Test that if a primary key is not defined a default is added.
  31. Args:
  32. model_default_primary: Fixture.
  33. """
  34. assert "id" in model_default_primary.__class__.__fields__
  35. def test_custom_primary_key(model_custom_primary: Model):
  36. """Test that if a primary key is defined no default key is added.
  37. Args:
  38. model_custom_primary: Fixture.
  39. """
  40. assert "id" not in model_custom_primary.__class__.__fields__
  41. @pytest.mark.filterwarnings(
  42. "ignore:This declarative base already contains a class with the same class name",
  43. )
  44. def test_automigration(
  45. tmp_working_dir: Path,
  46. monkeypatch: pytest.MonkeyPatch,
  47. model_registry: Type[ModelRegistry],
  48. ):
  49. """Test alembic automigration with add and drop table and column.
  50. Args:
  51. tmp_working_dir: directory where database and migrations are stored
  52. monkeypatch: pytest fixture to overwrite attributes
  53. model_registry: clean reflex ModelRegistry
  54. """
  55. alembic_ini = tmp_working_dir / "alembic.ini"
  56. versions = tmp_working_dir / "alembic" / "versions"
  57. monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
  58. config_mock = mock.Mock()
  59. config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
  60. monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
  61. Model.alembic_init()
  62. assert alembic_ini.exists()
  63. assert versions.exists()
  64. # initial table
  65. class AlembicThing(Model, table=True): # type: ignore
  66. t1: str
  67. with Model.get_db_engine().connect() as connection:
  68. assert Model.alembic_autogenerate(
  69. connection=connection, message="Initial Revision"
  70. )
  71. assert Model.migrate()
  72. version_scripts = list(versions.glob("*.py"))
  73. assert len(version_scripts) == 1
  74. assert version_scripts[0].name.endswith("initial_revision.py")
  75. with reflex.model.session() as session:
  76. session.add(AlembicThing(id=None, t1="foo"))
  77. session.commit()
  78. model_registry.get_metadata().clear()
  79. # Create column t2, mark t1 as optional with default
  80. class AlembicThing(Model, table=True): # type: ignore
  81. t1: Optional[str] = "default"
  82. t2: str = "bar"
  83. assert Model.migrate(autogenerate=True)
  84. assert len(list(versions.glob("*.py"))) == 2
  85. with reflex.model.session() as session:
  86. session.add(AlembicThing(t2="baz"))
  87. session.commit()
  88. result = session.exec(sqlmodel.select(AlembicThing)).all()
  89. assert len(result) == 2
  90. assert result[0].t1 == "foo"
  91. assert result[0].t2 == "bar"
  92. assert result[1].t1 == "default"
  93. assert result[1].t2 == "baz"
  94. model_registry.get_metadata().clear()
  95. # Drop column t1
  96. class AlembicThing(Model, table=True): # type: ignore
  97. t2: str = "bar"
  98. assert Model.migrate(autogenerate=True)
  99. assert len(list(versions.glob("*.py"))) == 3
  100. with reflex.model.session() as session:
  101. result = session.exec(sqlmodel.select(AlembicThing)).all()
  102. assert len(result) == 2
  103. assert result[0].t2 == "bar"
  104. assert result[1].t2 == "baz"
  105. # Add table
  106. class AlembicSecond(Model, table=True): # type: ignore
  107. a: int = 42
  108. b: float = 4.2
  109. assert Model.migrate(autogenerate=True)
  110. assert len(list(versions.glob("*.py"))) == 4
  111. with reflex.model.session() as session:
  112. session.add(AlembicSecond(id=None))
  113. session.commit()
  114. result = session.exec(sqlmodel.select(AlembicSecond)).all()
  115. assert len(result) == 1
  116. assert result[0].a == 42
  117. assert result[0].b == 4.2
  118. # No-op
  119. assert Model.migrate(autogenerate=True)
  120. assert len(list(versions.glob("*.py"))) == 4
  121. # drop table (AlembicSecond)
  122. model_registry.get_metadata().clear()
  123. class AlembicThing(Model, table=True): # type: ignore
  124. t2: str = "bar"
  125. assert Model.migrate(autogenerate=True)
  126. assert len(list(versions.glob("*.py"))) == 5
  127. with reflex.model.session() as session:
  128. with pytest.raises(sqlalchemy.exc.OperationalError) as errctx: # type: ignore
  129. session.exec(sqlmodel.select(AlembicSecond)).all()
  130. assert errctx.match(r"no such table: alembicsecond")
  131. # first table should still exist
  132. result = session.exec(sqlmodel.select(AlembicThing)).all()
  133. assert len(result) == 2
  134. assert result[0].t2 == "bar"
  135. assert result[1].t2 == "baz"
  136. model_registry.get_metadata().clear()
  137. class AlembicThing(Model, table=True): # type: ignore
  138. # changing column type not supported by default
  139. t2: int = 42
  140. assert Model.migrate(autogenerate=True)
  141. assert len(list(versions.glob("*.py"))) == 5
  142. # clear all metadata to avoid influencing subsequent tests
  143. model_registry.get_metadata().clear()
  144. # drop remaining tables
  145. assert Model.migrate(autogenerate=True)
  146. assert len(list(versions.glob("*.py"))) == 6