Bläddra i källkod

reflex db migrate CLI and associated config (#1336)

Masen Furer 1 år sedan
förälder
incheckning
4a661a5395
9 ändrade filer med 172 tillägg och 80 borttagningar
  1. 1 1
      poetry.lock
  2. 1 1
      pyproject.toml
  3. 0 4
      reflex/app.py
  4. 4 1
      reflex/compiler/utils.py
  5. 62 38
      reflex/model.py
  6. 49 2
      reflex/reflex.py
  7. 0 14
      reflex/utils/build.py
  8. 37 1
      reflex/utils/prerequisites.py
  9. 18 18
      tests/test_model.py

+ 1 - 1
poetry.lock

@@ -2128,4 +2128,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.7"
-content-hash = "f4586d218b5320f0b595db276e8426db6bfffb406e8108b8a1bd9e785b6407c4"
+content-hash = "ac27016107e8a033aa39d9a712d3ef685132e22ede599a26214b17da6ff35829"

+ 1 - 1
pyproject.toml

@@ -45,6 +45,7 @@ websockets = "^10.4"
 starlette-admin = "^0.9.0"
 python-dotenv = "^0.13.0"
 importlib-metadata = {version = "^6.7.0", python = ">=3.7, <3.8"}
+alembic = "^1.11.1"
 
 [tool.poetry.group.dev.dependencies]
 pytest = "^7.1.2"
@@ -62,7 +63,6 @@ pandas = [
 ]
 asynctest = "^0.13.0"
 pre-commit = {version = "^3.2.1", python = ">=3.8,<4.0"}
-alembic = "^1.11.1"
 selenium = "^4.10.0"
 
 [tool.poetry.scripts]

+ 0 - 4
reflex/app.py

@@ -452,10 +452,6 @@ class App(Base):
         # Get the env mode.
         config = get_config()
 
-        # Update models during hot reload.
-        if config.db_url is not None and not Model.automigrate():
-            Model.create_all()
-
         # Empty the .web pages directory
         compiler.purge_web_pages_dir()
 

+ 4 - 1
reflex/compiler/utils.py

@@ -123,7 +123,10 @@ def compile_state(state: Type[State]) -> Dict:
     Returns:
         A dictionary of the compiled state.
     """
-    initial_state = state().dict()
+    try:
+        initial_state = state().dict()
+    except Exception:
+        initial_state = state().dict(include_computed=False)
     initial_state.update(
         {
             "events": [{"name": get_hydrate_event(state)}],

+ 62 - 38
reflex/model.py

@@ -4,26 +4,20 @@ from collections import defaultdict
 from pathlib import Path
 from typing import Any, Optional
 
+import alembic.autogenerate
+import alembic.command
+import alembic.config
+import alembic.operations.ops
+import alembic.runtime.environment
+import alembic.script
+import alembic.util
 import sqlalchemy
 import sqlmodel
 
+from reflex import constants
 from reflex.base import Base
 from reflex.config import get_config
-
-from . import constants
-
-try:
-    import alembic.autogenerate  # pyright: ignore [reportMissingImports]
-    import alembic.command  # pyright: ignore [reportMissingImports]
-    import alembic.operations.ops  # pyright: ignore [reportMissingImports]
-    import alembic.runtime.environment  # pyright: ignore [reportMissingImports]
-    import alembic.script  # pyright: ignore [reportMissingImports]
-    import alembic.util  # pyright: ignore [reportMissingImports]
-    from alembic.config import Config  # pyright: ignore [reportMissingImports]
-
-    has_alembic = True
-except ImportError:
-    has_alembic = False
+from reflex.utils import console
 
 
 def get_engine(url: Optional[str] = None):
@@ -42,6 +36,10 @@ def get_engine(url: Optional[str] = None):
     url = url or conf.db_url
     if url is None:
         raise ValueError("No database url configured")
+    if not Path(constants.ALEMBIC_CONFIG).exists():
+        console.print(
+            "[red]Database is not initialized, run [bold]reflex db init[/bold] first."
+        )
     return sqlmodel.create_engine(
         url,
         echo=False,
@@ -100,7 +98,7 @@ class Model(Base, sqlmodel.SQLModel):
         Returns:
             tuple of (config, script_directory)
         """
-        config = Config(constants.ALEMBIC_CONFIG)
+        config = alembic.config.Config(constants.ALEMBIC_CONFIG)
         return config, alembic.script.ScriptDirectory(
             config.get_main_option("script_location", default="version"),
         )
@@ -120,27 +118,45 @@ class Model(Base, sqlmodel.SQLModel):
         See https://alembic.sqlalchemy.org/en/latest/api/runtime.html
 
         Args:
-            type_: one of "schema", "table", "column", "index",
-                "unique_constraint", or "foreign_key_constraint"
-            obj: the object being rendered
-            autogen_context: shared AutogenContext passed to each render_item call
+            type_: One of "schema", "table", "column", "index",
+                "unique_constraint", or "foreign_key_constraint".
+            obj: The object being rendered.
+            autogen_context: Shared AutogenContext passed to each render_item call.
 
         Returns:
-            False - indicating that the default rendering should be used.
+            False - Indicating that the default rendering should be used.
         """
         autogen_context.imports.add("import sqlmodel")
         return False
 
     @classmethod
-    def _alembic_autogenerate(cls, connection: sqlalchemy.engine.Connection) -> bool:
+    def alembic_init(cls):
+        """Initialize alembic for the project."""
+        alembic.command.init(
+            config=alembic.config.Config(constants.ALEMBIC_CONFIG),
+            directory=str(Path(constants.ALEMBIC_CONFIG).parent / "alembic"),
+        )
+
+    @classmethod
+    def alembic_autogenerate(
+        cls,
+        connection: sqlalchemy.engine.Connection,
+        message: Optional[str] = None,
+        write_migration_scripts: bool = True,
+    ) -> bool:
         """Generate migration scripts for alembic-detectable changes.
 
         Args:
-            connection: sqlalchemy connection to use when detecting changes
+            connection: SQLAlchemy connection to use when detecting changes.
+            message: Human readable identifier describing the generated revision.
+            write_migration_scripts: If True, write autogenerated revisions to script directory.
 
         Returns:
             True when changes have been detected.
         """
+        if not Path(constants.ALEMBIC_CONFIG).exists():
+            return False
+
         config, script_directory = cls._alembic_config()
         revision_context = alembic.autogenerate.api.RevisionContext(
             config=config,
@@ -149,6 +165,7 @@ class Model(Base, sqlmodel.SQLModel):
                 lambda: None,
                 autogenerate=True,
                 head="head",
+                message=message,
             ),
         )
         writer = alembic.autogenerate.rewriter.Rewriter()
@@ -156,7 +173,7 @@ class Model(Base, sqlmodel.SQLModel):
         @writer.rewrites(alembic.operations.ops.AddColumnOp)
         def render_add_column_with_server_default(context, revision, op):
             # Carry the sqlmodel default as server_default so that newly added
-            # columns get the desired default value in existing rows
+            # columns get the desired default value in existing rows.
             if op.column.default is not None and op.column.server_default is None:
                 op.column.server_default = sqlalchemy.DefaultClause(
                     sqlalchemy.sql.expression.literal(op.column.default.arg),
@@ -184,9 +201,9 @@ class Model(Base, sqlmodel.SQLModel):
             upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops
             if upgrade_ops is not None:
                 changes_detected = bool(upgrade_ops.ops)
-        if changes_detected:
-            for _script in revision_context.generate_scripts():
-                pass  # must iterate to actually generate the scripts
+        if changes_detected and write_migration_scripts:
+            # Must iterate the generator to actually write the scripts.
+            _ = tuple(revision_context.generate_scripts())
         return changes_detected
 
     @classmethod
@@ -198,15 +215,14 @@ class Model(Base, sqlmodel.SQLModel):
         """Apply alembic migrations up to the given revision.
 
         Args:
-            connection: sqlalchemy connection to use when performing upgrade
-            to_rev: revision to migrate towards
+            connection: SQLAlchemy connection to use when performing upgrade.
+            to_rev: Revision to migrate towards.
         """
         config, script_directory = cls._alembic_config()
 
         def run_upgrade(rev, context):
             return script_directory._upgrade_revs(to_rev, rev)
 
-        # apply updates to database
         with alembic.runtime.environment.EnvironmentContext(
             config=config,
             script=script_directory,
@@ -216,28 +232,36 @@ class Model(Base, sqlmodel.SQLModel):
             env.run_migrations()
 
     @classmethod
-    def automigrate(cls) -> Optional[bool]:
-        """Generate and execute migrations for all sqlmodel Model classes.
+    def migrate(cls, autogenerate: bool = False) -> Optional[bool]:
+        """Execute alembic migrations for all sqlmodel Model classes.
 
         If alembic is not installed or has not been initialized for the project,
         then no action is performed.
 
+        If there are no revisions currently tracked by alembic, then
+        an initial revision will be created based on sqlmodel metadata.
+
         If models in the app have changed in incompatible ways that alembic
         cannot automatically generate revisions for, the app may not be able to
         start up until migration scripts have been corrected by hand.
 
+        Args:
+            autogenerate: If True, generate migration script and use it to upgrade schema
+                (otherwise, just bring the schema to current "head" revision).
+
         Returns:
-            True - indicating the process was successful
-            None - indicating the process was skipped
+            True - indicating the process was successful.
+            None - indicating the process was skipped.
         """
-        if not has_alembic or not Path(constants.ALEMBIC_CONFIG).exists():
+        if not Path(constants.ALEMBIC_CONFIG).exists():
             return
 
         with cls.get_db_engine().connect() as connection:
             cls._alembic_upgrade(connection=connection)
-            changes_detected = cls._alembic_autogenerate(connection=connection)
-            if changes_detected:
-                cls._alembic_upgrade(connection=connection)
+            if autogenerate:
+                changes_detected = cls.alembic_autogenerate(connection=connection)
+                if changes_detected:
+                    cls._alembic_upgrade(connection=connection)
             connection.commit()
         return True
 

+ 49 - 2
reflex/reflex.py

@@ -9,7 +9,7 @@ from pathlib import Path
 import httpx
 import typer
 
-from reflex import constants
+from reflex import constants, model
 from reflex.config import get_config
 from reflex.utils import build, console, exec, prerequisites, processes, telemetry
 
@@ -132,6 +132,9 @@ def run(
     # Check the admin dashboard settings.
     prerequisites.check_admin_settings()
 
+    # Warn if schema is not up to date.
+    prerequisites.check_schema_up_to_date()
+
     # Get the frontend and backend commands, based on the environment.
     setup_frontend = frontend_cmd = backend_cmd = None
     if env == constants.Env.DEV:
@@ -158,7 +161,6 @@ def run(
             target=frontend_cmd, args=(Path.cwd(), frontend_port, loglevel)
         ).start()
     if backend:
-        build.setup_backend()
         threading.Thread(
             target=backend_cmd,
             args=(app.__name__, backend_host, backend_port, loglevel),
@@ -258,6 +260,51 @@ def export(
         )
 
 
+db_cli = typer.Typer()
+
+
+@db_cli.command(name="init")
+def db_init():
+    """Create database schema and migration configuration."""
+    if get_config().db_url is None:
+        console.print("[red]db_url is not configured, cannot initialize.")
+    if Path(constants.ALEMBIC_CONFIG).exists():
+        console.print(
+            "[red]Database is already initialized. Use "
+            "[bold]reflex db makemigrations[/bold] to create schema change "
+            "scripts and [bold]reflex db migrate[/bold] to apply migrations "
+            "to a new or existing database.",
+        )
+    prerequisites.get_app()
+    model.Model.alembic_init()
+    model.Model.migrate(autogenerate=True)
+
+
+@db_cli.command()
+def migrate():
+    """Create or update database schema based on app models or existing migration scripts."""
+    prerequisites.get_app()
+    if not prerequisites.check_db_initialized():
+        return
+    model.Model.migrate()
+    prerequisites.check_schema_up_to_date()
+
+
+@db_cli.command()
+def makemigrations(
+    message: str = typer.Option(
+        None, help="Human readable identifier for the generated revision."
+    ),
+):
+    """Create autogenerated alembic migration scripts."""
+    prerequisites.get_app()
+    if not prerequisites.check_db_initialized():
+        return
+    with model.Model.get_db_engine().connect() as connection:
+        model.Model.alembic_autogenerate(connection=connection, message=message)
+
+
+cli.add_typer(db_cli, name="db", help="Subcommands for managing the database schema")
 main = cli
 
 if __name__ == "__main__":

+ 0 - 14
reflex/utils/build.py

@@ -12,7 +12,6 @@ from typing import Optional, Union
 from rich.progress import Progress
 
 from reflex import constants
-from reflex.config import get_config
 from reflex.utils import path_ops, prerequisites
 from reflex.utils.processes import new_process
 
@@ -240,16 +239,3 @@ def setup_frontend_prod(
     """
     setup_frontend(root, loglevel, disable_telemetry)
     export_app(loglevel=loglevel)
-
-
-def setup_backend():
-    """Set up backend.
-
-    Specifically ensures backend database is updated when running --no-frontend.
-    """
-    # Import here to avoid circular imports.
-    from reflex.model import Model
-
-    config = get_config()
-    if config.db_url is not None:
-        Model.create_all()

+ 37 - 1
reflex/utils/prerequisites.py

@@ -16,10 +16,11 @@ from types import ModuleType
 from typing import Optional
 
 import typer
+from alembic.util.exc import CommandError
 from packaging import version
 from redis import Redis
 
-from reflex import constants
+from reflex import constants, model
 from reflex.config import get_config
 from reflex.utils import console, path_ops
 
@@ -370,6 +371,41 @@ def check_admin_settings():
             )
 
 
+def check_db_initialized() -> bool:
+    """Check if the database migrations are initialized.
+
+    Returns:
+        True if alembic is initialized (or if database is not used).
+    """
+    if get_config().db_url is not None and not Path(constants.ALEMBIC_CONFIG).exists():
+        console.print(
+            "[red]Database is not initialized. Run [bold]reflex db init[/bold] first."
+        )
+        return False
+    return True
+
+
+def check_schema_up_to_date():
+    """Check if the sqlmodel metadata matches the current database schema."""
+    if get_config().db_url is None or not Path(constants.ALEMBIC_CONFIG).exists():
+        return
+    with model.Model.get_db_engine().connect() as connection:
+        try:
+            if model.Model.alembic_autogenerate(
+                connection=connection,
+                write_migration_scripts=False,
+            ):
+                console.print(
+                    "[red]Detected database schema changes. Run [bold]reflex db makemigrations[/bold] "
+                    "to generate migration scripts.",
+                )
+        except CommandError as command_error:
+            if "Target database is not up to date." in str(command_error):
+                console.print(
+                    f"[red]{command_error} Run [bold]reflex db migrate[/bold] to update database."
+                )
+
+
 def migrate_to_reflex():
     """Migration from Pynecone to Reflex."""
     # Check if the old config file exists.

+ 18 - 18
tests/test_model.py

@@ -1,5 +1,3 @@
-import subprocess
-import sys
 from unittest import mock
 
 import pytest
@@ -68,26 +66,28 @@ def test_automigration(tmp_working_dir, monkeypatch):
         tmp_working_dir: directory where database and migrations are stored
         monkeypatch: pytest fixture to overwrite attributes
     """
-    subprocess.run(
-        [sys.executable, "-m", "alembic", "init", "alembic"],
-        cwd=tmp_working_dir,
-    )
     alembic_ini = tmp_working_dir / "alembic.ini"
     versions = tmp_working_dir / "alembic" / "versions"
-    assert alembic_ini.exists()
-    assert versions.exists()
+    monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
 
     config_mock = mock.Mock()
     config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
     monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
-    monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
+
+    Model.alembic_init()
+    assert alembic_ini.exists()
+    assert versions.exists()
 
     # initial table
     class AlembicThing(Model, table=True):  # type: ignore
         t1: str
 
-    Model.automigrate()
-    assert len(list(versions.glob("*.py"))) == 1
+    with Model.get_db_engine().connect() as connection:
+        Model.alembic_autogenerate(connection=connection, message="Initial Revision")
+    Model.migrate()
+    version_scripts = list(versions.glob("*.py"))
+    assert len(version_scripts) == 1
+    assert version_scripts[0].name.endswith("initial_revision.py")
 
     with reflex.model.session() as session:
         session.add(AlembicThing(id=None, t1="foo"))
@@ -100,7 +100,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
         t1: str
         t2: str = "bar"
 
-    Model.automigrate()
+    Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 2
 
     with reflex.model.session() as session:
@@ -114,7 +114,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
     class AlembicThing(Model, table=True):  # type: ignore
         t2: str = "bar"
 
-    Model.automigrate()
+    Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 3
 
     with reflex.model.session() as session:
@@ -127,7 +127,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
         a: int = 42
         b: float = 4.2
 
-    Model.automigrate()
+    Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 4
 
     with reflex.model.session() as session:
@@ -139,7 +139,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
         assert result[0].b == 4.2
 
     # No-op
-    Model.automigrate()
+    Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 4
 
     # drop table (AlembicSecond)
@@ -148,7 +148,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
     class AlembicThing(Model, table=True):  # type: ignore
         t2: str = "bar"
 
-    Model.automigrate()
+    Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 5
 
     with reflex.model.session() as session:
@@ -166,12 +166,12 @@ def test_automigration(tmp_working_dir, monkeypatch):
         # changing column type not supported by default
         t2: int = 42
 
-    Model.automigrate()
+    Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 5
 
     # clear all metadata to avoid influencing subsequent tests
     sqlmodel.SQLModel.metadata.clear()
 
     # drop remaining tables
-    Model.automigrate()
+    Model.migrate(autogenerate=True)
     assert len(list(versions.glob("*.py"))) == 6