|
@@ -4,26 +4,20 @@ from collections import defaultdict
|
|
|
from pathlib import Path
|
|
|
from typing import Any, Optional
|
|
|
|
|
|
+import alembic.autogenerate
|
|
|
+import alembic.command
|
|
|
+import alembic.config
|
|
|
+import alembic.operations.ops
|
|
|
+import alembic.runtime.environment
|
|
|
+import alembic.script
|
|
|
+import alembic.util
|
|
|
import sqlalchemy
|
|
|
import sqlmodel
|
|
|
|
|
|
+from reflex import constants
|
|
|
from reflex.base import Base
|
|
|
from reflex.config import get_config
|
|
|
-
|
|
|
-from . import constants
|
|
|
-
|
|
|
-try:
|
|
|
- import alembic.autogenerate # pyright: ignore [reportMissingImports]
|
|
|
- import alembic.command # pyright: ignore [reportMissingImports]
|
|
|
- import alembic.operations.ops # pyright: ignore [reportMissingImports]
|
|
|
- import alembic.runtime.environment # pyright: ignore [reportMissingImports]
|
|
|
- import alembic.script # pyright: ignore [reportMissingImports]
|
|
|
- import alembic.util # pyright: ignore [reportMissingImports]
|
|
|
- from alembic.config import Config # pyright: ignore [reportMissingImports]
|
|
|
-
|
|
|
- has_alembic = True
|
|
|
-except ImportError:
|
|
|
- has_alembic = False
|
|
|
+from reflex.utils import console
|
|
|
|
|
|
|
|
|
def get_engine(url: Optional[str] = None):
|
|
@@ -42,6 +36,10 @@ def get_engine(url: Optional[str] = None):
|
|
|
url = url or conf.db_url
|
|
|
if url is None:
|
|
|
raise ValueError("No database url configured")
|
|
|
+ if not Path(constants.ALEMBIC_CONFIG).exists():
|
|
|
+ console.print(
|
|
|
+ "[red]Database is not initialized, run [bold]reflex db init[/bold] first."
|
|
|
+ )
|
|
|
return sqlmodel.create_engine(
|
|
|
url,
|
|
|
echo=False,
|
|
@@ -100,7 +98,7 @@ class Model(Base, sqlmodel.SQLModel):
|
|
|
Returns:
|
|
|
tuple of (config, script_directory)
|
|
|
"""
|
|
|
- config = Config(constants.ALEMBIC_CONFIG)
|
|
|
+ config = alembic.config.Config(constants.ALEMBIC_CONFIG)
|
|
|
return config, alembic.script.ScriptDirectory(
|
|
|
config.get_main_option("script_location", default="version"),
|
|
|
)
|
|
@@ -120,27 +118,45 @@ class Model(Base, sqlmodel.SQLModel):
|
|
|
See https://alembic.sqlalchemy.org/en/latest/api/runtime.html
|
|
|
|
|
|
Args:
|
|
|
- type_: one of "schema", "table", "column", "index",
|
|
|
- "unique_constraint", or "foreign_key_constraint"
|
|
|
- obj: the object being rendered
|
|
|
- autogen_context: shared AutogenContext passed to each render_item call
|
|
|
+ type_: One of "schema", "table", "column", "index",
|
|
|
+ "unique_constraint", or "foreign_key_constraint".
|
|
|
+ obj: The object being rendered.
|
|
|
+ autogen_context: Shared AutogenContext passed to each render_item call.
|
|
|
|
|
|
Returns:
|
|
|
- False - indicating that the default rendering should be used.
|
|
|
+ False - Indicating that the default rendering should be used.
|
|
|
"""
|
|
|
autogen_context.imports.add("import sqlmodel")
|
|
|
return False
|
|
|
|
|
|
@classmethod
|
|
|
- def _alembic_autogenerate(cls, connection: sqlalchemy.engine.Connection) -> bool:
|
|
|
+ def alembic_init(cls):
|
|
|
+ """Initialize alembic for the project."""
|
|
|
+ alembic.command.init(
|
|
|
+ config=alembic.config.Config(constants.ALEMBIC_CONFIG),
|
|
|
+ directory=str(Path(constants.ALEMBIC_CONFIG).parent / "alembic"),
|
|
|
+ )
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def alembic_autogenerate(
|
|
|
+ cls,
|
|
|
+ connection: sqlalchemy.engine.Connection,
|
|
|
+ message: Optional[str] = None,
|
|
|
+ write_migration_scripts: bool = True,
|
|
|
+ ) -> bool:
|
|
|
"""Generate migration scripts for alembic-detectable changes.
|
|
|
|
|
|
Args:
|
|
|
- connection: sqlalchemy connection to use when detecting changes
|
|
|
+ connection: SQLAlchemy connection to use when detecting changes.
|
|
|
+ message: Human readable identifier describing the generated revision.
|
|
|
+ write_migration_scripts: If True, write autogenerated revisions to script directory.
|
|
|
|
|
|
Returns:
|
|
|
True when changes have been detected.
|
|
|
"""
|
|
|
+ if not Path(constants.ALEMBIC_CONFIG).exists():
|
|
|
+ return False
|
|
|
+
|
|
|
config, script_directory = cls._alembic_config()
|
|
|
revision_context = alembic.autogenerate.api.RevisionContext(
|
|
|
config=config,
|
|
@@ -149,6 +165,7 @@ class Model(Base, sqlmodel.SQLModel):
|
|
|
lambda: None,
|
|
|
autogenerate=True,
|
|
|
head="head",
|
|
|
+ message=message,
|
|
|
),
|
|
|
)
|
|
|
writer = alembic.autogenerate.rewriter.Rewriter()
|
|
@@ -156,7 +173,7 @@ class Model(Base, sqlmodel.SQLModel):
|
|
|
@writer.rewrites(alembic.operations.ops.AddColumnOp)
|
|
|
def render_add_column_with_server_default(context, revision, op):
|
|
|
# Carry the sqlmodel default as server_default so that newly added
|
|
|
- # columns get the desired default value in existing rows
|
|
|
+ # columns get the desired default value in existing rows.
|
|
|
if op.column.default is not None and op.column.server_default is None:
|
|
|
op.column.server_default = sqlalchemy.DefaultClause(
|
|
|
sqlalchemy.sql.expression.literal(op.column.default.arg),
|
|
@@ -184,9 +201,9 @@ class Model(Base, sqlmodel.SQLModel):
|
|
|
upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops
|
|
|
if upgrade_ops is not None:
|
|
|
changes_detected = bool(upgrade_ops.ops)
|
|
|
- if changes_detected:
|
|
|
- for _script in revision_context.generate_scripts():
|
|
|
- pass # must iterate to actually generate the scripts
|
|
|
+ if changes_detected and write_migration_scripts:
|
|
|
+ # Must iterate the generator to actually write the scripts.
|
|
|
+ _ = tuple(revision_context.generate_scripts())
|
|
|
return changes_detected
|
|
|
|
|
|
@classmethod
|
|
@@ -198,15 +215,14 @@ class Model(Base, sqlmodel.SQLModel):
|
|
|
"""Apply alembic migrations up to the given revision.
|
|
|
|
|
|
Args:
|
|
|
- connection: sqlalchemy connection to use when performing upgrade
|
|
|
- to_rev: revision to migrate towards
|
|
|
+ connection: SQLAlchemy connection to use when performing upgrade.
|
|
|
+ to_rev: Revision to migrate towards.
|
|
|
"""
|
|
|
config, script_directory = cls._alembic_config()
|
|
|
|
|
|
def run_upgrade(rev, context):
|
|
|
return script_directory._upgrade_revs(to_rev, rev)
|
|
|
|
|
|
- # apply updates to database
|
|
|
with alembic.runtime.environment.EnvironmentContext(
|
|
|
config=config,
|
|
|
script=script_directory,
|
|
@@ -216,28 +232,36 @@ class Model(Base, sqlmodel.SQLModel):
|
|
|
env.run_migrations()
|
|
|
|
|
|
@classmethod
|
|
|
- def automigrate(cls) -> Optional[bool]:
|
|
|
- """Generate and execute migrations for all sqlmodel Model classes.
|
|
|
+ def migrate(cls, autogenerate: bool = False) -> Optional[bool]:
|
|
|
+ """Execute alembic migrations for all sqlmodel Model classes.
|
|
|
|
|
|
If alembic is not installed or has not been initialized for the project,
|
|
|
then no action is performed.
|
|
|
|
|
|
+ If there are no revisions currently tracked by alembic, then
|
|
|
+ an initial revision will be created based on sqlmodel metadata.
|
|
|
+
|
|
|
If models in the app have changed in incompatible ways that alembic
|
|
|
cannot automatically generate revisions for, the app may not be able to
|
|
|
start up until migration scripts have been corrected by hand.
|
|
|
|
|
|
+ Args:
|
|
|
+ autogenerate: If True, generate migration script and use it to upgrade schema
|
|
|
+ (otherwise, just bring the schema to current "head" revision).
|
|
|
+
|
|
|
Returns:
|
|
|
- True - indicating the process was successful
|
|
|
- None - indicating the process was skipped
|
|
|
+ True - indicating the process was successful.
|
|
|
+ None - indicating the process was skipped.
|
|
|
"""
|
|
|
- if not has_alembic or not Path(constants.ALEMBIC_CONFIG).exists():
|
|
|
+ if not Path(constants.ALEMBIC_CONFIG).exists():
|
|
|
return
|
|
|
|
|
|
with cls.get_db_engine().connect() as connection:
|
|
|
cls._alembic_upgrade(connection=connection)
|
|
|
- changes_detected = cls._alembic_autogenerate(connection=connection)
|
|
|
- if changes_detected:
|
|
|
- cls._alembic_upgrade(connection=connection)
|
|
|
+ if autogenerate:
|
|
|
+ changes_detected = cls.alembic_autogenerate(connection=connection)
|
|
|
+ if changes_detected:
|
|
|
+ cls._alembic_upgrade(connection=connection)
|
|
|
connection.commit()
|
|
|
return True
|
|
|
|