test_model.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. from typing import Optional
  2. from unittest import mock
  3. import pytest
  4. import sqlalchemy
  5. import sqlmodel
  6. import reflex.constants
  7. import reflex.model
  8. from reflex.model import Model
  9. @pytest.fixture
  10. def model_default_primary() -> Model:
  11. """Returns a model object with no defined primary key.
  12. Returns:
  13. Model: Model object.
  14. """
  15. class ChildModel(Model):
  16. name: str
  17. return ChildModel(name="name")
  18. @pytest.fixture
  19. def model_custom_primary() -> Model:
  20. """Returns a model object with a custom primary key.
  21. Returns:
  22. Model: Model object.
  23. """
  24. class ChildModel(Model):
  25. custom_id: Optional[int] = sqlmodel.Field(default=None, primary_key=True)
  26. name: str
  27. return ChildModel(name="name")
  28. def test_default_primary_key(model_default_primary):
  29. """Test that if a primary key is not defined a default is added.
  30. Args:
  31. model_default_primary: Fixture.
  32. """
  33. assert "id" in model_default_primary.__class__.__fields__
  34. def test_custom_primary_key(model_custom_primary):
  35. """Test that if a primary key is defined no default key is added.
  36. Args:
  37. model_custom_primary: Fixture.
  38. """
  39. assert "id" not in model_custom_primary.__class__.__fields__
  40. @pytest.mark.filterwarnings(
  41. "ignore:This declarative base already contains a class with the same class name",
  42. )
  43. def test_automigration(tmp_working_dir, monkeypatch):
  44. """Test alembic automigration with add and drop table and column.
  45. Args:
  46. tmp_working_dir: directory where database and migrations are stored
  47. monkeypatch: pytest fixture to overwrite attributes
  48. """
  49. alembic_ini = tmp_working_dir / "alembic.ini"
  50. versions = tmp_working_dir / "alembic" / "versions"
  51. monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
  52. config_mock = mock.Mock()
  53. config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
  54. monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
  55. Model.alembic_init()
  56. assert alembic_ini.exists()
  57. assert versions.exists()
  58. # initial table
  59. class AlembicThing(Model, table=True): # type: ignore
  60. t1: str
  61. with Model.get_db_engine().connect() as connection:
  62. Model.alembic_autogenerate(connection=connection, message="Initial Revision")
  63. Model.migrate()
  64. version_scripts = list(versions.glob("*.py"))
  65. assert len(version_scripts) == 1
  66. assert version_scripts[0].name.endswith("initial_revision.py")
  67. with reflex.model.session() as session:
  68. session.add(AlembicThing(id=None, t1="foo"))
  69. session.commit()
  70. sqlmodel.SQLModel.metadata.clear()
  71. # Create column t2, mark t1 as optional with default
  72. class AlembicThing(Model, table=True): # type: ignore
  73. t1: Optional[str] = "default"
  74. t2: str = "bar"
  75. Model.migrate(autogenerate=True)
  76. assert len(list(versions.glob("*.py"))) == 2
  77. with reflex.model.session() as session:
  78. session.add(AlembicThing(t2="baz"))
  79. session.commit()
  80. result = session.exec(sqlmodel.select(AlembicThing)).all()
  81. assert len(result) == 2
  82. assert result[0].t1 == "foo"
  83. assert result[0].t2 == "bar"
  84. assert result[1].t1 == "default"
  85. assert result[1].t2 == "baz"
  86. sqlmodel.SQLModel.metadata.clear()
  87. # Drop column t1
  88. class AlembicThing(Model, table=True): # type: ignore
  89. t2: str = "bar"
  90. Model.migrate(autogenerate=True)
  91. assert len(list(versions.glob("*.py"))) == 3
  92. with reflex.model.session() as session:
  93. result = session.exec(sqlmodel.select(AlembicThing)).all()
  94. assert len(result) == 2
  95. assert result[0].t2 == "bar"
  96. assert result[1].t2 == "baz"
  97. # Add table
  98. class AlembicSecond(Model, table=True): # type: ignore
  99. a: int = 42
  100. b: float = 4.2
  101. Model.migrate(autogenerate=True)
  102. assert len(list(versions.glob("*.py"))) == 4
  103. with reflex.model.session() as session:
  104. session.add(AlembicSecond(id=None))
  105. session.commit()
  106. result = session.exec(sqlmodel.select(AlembicSecond)).all()
  107. assert len(result) == 1
  108. assert result[0].a == 42
  109. assert result[0].b == 4.2
  110. # No-op
  111. Model.migrate(autogenerate=True)
  112. assert len(list(versions.glob("*.py"))) == 4
  113. # drop table (AlembicSecond)
  114. sqlmodel.SQLModel.metadata.clear()
  115. class AlembicThing(Model, table=True): # type: ignore
  116. t2: str = "bar"
  117. Model.migrate(autogenerate=True)
  118. assert len(list(versions.glob("*.py"))) == 5
  119. with reflex.model.session() as session:
  120. with pytest.raises(sqlalchemy.exc.OperationalError) as errctx: # type: ignore
  121. session.exec(sqlmodel.select(AlembicSecond)).all()
  122. assert errctx.match(r"no such table: alembicsecond")
  123. # first table should still exist
  124. result = session.exec(sqlmodel.select(AlembicThing)).all()
  125. assert len(result) == 2
  126. assert result[0].t2 == "bar"
  127. assert result[1].t2 == "baz"
  128. sqlmodel.SQLModel.metadata.clear()
  129. class AlembicThing(Model, table=True): # type: ignore
  130. # changing column type not supported by default
  131. t2: int = 42
  132. Model.migrate(autogenerate=True)
  133. assert len(list(versions.glob("*.py"))) == 5
  134. # clear all metadata to avoid influencing subsequent tests
  135. sqlmodel.SQLModel.metadata.clear()
  136. # drop remaining tables
  137. Model.migrate(autogenerate=True)
  138. assert len(list(versions.glob("*.py"))) == 6