|
@@ -1,4 +1,5 @@
|
|
-from typing import Optional
|
|
|
|
|
|
+from pathlib import Path
|
|
|
|
+from typing import Optional, Type
|
|
from unittest import mock
|
|
from unittest import mock
|
|
|
|
|
|
import pytest
|
|
import pytest
|
|
@@ -7,7 +8,7 @@ import sqlmodel
|
|
|
|
|
|
import reflex.constants
|
|
import reflex.constants
|
|
import reflex.model
|
|
import reflex.model
|
|
-from reflex.model import Model
|
|
|
|
|
|
+from reflex.model import Model, ModelRegistry
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
@pytest.fixture
|
|
@@ -39,7 +40,7 @@ def model_custom_primary() -> Model:
|
|
return ChildModel(name="name")
|
|
return ChildModel(name="name")
|
|
|
|
|
|
|
|
|
|
-def test_default_primary_key(model_default_primary):
|
|
|
|
|
|
+def test_default_primary_key(model_default_primary: Model):
|
|
"""Test that if a primary key is not defined a default is added.
|
|
"""Test that if a primary key is not defined a default is added.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
@@ -48,7 +49,7 @@ def test_default_primary_key(model_default_primary):
|
|
assert "id" in model_default_primary.__class__.__fields__
|
|
assert "id" in model_default_primary.__class__.__fields__
|
|
|
|
|
|
|
|
|
|
-def test_custom_primary_key(model_custom_primary):
|
|
|
|
|
|
+def test_custom_primary_key(model_custom_primary: Model):
|
|
"""Test that if a primary key is defined no default key is added.
|
|
"""Test that if a primary key is defined no default key is added.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
@@ -60,12 +61,17 @@ def test_custom_primary_key(model_custom_primary):
|
|
@pytest.mark.filterwarnings(
|
|
@pytest.mark.filterwarnings(
|
|
"ignore:This declarative base already contains a class with the same class name",
|
|
"ignore:This declarative base already contains a class with the same class name",
|
|
)
|
|
)
|
|
-def test_automigration(tmp_working_dir, monkeypatch):
|
|
|
|
|
|
+def test_automigration(
|
|
|
|
+ tmp_working_dir: Path,
|
|
|
|
+ monkeypatch: pytest.MonkeyPatch,
|
|
|
|
+ model_registry: Type[ModelRegistry],
|
|
|
|
+):
|
|
"""Test alembic automigration with add and drop table and column.
|
|
"""Test alembic automigration with add and drop table and column.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
tmp_working_dir: directory where database and migrations are stored
|
|
tmp_working_dir: directory where database and migrations are stored
|
|
monkeypatch: pytest fixture to overwrite attributes
|
|
monkeypatch: pytest fixture to overwrite attributes
|
|
|
|
+ model_registry: clean reflex ModelRegistry
|
|
"""
|
|
"""
|
|
alembic_ini = tmp_working_dir / "alembic.ini"
|
|
alembic_ini = tmp_working_dir / "alembic.ini"
|
|
versions = tmp_working_dir / "alembic" / "versions"
|
|
versions = tmp_working_dir / "alembic" / "versions"
|
|
@@ -84,8 +90,10 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
|
t1: str
|
|
t1: str
|
|
|
|
|
|
with Model.get_db_engine().connect() as connection:
|
|
with Model.get_db_engine().connect() as connection:
|
|
- Model.alembic_autogenerate(connection=connection, message="Initial Revision")
|
|
|
|
- Model.migrate()
|
|
|
|
|
|
+ assert Model.alembic_autogenerate(
|
|
|
|
+ connection=connection, message="Initial Revision"
|
|
|
|
+ )
|
|
|
|
+ assert Model.migrate()
|
|
version_scripts = list(versions.glob("*.py"))
|
|
version_scripts = list(versions.glob("*.py"))
|
|
assert len(version_scripts) == 1
|
|
assert len(version_scripts) == 1
|
|
assert version_scripts[0].name.endswith("initial_revision.py")
|
|
assert version_scripts[0].name.endswith("initial_revision.py")
|
|
@@ -94,14 +102,14 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
|
session.add(AlembicThing(id=None, t1="foo"))
|
|
session.add(AlembicThing(id=None, t1="foo"))
|
|
session.commit()
|
|
session.commit()
|
|
|
|
|
|
- sqlmodel.SQLModel.metadata.clear()
|
|
|
|
|
|
+ model_registry.get_metadata().clear()
|
|
|
|
|
|
# Create column t2, mark t1 as optional with default
|
|
# Create column t2, mark t1 as optional with default
|
|
class AlembicThing(Model, table=True): # type: ignore
|
|
class AlembicThing(Model, table=True): # type: ignore
|
|
t1: Optional[str] = "default"
|
|
t1: Optional[str] = "default"
|
|
t2: str = "bar"
|
|
t2: str = "bar"
|
|
|
|
|
|
- Model.migrate(autogenerate=True)
|
|
|
|
|
|
+ assert Model.migrate(autogenerate=True)
|
|
assert len(list(versions.glob("*.py"))) == 2
|
|
assert len(list(versions.glob("*.py"))) == 2
|
|
|
|
|
|
with reflex.model.session() as session:
|
|
with reflex.model.session() as session:
|
|
@@ -114,13 +122,13 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
|
assert result[1].t1 == "default"
|
|
assert result[1].t1 == "default"
|
|
assert result[1].t2 == "baz"
|
|
assert result[1].t2 == "baz"
|
|
|
|
|
|
- sqlmodel.SQLModel.metadata.clear()
|
|
|
|
|
|
+ model_registry.get_metadata().clear()
|
|
|
|
|
|
# Drop column t1
|
|
# Drop column t1
|
|
class AlembicThing(Model, table=True): # type: ignore
|
|
class AlembicThing(Model, table=True): # type: ignore
|
|
t2: str = "bar"
|
|
t2: str = "bar"
|
|
|
|
|
|
- Model.migrate(autogenerate=True)
|
|
|
|
|
|
+ assert Model.migrate(autogenerate=True)
|
|
assert len(list(versions.glob("*.py"))) == 3
|
|
assert len(list(versions.glob("*.py"))) == 3
|
|
|
|
|
|
with reflex.model.session() as session:
|
|
with reflex.model.session() as session:
|
|
@@ -134,7 +142,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
|
a: int = 42
|
|
a: int = 42
|
|
b: float = 4.2
|
|
b: float = 4.2
|
|
|
|
|
|
- Model.migrate(autogenerate=True)
|
|
|
|
|
|
+ assert Model.migrate(autogenerate=True)
|
|
assert len(list(versions.glob("*.py"))) == 4
|
|
assert len(list(versions.glob("*.py"))) == 4
|
|
|
|
|
|
with reflex.model.session() as session:
|
|
with reflex.model.session() as session:
|
|
@@ -146,16 +154,16 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
|
assert result[0].b == 4.2
|
|
assert result[0].b == 4.2
|
|
|
|
|
|
# No-op
|
|
# No-op
|
|
- Model.migrate(autogenerate=True)
|
|
|
|
|
|
+ assert Model.migrate(autogenerate=True)
|
|
assert len(list(versions.glob("*.py"))) == 4
|
|
assert len(list(versions.glob("*.py"))) == 4
|
|
|
|
|
|
# drop table (AlembicSecond)
|
|
# drop table (AlembicSecond)
|
|
- sqlmodel.SQLModel.metadata.clear()
|
|
|
|
|
|
+ model_registry.get_metadata().clear()
|
|
|
|
|
|
class AlembicThing(Model, table=True): # type: ignore
|
|
class AlembicThing(Model, table=True): # type: ignore
|
|
t2: str = "bar"
|
|
t2: str = "bar"
|
|
|
|
|
|
- Model.migrate(autogenerate=True)
|
|
|
|
|
|
+ assert Model.migrate(autogenerate=True)
|
|
assert len(list(versions.glob("*.py"))) == 5
|
|
assert len(list(versions.glob("*.py"))) == 5
|
|
|
|
|
|
with reflex.model.session() as session:
|
|
with reflex.model.session() as session:
|
|
@@ -168,18 +176,18 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
|
assert result[0].t2 == "bar"
|
|
assert result[0].t2 == "bar"
|
|
assert result[1].t2 == "baz"
|
|
assert result[1].t2 == "baz"
|
|
|
|
|
|
- sqlmodel.SQLModel.metadata.clear()
|
|
|
|
|
|
+ model_registry.get_metadata().clear()
|
|
|
|
|
|
class AlembicThing(Model, table=True): # type: ignore
|
|
class AlembicThing(Model, table=True): # type: ignore
|
|
# changing column type not supported by default
|
|
# changing column type not supported by default
|
|
t2: int = 42
|
|
t2: int = 42
|
|
|
|
|
|
- Model.migrate(autogenerate=True)
|
|
|
|
|
|
+ assert Model.migrate(autogenerate=True)
|
|
assert len(list(versions.glob("*.py"))) == 5
|
|
assert len(list(versions.glob("*.py"))) == 5
|
|
|
|
|
|
# clear all metadata to avoid influencing subsequent tests
|
|
# clear all metadata to avoid influencing subsequent tests
|
|
- sqlmodel.SQLModel.metadata.clear()
|
|
|
|
|
|
+ model_registry.get_metadata().clear()
|
|
|
|
|
|
# drop remaining tables
|
|
# drop remaining tables
|
|
- Model.migrate(autogenerate=True)
|
|
|
|
|
|
+ assert Model.migrate(autogenerate=True)
|
|
assert len(list(versions.glob("*.py"))) == 6
|
|
assert len(list(versions.glob("*.py"))) == 6
|