Parcourir la source

rx.Model: automigrate using alembic (#1321)

Masen Furer il y a 1 an
Parent
commit
5505d10989
7 fichiers modifiés avec 406 ajouts et 3 suppressions
  1. 60 1
      poetry.lock
  2. 1 0
      pyproject.toml
  3. 1 1
      reflex/app.py
  4. 3 0
      reflex/constants.py
  5. 167 1
      reflex/model.py
  6. 48 0
      tests/conftest.py
  7. 126 0
      tests/test_model.py

+ 60 - 1
poetry.lock

@@ -1,5 +1,26 @@
 # This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
 
+[[package]]
+name = "alembic"
+version = "1.11.1"
+description = "A database migration tool for SQLAlchemy."
+optional = false
+python-versions = ">=3.7"
+files = [
+    {file = "alembic-1.11.1-py3-none-any.whl", hash = "sha256:dc871798a601fab38332e38d6ddb38d5e734f60034baeb8e2db5b642fccd8ab8"},
+    {file = "alembic-1.11.1.tar.gz", hash = "sha256:6a810a6b012c88b33458fceb869aef09ac75d6ace5291915ba7fae44de372c01"},
+]
+
+[package.dependencies]
+importlib-metadata = {version = "*", markers = "python_version < \"3.9\""}
+importlib-resources = {version = "*", markers = "python_version < \"3.9\""}
+Mako = "*"
+SQLAlchemy = ">=1.3.0"
+typing-extensions = ">=4"
+
+[package.extras]
+tz = ["python-dateutil"]
+
 [[package]]
 name = "anyio"
 version = "3.7.1"
@@ -501,6 +522,24 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker
 perf = ["ipython"]
 testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
 
+[[package]]
+name = "importlib-resources"
+version = "5.12.0"
+description = "Read resources from Python packages"
+optional = false
+python-versions = ">=3.7"
+files = [
+    {file = "importlib_resources-5.12.0-py3-none-any.whl", hash = "sha256:7b1deeebbf351c7578e09bf2f63fa2ce8b5ffec296e0d349139d43cca061a81a"},
+    {file = "importlib_resources-5.12.0.tar.gz", hash = "sha256:4be82589bf5c1d7999aedf2a45159d10cb3ca4f19b2271f8792bc8e6da7b22f6"},
+]
+
+[package.dependencies]
+zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
+
 [[package]]
 name = "iniconfig"
 version = "2.0.0"
@@ -529,6 +568,26 @@ MarkupSafe = ">=2.0"
 [package.extras]
 i18n = ["Babel (>=2.7)"]
 
+[[package]]
+name = "mako"
+version = "1.2.4"
+description = "A super-fast templating language that borrows the best ideas from the existing templating languages."
+optional = false
+python-versions = ">=3.7"
+files = [
+    {file = "Mako-1.2.4-py3-none-any.whl", hash = "sha256:c97c79c018b9165ac9922ae4f32da095ffd3c4e6872b45eded42926deea46818"},
+    {file = "Mako-1.2.4.tar.gz", hash = "sha256:d60a3903dc3bb01a18ad6a89cdbe2e4eadc69c0bc8ef1e3773ba53d44c3f7a34"},
+]
+
+[package.dependencies]
+importlib-metadata = {version = "*", markers = "python_version < \"3.8\""}
+MarkupSafe = ">=0.9.2"
+
+[package.extras]
+babel = ["Babel"]
+lingua = ["lingua"]
+testing = ["pytest"]
+
 [[package]]
 name = "markdown-it-py"
 version = "2.2.0"
@@ -1854,4 +1913,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.7"
-content-hash = "b07623f193778651dd3ece81370d2d9a2ba3b53855a57baf0147e67bf07aaed8"
+content-hash = "cc659e46041316bc81ce1758334c6fa9ccf9812612ef67e170b173d6d1caa2b2"

+ 1 - 0
pyproject.toml

@@ -62,6 +62,7 @@ pandas = [
 ]
 asynctest = "^0.13.0"
 pre-commit = {version = "^3.2.1", python = ">=3.8,<4.0"}
+alembic = "^1.11.1"
 
 [tool.poetry.scripts]
 reflex = "reflex.reflex:main"

+ 1 - 1
reflex/app.py

@@ -444,7 +444,7 @@ class App(Base):
         config = get_config()
 
         # Update models during hot reload.
-        if config.db_url is not None:
+        if config.db_url is not None and not Model.automigrate():
             Model.create_all()
 
         # Empty the .web pages directory

+ 3 - 0
reflex/constants.py

@@ -353,3 +353,6 @@ TOGGLE_COLOR_MODE = "toggleColorMode"
 # Server socket configuration variables
 CORS_ALLOWED_ORIGINS = get_value("CORS_ALLOWED_ORIGINS", ["*"], list)
 POLLING_MAX_HTTP_BUFFER_SIZE = 1000 * 1000
+
+# Alembic migrations
+ALEMBIC_CONFIG = os.environ.get("ALEMBIC_CONFIG", "alembic.ini")

+ 167 - 1
reflex/model.py

@@ -1,12 +1,30 @@
 """Database built into Reflex."""
 
-from typing import Optional
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Optional
 
+import sqlalchemy
 import sqlmodel
 
 from reflex.base import Base
 from reflex.config import get_config
 
+from . import constants
+
+try:
+    import alembic.autogenerate  # pyright: ignore [reportMissingImports]
+    import alembic.command  # pyright: ignore [reportMissingImports]
+    import alembic.operations.ops  # pyright: ignore [reportMissingImports]
+    import alembic.runtime.environment  # pyright: ignore [reportMissingImports]
+    import alembic.script  # pyright: ignore [reportMissingImports]
+    import alembic.util  # pyright: ignore [reportMissingImports]
+    from alembic.config import Config  # pyright: ignore [reportMissingImports]
+
+    has_alembic = True
+except ImportError:
+    has_alembic = False
+
 
 def get_engine(url: Optional[str] = None):
     """Get the database engine.
@@ -75,6 +93,154 @@ class Model(Base, sqlmodel.SQLModel):
         """
         return get_engine()
 
+    @staticmethod
+    def _alembic_config():
+        """Get the alembic configuration and script_directory.
+
+        Returns:
+            tuple of (config, script_directory)
+        """
+        config = Config(constants.ALEMBIC_CONFIG)
+        return config, alembic.script.ScriptDirectory(
+            config.get_main_option("script_location", default="version"),
+        )
+
+    @staticmethod
+    def _alembic_render_item(
+        type_: str,
+        obj: Any,
+        autogen_context: "alembic.autogenerate.api.AutogenContext",
+    ):
+        """Alembic render_item hook call.
+
+        This method is called to provide python code for the given obj,
+        but currently it is only used to add `sqlmodel` to the import list
+        when generating migration scripts.
+
+        See https://alembic.sqlalchemy.org/en/latest/api/runtime.html
+
+        Args:
+            type_: one of "schema", "table", "column", "index",
+                "unique_constraint", or "foreign_key_constraint"
+            obj: the object being rendered
+            autogen_context: shared AutogenContext passed to each render_item call
+
+        Returns:
+            False - indicating that the default rendering should be used.
+        """
+        autogen_context.imports.add("import sqlmodel")
+        return False
+
+    @classmethod
+    def _alembic_autogenerate(cls, connection: sqlalchemy.engine.Connection) -> bool:
+        """Generate migration scripts for alembic-detectable changes.
+
+        Args:
+            connection: sqlalchemy connection to use when detecting changes
+
+        Returns:
+            True when changes have been detected.
+        """
+        config, script_directory = cls._alembic_config()
+        revision_context = alembic.autogenerate.api.RevisionContext(
+            config=config,
+            script_directory=script_directory,
+            command_args=defaultdict(
+                lambda: None,
+                autogenerate=True,
+                head="head",
+            ),
+        )
+        writer = alembic.autogenerate.rewriter.Rewriter()
+
+        @writer.rewrites(alembic.operations.ops.AddColumnOp)
+        def render_add_column_with_server_default(context, revision, op):
+            # Carry the sqlmodel default as server_default so that newly added
+            # columns get the desired default value in existing rows
+            if op.column.default is not None and op.column.server_default is None:
+                op.column.server_default = sqlalchemy.DefaultClause(
+                    sqlalchemy.sql.expression.literal(op.column.default.arg),
+                )
+            return op
+
+        def run_autogenerate(rev, context):
+            revision_context.run_autogenerate(rev, context)
+            return []
+
+        with alembic.runtime.environment.EnvironmentContext(
+            config=config,
+            script=script_directory,
+            fn=run_autogenerate,
+        ) as env:
+            env.configure(
+                connection=connection,
+                target_metadata=sqlmodel.SQLModel.metadata,
+                render_item=cls._alembic_render_item,
+                process_revision_directives=writer,  # type: ignore
+            )
+            env.run_migrations()
+        changes_detected = False
+        if revision_context.generated_revisions:
+            upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops
+            if upgrade_ops is not None:
+                changes_detected = bool(upgrade_ops.ops)
+        if changes_detected:
+            for _script in revision_context.generate_scripts():
+                pass  # must iterate to actually generate the scripts
+        return changes_detected
+
+    @classmethod
+    def _alembic_upgrade(
+        cls,
+        connection: sqlalchemy.engine.Connection,
+        to_rev: str = "head",
+    ) -> None:
+        """Apply alembic migrations up to the given revision.
+
+        Args:
+            connection: sqlalchemy connection to use when performing upgrade
+            to_rev: revision to migrate towards
+        """
+        config, script_directory = cls._alembic_config()
+
+        def run_upgrade(rev, context):
+            return script_directory._upgrade_revs(to_rev, rev)
+
+        # apply updates to database
+        with alembic.runtime.environment.EnvironmentContext(
+            config=config,
+            script=script_directory,
+            fn=run_upgrade,
+        ) as env:
+            env.configure(connection=connection)
+            env.run_migrations()
+
+    @classmethod
+    def automigrate(cls) -> Optional[bool]:
+        """Generate and execute migrations for all sqlmodel Model classes.
+
+        If alembic is not installed or has not been initialized for the project,
+        then no action is performed.
+
+        If models in the app have changed in incompatible ways that alembic
+        cannot automatically generate revisions for, the app may not be able to
+        start up until migration scripts have been corrected by hand.
+
+        Returns:
+            True - indicating the process was successful
+            None - indicating the process was skipped
+        """
+        if not has_alembic or not Path(constants.ALEMBIC_CONFIG).exists():
+            return
+
+        with cls.get_db_engine().connect() as connection:
+            cls._alembic_upgrade(connection=connection)
+            changes_detected = cls._alembic_autogenerate(connection=connection)
+            if changes_detected:
+                cls._alembic_upgrade(connection=connection)
+            connection.commit()
+        return True
+
     @classmethod
     @property
     def select(cls):

+ 48 - 0
tests/conftest.py

@@ -1,5 +1,8 @@
 """Test fixtures."""
+import contextlib
+import os
 import platform
+from pathlib import Path
 from typing import Dict, Generator, List
 
 import pytest
@@ -479,3 +482,48 @@ def router_data(router_data_headers) -> Dict[str, str]:
         "headers": router_data_headers,
         "ip": "127.0.0.1",
     }
+
+
+# borrowed from py3.11
+class chdir(contextlib.AbstractContextManager):
+    """Non thread-safe context manager to change the current working directory."""
+
+    def __init__(self, path):
+        """Prepare contextmanager.
+
+        Args:
+            path: the path to change to
+        """
+        self.path = path
+        self._old_cwd = []
+
+    def __enter__(self):
+        """Save current directory and perform chdir."""
+        self._old_cwd.append(Path(".").resolve())
+        os.chdir(self.path)
+
+    def __exit__(self, *excinfo):
+        """Change back to previous directory on stack.
+
+        Args:
+            excinfo: sys.exc_info captured in the context block
+        """
+        os.chdir(self._old_cwd.pop())
+
+
+@pytest.fixture
+def tmp_working_dir(tmp_path):
+    """Create a temporary directory and chdir to it.
+
+    After the test executes, chdir back to the original working directory.
+
+    Args:
+        tmp_path: pytest tmp_path fixture creates per-test temp dir
+
+    Yields:
+        subdirectory of tmp_path which is now the current working directory.
+    """
+    working_dir = tmp_path / "working_dir"
+    working_dir.mkdir()
+    with chdir(working_dir):
+        yield working_dir

+ 126 - 0
tests/test_model.py

@@ -1,6 +1,13 @@
+import subprocess
+import sys
+from unittest import mock
+
 import pytest
+import sqlalchemy
 import sqlmodel
 
+import reflex.constants
+import reflex.model
 from reflex.model import Model
 
 
@@ -49,3 +56,122 @@ def test_custom_primary_key(model_custom_primary):
         model_custom_primary: Fixture.
     """
     assert "id" not in model_custom_primary.__class__.__fields__
+
+
+@pytest.mark.filterwarnings(
+    "ignore:This declarative base already contains a class with the same class name",
+)
+def test_automigration(tmp_working_dir, monkeypatch):
+    """Test alembic automigration with add and drop table and column.
+
+    Args:
+        tmp_working_dir: directory where database and migrations are stored
+        monkeypatch: pytest fixture to overwrite attributes
+    """
+    subprocess.run(
+        [sys.executable, "-m", "alembic", "init", "alembic"],
+        cwd=tmp_working_dir,
+    )
+    alembic_ini = tmp_working_dir / "alembic.ini"
+    versions = tmp_working_dir / "alembic" / "versions"
+    assert alembic_ini.exists()
+    assert versions.exists()
+
+    config_mock = mock.Mock()
+    config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
+    monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
+    monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
+
+    # initial table
+    class AlembicThing(Model, table=True):  # type: ignore
+        t1: str
+
+    Model.automigrate()
+    assert len(list(versions.glob("*.py"))) == 1
+
+    with reflex.model.session() as session:
+        session.add(AlembicThing(id=None, t1="foo"))
+        session.commit()
+
+    sqlmodel.SQLModel.metadata.clear()
+
+    # Create column t2
+    class AlembicThing(Model, table=True):  # type: ignore
+        t1: str
+        t2: str = "bar"
+
+    Model.automigrate()
+    assert len(list(versions.glob("*.py"))) == 2
+
+    with reflex.model.session() as session:
+        result = session.exec(sqlmodel.select(AlembicThing)).all()
+        assert len(result) == 1
+        assert result[0].t2 == "bar"
+
+    sqlmodel.SQLModel.metadata.clear()
+
+    # Drop column t1
+    class AlembicThing(Model, table=True):  # type: ignore
+        t2: str = "bar"
+
+    Model.automigrate()
+    assert len(list(versions.glob("*.py"))) == 3
+
+    with reflex.model.session() as session:
+        result = session.exec(sqlmodel.select(AlembicThing)).all()
+        assert len(result) == 1
+        assert result[0].t2 == "bar"
+
+    # Add table
+    class AlembicSecond(Model, table=True):  # type: ignore
+        a: int = 42
+        b: float = 4.2
+
+    Model.automigrate()
+    assert len(list(versions.glob("*.py"))) == 4
+
+    with reflex.model.session() as session:
+        session.add(AlembicSecond(id=None))
+        session.commit()
+        result = session.exec(sqlmodel.select(AlembicSecond)).all()
+        assert len(result) == 1
+        assert result[0].a == 42
+        assert result[0].b == 4.2
+
+    # No-op
+    Model.automigrate()
+    assert len(list(versions.glob("*.py"))) == 4
+
+    # drop table (AlembicSecond)
+    sqlmodel.SQLModel.metadata.clear()
+
+    class AlembicThing(Model, table=True):  # type: ignore
+        t2: str = "bar"
+
+    Model.automigrate()
+    assert len(list(versions.glob("*.py"))) == 5
+
+    with reflex.model.session() as session:
+        with pytest.raises(sqlalchemy.exc.OperationalError) as errctx:  # type: ignore
+            session.exec(sqlmodel.select(AlembicSecond)).all()
+        assert errctx.match(r"no such table: alembicsecond")
+        # first table should still exist
+        result = session.exec(sqlmodel.select(AlembicThing)).all()
+        assert len(result) == 1
+        assert result[0].t2 == "bar"
+
+    sqlmodel.SQLModel.metadata.clear()
+
+    class AlembicThing(Model, table=True):  # type: ignore
+        # changing column type not supported by default
+        t2: int = 42
+
+    Model.automigrate()
+    assert len(list(versions.glob("*.py"))) == 5
+
+    # clear all metadata to avoid influencing subsequent tests
+    sqlmodel.SQLModel.metadata.clear()
+
+    # drop remaining tables
+    Model.automigrate()
+    assert len(list(versions.glob("*.py"))) == 6