Bladeren bron

bare sqlalchemy session + tests (#3522)

* add bare sqlalchemy session, Closes #3512

* expose sqla_session at module level, add tests, improve typing

* fix table name

* add model_registry fixture, improve typing

* did not meant to push this

* add docstring to model_registry

* do not expose sqla_session in reflex namespace
benedikt-bartscher 10 maanden geleden
bovenliggende
commit
9d71bcbbb5
5 gewijzigde bestanden met toevoegingen van 221 en 22 verwijderingen
  1. 1 0
      .gitignore
  2. 14 2
      reflex/model.py
  3. 13 1
      tests/conftest.py
  4. 27 19
      tests/test_model.py
  5. 166 0
      tests/test_sqlalchemy.py

+ 1 - 0
.gitignore

@@ -12,3 +12,4 @@ venv
 requirements.txt
 requirements.txt
 .pyi_generator_last_run
 .pyi_generator_last_run
 .pyi_generator_diff
 .pyi_generator_diff
+reflex.db

+ 14 - 2
reflex/model.py

@@ -24,7 +24,7 @@ from reflex.utils import console
 from reflex.utils.compat import sqlmodel
 from reflex.utils.compat import sqlmodel
 
 
 
 
-def get_engine(url: str | None = None):
+def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
     """Get the database engine.
     """Get the database engine.
 
 
     Args:
     Args:
@@ -396,7 +396,7 @@ ModelRegistry.register(Model)
 
 
 
 
 def session(url: str | None = None) -> sqlmodel.Session:
 def session(url: str | None = None) -> sqlmodel.Session:
-    """Get a session to interact with the database.
+    """Get a sqlmodel session to interact with the database.
 
 
     Args:
     Args:
         url: The database url.
         url: The database url.
@@ -405,3 +405,15 @@ def session(url: str | None = None) -> sqlmodel.Session:
         A database session.
         A database session.
     """
     """
     return sqlmodel.Session(get_engine(url))
     return sqlmodel.Session(get_engine(url))
+
+
+def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
+    """Get a bare sqlalchemy session to interact with the database.
+
+    Args:
+        url: The database url.
+
+    Returns:
+        A database session.
+    """
+    return sqlalchemy.orm.Session(get_engine(url))

+ 13 - 1
tests/conftest.py

@@ -5,13 +5,14 @@ import os
 import platform
 import platform
 import uuid
 import uuid
 from pathlib import Path
 from pathlib import Path
-from typing import Dict, Generator
+from typing import Dict, Generator, Type
 from unittest import mock
 from unittest import mock
 
 
 import pytest
 import pytest
 
 
 from reflex.app import App
 from reflex.app import App
 from reflex.event import EventSpec
 from reflex.event import EventSpec
+from reflex.model import ModelRegistry
 from reflex.utils import prerequisites
 from reflex.utils import prerequisites
 
 
 from .states import (
 from .states import (
@@ -247,3 +248,14 @@ def token() -> str:
         A fresh/unique token string.
         A fresh/unique token string.
     """
     """
     return str(uuid.uuid4())
     return str(uuid.uuid4())
+
+
+@pytest.fixture
+def model_registry() -> Generator[Type[ModelRegistry], None, None]:
+    """Create a model registry.
+
+    Yields:
+        A fresh model registry.
+    """
+    yield ModelRegistry
+    ModelRegistry._metadata = None

+ 27 - 19
tests/test_model.py

@@ -1,4 +1,5 @@
-from typing import Optional
+from pathlib import Path
+from typing import Optional, Type
 from unittest import mock
 from unittest import mock
 
 
 import pytest
 import pytest
@@ -7,7 +8,7 @@ import sqlmodel
 
 
 import reflex.constants
 import reflex.constants
 import reflex.model
 import reflex.model
-from reflex.model import Model
+from reflex.model import Model, ModelRegistry
 
 
 
 
 @pytest.fixture
 @pytest.fixture
@@ -39,7 +40,7 @@ def model_custom_primary() -> Model:
     return ChildModel(name="name")
     return ChildModel(name="name")
 
 
 
 
-def test_default_primary_key(model_default_primary):
+def test_default_primary_key(model_default_primary: Model):
     """Test that if a primary key is not defined a default is added.
     """Test that if a primary key is not defined a default is added.
 
 
     Args:
     Args:
@@ -48,7 +49,7 @@ def test_default_primary_key(model_default_primary):
     assert "id" in model_default_primary.__class__.__fields__
     assert "id" in model_default_primary.__class__.__fields__
 
 
 
 
-def test_custom_primary_key(model_custom_primary):
+def test_custom_primary_key(model_custom_primary: Model):
     """Test that if a primary key is defined no default key is added.
     """Test that if a primary key is defined no default key is added.
 
 
     Args:
     Args:
@@ -60,12 +61,17 @@ def test_custom_primary_key(model_custom_primary):
 @pytest.mark.filterwarnings(
 @pytest.mark.filterwarnings(
     "ignore:This declarative base already contains a class with the same class name",
     "ignore:This declarative base already contains a class with the same class name",
 )
 )
-def test_automigration(tmp_working_dir, monkeypatch):
+def test_automigration(
+    tmp_working_dir: Path,
+    monkeypatch: pytest.MonkeyPatch,
+    model_registry: Type[ModelRegistry],
+):
     """Test alembic automigration with add and drop table and column.
     """Test alembic automigration with add and drop table and column.
 
 
     Args:
     Args:
         tmp_working_dir: directory where database and migrations are stored
         tmp_working_dir: directory where database and migrations are stored
         monkeypatch: pytest fixture to overwrite attributes
         monkeypatch: pytest fixture to overwrite attributes
+        model_registry: clean reflex ModelRegistry
     """
     """
     alembic_ini = tmp_working_dir / "alembic.ini"
     alembic_ini = tmp_working_dir / "alembic.ini"
     versions = tmp_working_dir / "alembic" / "versions"
     versions = tmp_working_dir / "alembic" / "versions"
@@ -84,8 +90,10 @@ def test_automigration(tmp_working_dir, monkeypatch):
         t1: str
         t1: str
 
 
     with Model.get_db_engine().connect() as connection:
     with Model.get_db_engine().connect() as connection:
-        Model.alembic_autogenerate(connection=connection, message="Initial Revision")
-    Model.migrate()
+        assert Model.alembic_autogenerate(
+            connection=connection, message="Initial Revision"
+        )
+    assert Model.migrate()
     version_scripts = list(versions.glob("*.py"))
     version_scripts = list(versions.glob("*.py"))
     assert len(version_scripts) == 1
     assert len(version_scripts) == 1
     assert version_scripts[0].name.endswith("initial_revision.py")
     assert version_scripts[0].name.endswith("initial_revision.py")
@@ -94,14 +102,14 @@ def test_automigration(tmp_working_dir, monkeypatch):
         session.add(AlembicThing(id=None, t1="foo"))
         session.add(AlembicThing(id=None, t1="foo"))
         session.commit()
         session.commit()
 
 
-    sqlmodel.SQLModel.metadata.clear()
+    model_registry.get_metadata().clear()
 
 
     # Create column t2, mark t1 as optional with default
     # Create column t2, mark t1 as optional with default
     class AlembicThing(Model, table=True):  # type: ignore
     class AlembicThing(Model, table=True):  # type: ignore
         t1: Optional[str] = "default"
         t1: Optional[str] = "default"
         t2: str = "bar"
         t2: str = "bar"
 
 
-    Model.migrate(autogenerate=True)
+    assert Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 2
     assert len(list(versions.glob("*.py"))) == 2
 
 
     with reflex.model.session() as session:
     with reflex.model.session() as session:
@@ -114,13 +122,13 @@ def test_automigration(tmp_working_dir, monkeypatch):
         assert result[1].t1 == "default"
         assert result[1].t1 == "default"
         assert result[1].t2 == "baz"
         assert result[1].t2 == "baz"
 
 
-    sqlmodel.SQLModel.metadata.clear()
+    model_registry.get_metadata().clear()
 
 
     # Drop column t1
     # Drop column t1
     class AlembicThing(Model, table=True):  # type: ignore
     class AlembicThing(Model, table=True):  # type: ignore
         t2: str = "bar"
         t2: str = "bar"
 
 
-    Model.migrate(autogenerate=True)
+    assert Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 3
     assert len(list(versions.glob("*.py"))) == 3
 
 
     with reflex.model.session() as session:
     with reflex.model.session() as session:
@@ -134,7 +142,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
         a: int = 42
         a: int = 42
         b: float = 4.2
         b: float = 4.2
 
 
-    Model.migrate(autogenerate=True)
+    assert Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 4
     assert len(list(versions.glob("*.py"))) == 4
 
 
     with reflex.model.session() as session:
     with reflex.model.session() as session:
@@ -146,16 +154,16 @@ def test_automigration(tmp_working_dir, monkeypatch):
         assert result[0].b == 4.2
         assert result[0].b == 4.2
 
 
     # No-op
     # No-op
-    Model.migrate(autogenerate=True)
+    assert Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 4
     assert len(list(versions.glob("*.py"))) == 4
 
 
     # drop table (AlembicSecond)
     # drop table (AlembicSecond)
-    sqlmodel.SQLModel.metadata.clear()
+    model_registry.get_metadata().clear()
 
 
     class AlembicThing(Model, table=True):  # type: ignore
     class AlembicThing(Model, table=True):  # type: ignore
         t2: str = "bar"
         t2: str = "bar"
 
 
-    Model.migrate(autogenerate=True)
+    assert Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 5
     assert len(list(versions.glob("*.py"))) == 5
 
 
     with reflex.model.session() as session:
     with reflex.model.session() as session:
@@ -168,18 +176,18 @@ def test_automigration(tmp_working_dir, monkeypatch):
         assert result[0].t2 == "bar"
         assert result[0].t2 == "bar"
         assert result[1].t2 == "baz"
         assert result[1].t2 == "baz"
 
 
-    sqlmodel.SQLModel.metadata.clear()
+    model_registry.get_metadata().clear()
 
 
     class AlembicThing(Model, table=True):  # type: ignore
     class AlembicThing(Model, table=True):  # type: ignore
         # changing column type not supported by default
         # changing column type not supported by default
         t2: int = 42
         t2: int = 42
 
 
-    Model.migrate(autogenerate=True)
+    assert Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 5
     assert len(list(versions.glob("*.py"))) == 5
 
 
     # clear all metadata to avoid influencing subsequent tests
     # clear all metadata to avoid influencing subsequent tests
-    sqlmodel.SQLModel.metadata.clear()
+    model_registry.get_metadata().clear()
 
 
     # drop remaining tables
     # drop remaining tables
-    Model.migrate(autogenerate=True)
+    assert Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 6
     assert len(list(versions.glob("*.py"))) == 6

+ 166 - 0
tests/test_sqlalchemy.py

@@ -0,0 +1,166 @@
+from pathlib import Path
+from typing import Optional, Type
+from unittest import mock
+
+import pytest
+from sqlalchemy import select
+from sqlalchemy.exc import OperationalError
+from sqlalchemy.orm import (
+    DeclarativeBase,
+    Mapped,
+    MappedAsDataclass,
+    declared_attr,
+    mapped_column,
+)
+
+import reflex.constants
+import reflex.model
+from reflex.model import Model, ModelRegistry, sqla_session
+
+
+@pytest.mark.filterwarnings(
+    "ignore:This declarative base already contains a class with the same class name",
+)
+def test_automigration(
+    tmp_working_dir: Path,
+    monkeypatch: pytest.MonkeyPatch,
+    model_registry: Type[ModelRegistry],
+):
+    """Test alembic automigration with add and drop table and column.
+
+    Args:
+        tmp_working_dir: directory where database and migrations are stored
+        monkeypatch: pytest fixture to overwrite attributes
+        model_registry: clean reflex ModelRegistry
+    """
+    alembic_ini = tmp_working_dir / "alembic.ini"
+    versions = tmp_working_dir / "alembic" / "versions"
+    monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
+
+    config_mock = mock.Mock()
+    config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
+    monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
+
+    assert alembic_ini.exists() is False
+    assert versions.exists() is False
+    Model.alembic_init()
+    assert alembic_ini.exists()
+    assert versions.exists()
+
+    class Base(DeclarativeBase):
+        @declared_attr.directive
+        def __tablename__(cls) -> str:
+            return cls.__name__.lower()
+
+    assert model_registry.register(Base)
+
+    class ModelBase(Base, MappedAsDataclass):
+        __abstract__ = True
+        id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)
+
+    # initial table
+    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+        t1: Mapped[str] = mapped_column(default="")
+
+    with Model.get_db_engine().connect() as connection:
+        assert Model.alembic_autogenerate(
+            connection=connection, message="Initial Revision"
+        )
+    assert Model.migrate()
+    version_scripts = list(versions.glob("*.py"))
+    assert len(version_scripts) == 1
+    assert version_scripts[0].name.endswith("initial_revision.py")
+
+    with sqla_session() as session:
+        session.add(AlembicThing(t1="foo"))
+        session.commit()
+
+    model_registry.get_metadata().clear()
+
+    # Create column t2, mark t1 as optional with default
+    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+        t1: Mapped[Optional[str]] = mapped_column(default="default")
+        t2: Mapped[str] = mapped_column(default="bar")
+
+    assert Model.migrate(autogenerate=True)
+    assert len(list(versions.glob("*.py"))) == 2
+
+    with sqla_session() as session:
+        session.add(AlembicThing(t2="baz"))
+        session.commit()
+        result = session.scalars(select(AlembicThing)).all()
+        assert len(result) == 2
+        assert result[0].t1 == "foo"
+        assert result[0].t2 == "bar"
+        assert result[1].t1 == "default"
+        assert result[1].t2 == "baz"
+
+    model_registry.get_metadata().clear()
+
+    # Drop column t1
+    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+        t2: Mapped[str] = mapped_column(default="bar")
+
+    assert Model.migrate(autogenerate=True)
+    assert len(list(versions.glob("*.py"))) == 3
+
+    with sqla_session() as session:
+        result = session.scalars(select(AlembicThing)).all()
+        assert len(result) == 2
+        assert result[0].t2 == "bar"
+        assert result[1].t2 == "baz"
+
+    # Add table
+    class AlembicSecond(ModelBase):
+        a: Mapped[int] = mapped_column(default=42)
+        b: Mapped[float] = mapped_column(default=4.2)
+
+    assert Model.migrate(autogenerate=True)
+    assert len(list(versions.glob("*.py"))) == 4
+
+    with reflex.model.session() as session:
+        session.add(AlembicSecond(id=None))
+        session.commit()
+        result = session.scalars(select(AlembicSecond)).all()
+        assert len(result) == 1
+        assert result[0].a == 42
+        assert result[0].b == 4.2
+
+    # No-op
+    # assert Model.migrate(autogenerate=True)
+    # assert len(list(versions.glob("*.py"))) == 4
+
+    # drop table (AlembicSecond)
+    model_registry.get_metadata().clear()
+
+    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+        t2: Mapped[str] = mapped_column(default="bar")
+
+    assert Model.migrate(autogenerate=True)
+    assert len(list(versions.glob("*.py"))) == 5
+
+    with reflex.model.session() as session:
+        with pytest.raises(OperationalError) as errctx:
+            _ = session.scalars(select(AlembicSecond)).all()
+        assert errctx.match(r"no such table: alembicsecond")
+        # first table should still exist
+        result = session.scalars(select(AlembicThing)).all()
+        assert len(result) == 2
+        assert result[0].t2 == "bar"
+        assert result[1].t2 == "baz"
+
+    model_registry.get_metadata().clear()
+
+    class AlembicThing(ModelBase):
+        # changing column type not supported by default
+        t2: Mapped[int] = mapped_column(default=42)
+
+    assert Model.migrate(autogenerate=True)
+    assert len(list(versions.glob("*.py"))) == 5
+
+    # clear all metadata to avoid influencing subsequent tests
+    model_registry.get_metadata().clear()
+
+    # drop remaining tables
+    assert Model.migrate(autogenerate=True)
+    assert len(list(versions.glob("*.py"))) == 6