Explorar o código

bare sqlalchemy session + tests (#3522)

* add bare sqlalchemy session, Closes #3512

* expose sqla_session at module level, add tests, improve typing

* fix table name

* add model_registry fixture, improve typing

* did not meant to push this

* add docstring to model_registry

* do not expose sqla_session in reflex namespace
benedikt-bartscher hai 10 meses
pai
achega
9d71bcbbb5
Modificáronse 5 ficheiros con 221 adicións e 22 borrados
  1. 1 0
      .gitignore
  2. 14 2
      reflex/model.py
  3. 13 1
      tests/conftest.py
  4. 27 19
      tests/test_model.py
  5. 166 0
      tests/test_sqlalchemy.py

+ 1 - 0
.gitignore

@@ -12,3 +12,4 @@ venv
 requirements.txt
 .pyi_generator_last_run
 .pyi_generator_diff
+reflex.db

+ 14 - 2
reflex/model.py

@@ -24,7 +24,7 @@ from reflex.utils import console
 from reflex.utils.compat import sqlmodel
 
 
-def get_engine(url: str | None = None):
+def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
     """Get the database engine.
 
     Args:
@@ -396,7 +396,7 @@ ModelRegistry.register(Model)
 
 
 def session(url: str | None = None) -> sqlmodel.Session:
-    """Get a session to interact with the database.
+    """Get a sqlmodel session to interact with the database.
 
     Args:
         url: The database url.
@@ -405,3 +405,15 @@ def session(url: str | None = None) -> sqlmodel.Session:
         A database session.
     """
     return sqlmodel.Session(get_engine(url))
+
+
+def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
+    """Get a bare sqlalchemy session to interact with the database.
+
+    Args:
+        url: The database url.
+
+    Returns:
+        A database session.
+    """
+    return sqlalchemy.orm.Session(get_engine(url))

+ 13 - 1
tests/conftest.py

@@ -5,13 +5,14 @@ import os
 import platform
 import uuid
 from pathlib import Path
-from typing import Dict, Generator
+from typing import Dict, Generator, Type
 from unittest import mock
 
 import pytest
 
 from reflex.app import App
 from reflex.event import EventSpec
+from reflex.model import ModelRegistry
 from reflex.utils import prerequisites
 
 from .states import (
@@ -247,3 +248,14 @@ def token() -> str:
         A fresh/unique token string.
     """
     return str(uuid.uuid4())
+
+
+@pytest.fixture
+def model_registry() -> Generator[Type[ModelRegistry], None, None]:
+    """Create a model registry.
+
+    Yields:
+        A fresh model registry.
+    """
+    yield ModelRegistry
+    ModelRegistry._metadata = None

+ 27 - 19
tests/test_model.py

@@ -1,4 +1,5 @@
-from typing import Optional
+from pathlib import Path
+from typing import Optional, Type
 from unittest import mock
 
 import pytest
@@ -7,7 +8,7 @@ import sqlmodel
 
 import reflex.constants
 import reflex.model
-from reflex.model import Model
+from reflex.model import Model, ModelRegistry
 
 
 @pytest.fixture
@@ -39,7 +40,7 @@ def model_custom_primary() -> Model:
     return ChildModel(name="name")
 
 
-def test_default_primary_key(model_default_primary):
+def test_default_primary_key(model_default_primary: Model):
     """Test that if a primary key is not defined a default is added.
 
     Args:
@@ -48,7 +49,7 @@ def test_default_primary_key(model_default_primary):
     assert "id" in model_default_primary.__class__.__fields__
 
 
-def test_custom_primary_key(model_custom_primary):
+def test_custom_primary_key(model_custom_primary: Model):
     """Test that if a primary key is defined no default key is added.
 
     Args:
@@ -60,12 +61,17 @@ def test_custom_primary_key(model_custom_primary):
 @pytest.mark.filterwarnings(
     "ignore:This declarative base already contains a class with the same class name",
 )
-def test_automigration(tmp_working_dir, monkeypatch):
+def test_automigration(
+    tmp_working_dir: Path,
+    monkeypatch: pytest.MonkeyPatch,
+    model_registry: Type[ModelRegistry],
+):
     """Test alembic automigration with add and drop table and column.
 
     Args:
         tmp_working_dir: directory where database and migrations are stored
         monkeypatch: pytest fixture to overwrite attributes
+        model_registry: clean reflex ModelRegistry
     """
     alembic_ini = tmp_working_dir / "alembic.ini"
     versions = tmp_working_dir / "alembic" / "versions"
@@ -84,8 +90,10 @@ def test_automigration(tmp_working_dir, monkeypatch):
         t1: str
 
     with Model.get_db_engine().connect() as connection:
-        Model.alembic_autogenerate(connection=connection, message="Initial Revision")
-    Model.migrate()
+        assert Model.alembic_autogenerate(
+            connection=connection, message="Initial Revision"
+        )
+    assert Model.migrate()
     version_scripts = list(versions.glob("*.py"))
     assert len(version_scripts) == 1
     assert version_scripts[0].name.endswith("initial_revision.py")
@@ -94,14 +102,14 @@ def test_automigration(tmp_working_dir, monkeypatch):
         session.add(AlembicThing(id=None, t1="foo"))
         session.commit()
 
-    sqlmodel.SQLModel.metadata.clear()
+    model_registry.get_metadata().clear()
 
     # Create column t2, mark t1 as optional with default
     class AlembicThing(Model, table=True):  # type: ignore
         t1: Optional[str] = "default"
         t2: str = "bar"
 
-    Model.migrate(autogenerate=True)
+    assert Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 2
 
     with reflex.model.session() as session:
@@ -114,13 +122,13 @@ def test_automigration(tmp_working_dir, monkeypatch):
         assert result[1].t1 == "default"
         assert result[1].t2 == "baz"
 
-    sqlmodel.SQLModel.metadata.clear()
+    model_registry.get_metadata().clear()
 
     # Drop column t1
     class AlembicThing(Model, table=True):  # type: ignore
         t2: str = "bar"
 
-    Model.migrate(autogenerate=True)
+    assert Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 3
 
     with reflex.model.session() as session:
@@ -134,7 +142,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
         a: int = 42
         b: float = 4.2
 
-    Model.migrate(autogenerate=True)
+    assert Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 4
 
     with reflex.model.session() as session:
@@ -146,16 +154,16 @@ def test_automigration(tmp_working_dir, monkeypatch):
         assert result[0].b == 4.2
 
     # No-op
-    Model.migrate(autogenerate=True)
+    assert Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 4
 
     # drop table (AlembicSecond)
-    sqlmodel.SQLModel.metadata.clear()
+    model_registry.get_metadata().clear()
 
     class AlembicThing(Model, table=True):  # type: ignore
         t2: str = "bar"
 
-    Model.migrate(autogenerate=True)
+    assert Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 5
 
     with reflex.model.session() as session:
@@ -168,18 +176,18 @@ def test_automigration(tmp_working_dir, monkeypatch):
         assert result[0].t2 == "bar"
         assert result[1].t2 == "baz"
 
-    sqlmodel.SQLModel.metadata.clear()
+    model_registry.get_metadata().clear()
 
     class AlembicThing(Model, table=True):  # type: ignore
         # changing column type not supported by default
         t2: int = 42
 
-    Model.migrate(autogenerate=True)
+    assert Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 5
 
     # clear all metadata to avoid influencing subsequent tests
-    sqlmodel.SQLModel.metadata.clear()
+    model_registry.get_metadata().clear()
 
     # drop remaining tables
-    Model.migrate(autogenerate=True)
+    assert Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 6

+ 166 - 0
tests/test_sqlalchemy.py

@@ -0,0 +1,166 @@
+from pathlib import Path
+from typing import Optional, Type
+from unittest import mock
+
+import pytest
+from sqlalchemy import select
+from sqlalchemy.exc import OperationalError
+from sqlalchemy.orm import (
+    DeclarativeBase,
+    Mapped,
+    MappedAsDataclass,
+    declared_attr,
+    mapped_column,
+)
+
+import reflex.constants
+import reflex.model
+from reflex.model import Model, ModelRegistry, sqla_session
+
+
+@pytest.mark.filterwarnings(
+    "ignore:This declarative base already contains a class with the same class name",
+)
+def test_automigration(
+    tmp_working_dir: Path,
+    monkeypatch: pytest.MonkeyPatch,
+    model_registry: Type[ModelRegistry],
+):
+    """Test alembic automigration with add and drop table and column.
+
+    Args:
+        tmp_working_dir: directory where database and migrations are stored
+        monkeypatch: pytest fixture to overwrite attributes
+        model_registry: clean reflex ModelRegistry
+    """
+    alembic_ini = tmp_working_dir / "alembic.ini"
+    versions = tmp_working_dir / "alembic" / "versions"
+    monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
+
+    config_mock = mock.Mock()
+    config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
+    monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
+
+    assert alembic_ini.exists() is False
+    assert versions.exists() is False
+    Model.alembic_init()
+    assert alembic_ini.exists()
+    assert versions.exists()
+
+    class Base(DeclarativeBase):
+        @declared_attr.directive
+        def __tablename__(cls) -> str:
+            return cls.__name__.lower()
+
+    assert model_registry.register(Base)
+
+    class ModelBase(Base, MappedAsDataclass):
+        __abstract__ = True
+        id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)
+
+    # initial table
+    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+        t1: Mapped[str] = mapped_column(default="")
+
+    with Model.get_db_engine().connect() as connection:
+        assert Model.alembic_autogenerate(
+            connection=connection, message="Initial Revision"
+        )
+    assert Model.migrate()
+    version_scripts = list(versions.glob("*.py"))
+    assert len(version_scripts) == 1
+    assert version_scripts[0].name.endswith("initial_revision.py")
+
+    with sqla_session() as session:
+        session.add(AlembicThing(t1="foo"))
+        session.commit()
+
+    model_registry.get_metadata().clear()
+
+    # Create column t2, mark t1 as optional with default
+    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+        t1: Mapped[Optional[str]] = mapped_column(default="default")
+        t2: Mapped[str] = mapped_column(default="bar")
+
+    assert Model.migrate(autogenerate=True)
+    assert len(list(versions.glob("*.py"))) == 2
+
+    with sqla_session() as session:
+        session.add(AlembicThing(t2="baz"))
+        session.commit()
+        result = session.scalars(select(AlembicThing)).all()
+        assert len(result) == 2
+        assert result[0].t1 == "foo"
+        assert result[0].t2 == "bar"
+        assert result[1].t1 == "default"
+        assert result[1].t2 == "baz"
+
+    model_registry.get_metadata().clear()
+
+    # Drop column t1
+    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+        t2: Mapped[str] = mapped_column(default="bar")
+
+    assert Model.migrate(autogenerate=True)
+    assert len(list(versions.glob("*.py"))) == 3
+
+    with sqla_session() as session:
+        result = session.scalars(select(AlembicThing)).all()
+        assert len(result) == 2
+        assert result[0].t2 == "bar"
+        assert result[1].t2 == "baz"
+
+    # Add table
+    class AlembicSecond(ModelBase):
+        a: Mapped[int] = mapped_column(default=42)
+        b: Mapped[float] = mapped_column(default=4.2)
+
+    assert Model.migrate(autogenerate=True)
+    assert len(list(versions.glob("*.py"))) == 4
+
+    with reflex.model.session() as session:
+        session.add(AlembicSecond(id=None))
+        session.commit()
+        result = session.scalars(select(AlembicSecond)).all()
+        assert len(result) == 1
+        assert result[0].a == 42
+        assert result[0].b == 4.2
+
+    # No-op
+    # assert Model.migrate(autogenerate=True)
+    # assert len(list(versions.glob("*.py"))) == 4
+
+    # drop table (AlembicSecond)
+    model_registry.get_metadata().clear()
+
+    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+        t2: Mapped[str] = mapped_column(default="bar")
+
+    assert Model.migrate(autogenerate=True)
+    assert len(list(versions.glob("*.py"))) == 5
+
+    with reflex.model.session() as session:
+        with pytest.raises(OperationalError) as errctx:
+            _ = session.scalars(select(AlembicSecond)).all()
+        assert errctx.match(r"no such table: alembicsecond")
+        # first table should still exist
+        result = session.scalars(select(AlembicThing)).all()
+        assert len(result) == 2
+        assert result[0].t2 == "bar"
+        assert result[1].t2 == "baz"
+
+    model_registry.get_metadata().clear()
+
+    class AlembicThing(ModelBase):
+        # changing column type not supported by default
+        t2: Mapped[int] = mapped_column(default=42)
+
+    assert Model.migrate(autogenerate=True)
+    assert len(list(versions.glob("*.py"))) == 5
+
+    # clear all metadata to avoid influencing subsequent tests
+    model_registry.get_metadata().clear()
+
+    # drop remaining tables
+    assert Model.migrate(autogenerate=True)
+    assert len(list(versions.glob("*.py"))) == 6