test_model.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. from unittest import mock
  2. import pytest
  3. import sqlalchemy
  4. import sqlmodel
  5. import reflex.constants
  6. import reflex.model
  7. from reflex.model import Model
  8. @pytest.fixture
  9. def model_default_primary() -> Model:
  10. """Returns a model object with no defined primary key.
  11. Returns:
  12. Model: Model object.
  13. """
  14. class ChildModel(Model):
  15. name: str
  16. return ChildModel(name="name") # type: ignore
  17. @pytest.fixture
  18. def model_custom_primary() -> Model:
  19. """Returns a model object with a custom primary key.
  20. Returns:
  21. Model: Model object.
  22. """
  23. class ChildModel(Model):
  24. custom_id: int = sqlmodel.Field(default=None, primary_key=True)
  25. name: str
  26. return ChildModel(name="name") # type: ignore
  27. def test_default_primary_key(model_default_primary):
  28. """Test that if a primary key is not defined a default is added.
  29. Args:
  30. model_default_primary: Fixture.
  31. """
  32. assert "id" in model_default_primary.__class__.__fields__
  33. def test_custom_primary_key(model_custom_primary):
  34. """Test that if a primary key is defined no default key is added.
  35. Args:
  36. model_custom_primary: Fixture.
  37. """
  38. assert "id" not in model_custom_primary.__class__.__fields__
  39. @pytest.mark.filterwarnings(
  40. "ignore:This declarative base already contains a class with the same class name",
  41. )
  42. def test_automigration(tmp_working_dir, monkeypatch):
  43. """Test alembic automigration with add and drop table and column.
  44. Args:
  45. tmp_working_dir: directory where database and migrations are stored
  46. monkeypatch: pytest fixture to overwrite attributes
  47. """
  48. alembic_ini = tmp_working_dir / "alembic.ini"
  49. versions = tmp_working_dir / "alembic" / "versions"
  50. monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
  51. config_mock = mock.Mock()
  52. config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
  53. monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
  54. Model.alembic_init()
  55. assert alembic_ini.exists()
  56. assert versions.exists()
  57. # initial table
  58. class AlembicThing(Model, table=True): # type: ignore
  59. t1: str
  60. with Model.get_db_engine().connect() as connection:
  61. Model.alembic_autogenerate(connection=connection, message="Initial Revision")
  62. Model.migrate()
  63. version_scripts = list(versions.glob("*.py"))
  64. assert len(version_scripts) == 1
  65. assert version_scripts[0].name.endswith("initial_revision.py")
  66. with reflex.model.session() as session:
  67. session.add(AlembicThing(id=None, t1="foo"))
  68. session.commit()
  69. sqlmodel.SQLModel.metadata.clear()
  70. # Create column t2
  71. class AlembicThing(Model, table=True): # type: ignore
  72. t1: str
  73. t2: str = "bar"
  74. Model.migrate(autogenerate=True)
  75. assert len(list(versions.glob("*.py"))) == 2
  76. with reflex.model.session() as session:
  77. result = session.exec(sqlmodel.select(AlembicThing)).all()
  78. assert len(result) == 1
  79. assert result[0].t2 == "bar"
  80. sqlmodel.SQLModel.metadata.clear()
  81. # Drop column t1
  82. class AlembicThing(Model, table=True): # type: ignore
  83. t2: str = "bar"
  84. Model.migrate(autogenerate=True)
  85. assert len(list(versions.glob("*.py"))) == 3
  86. with reflex.model.session() as session:
  87. result = session.exec(sqlmodel.select(AlembicThing)).all()
  88. assert len(result) == 1
  89. assert result[0].t2 == "bar"
  90. # Add table
  91. class AlembicSecond(Model, table=True): # type: ignore
  92. a: int = 42
  93. b: float = 4.2
  94. Model.migrate(autogenerate=True)
  95. assert len(list(versions.glob("*.py"))) == 4
  96. with reflex.model.session() as session:
  97. session.add(AlembicSecond(id=None))
  98. session.commit()
  99. result = session.exec(sqlmodel.select(AlembicSecond)).all()
  100. assert len(result) == 1
  101. assert result[0].a == 42
  102. assert result[0].b == 4.2
  103. # No-op
  104. Model.migrate(autogenerate=True)
  105. assert len(list(versions.glob("*.py"))) == 4
  106. # drop table (AlembicSecond)
  107. sqlmodel.SQLModel.metadata.clear()
  108. class AlembicThing(Model, table=True): # type: ignore
  109. t2: str = "bar"
  110. Model.migrate(autogenerate=True)
  111. assert len(list(versions.glob("*.py"))) == 5
  112. with reflex.model.session() as session:
  113. with pytest.raises(sqlalchemy.exc.OperationalError) as errctx: # type: ignore
  114. session.exec(sqlmodel.select(AlembicSecond)).all()
  115. assert errctx.match(r"no such table: alembicsecond")
  116. # first table should still exist
  117. result = session.exec(sqlmodel.select(AlembicThing)).all()
  118. assert len(result) == 1
  119. assert result[0].t2 == "bar"
  120. sqlmodel.SQLModel.metadata.clear()
  121. class AlembicThing(Model, table=True): # type: ignore
  122. # changing column type not supported by default
  123. t2: int = 42
  124. Model.migrate(autogenerate=True)
  125. assert len(list(versions.glob("*.py"))) == 5
  126. # clear all metadata to avoid influencing subsequent tests
  127. sqlmodel.SQLModel.metadata.clear()
  128. # drop remaining tables
  129. Model.migrate(autogenerate=True)
  130. assert len(list(versions.glob("*.py"))) == 6