|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
import os
|
|
import os
|
|
from collections import defaultdict
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
-from typing import Any, Optional
|
|
|
|
|
|
+from typing import Any, ClassVar, Optional, Type, Union
|
|
|
|
|
|
import alembic.autogenerate
|
|
import alembic.autogenerate
|
|
import alembic.command
|
|
import alembic.command
|
|
@@ -51,6 +51,88 @@ def get_engine(url: str | None = None):
|
|
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
|
|
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
|
|
|
|
|
|
|
|
|
|
|
|
+SQLModelOrSqlAlchemy = Union[
|
|
|
|
+ Type[sqlmodel.SQLModel], Type[sqlalchemy.orm.DeclarativeBase]
|
|
|
|
+]
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class ModelRegistry:
|
|
|
|
+ """Registry for all models."""
|
|
|
|
+
|
|
|
|
+ models: ClassVar[set[SQLModelOrSqlAlchemy]] = set()
|
|
|
|
+
|
|
|
|
+ # Cache the metadata to avoid re-creating it.
|
|
|
|
+ _metadata: ClassVar[sqlalchemy.MetaData | None] = None
|
|
|
|
+
|
|
|
|
+ @classmethod
|
|
|
|
+ def register(cls, model: SQLModelOrSqlAlchemy):
|
|
|
|
+ """Register a model. Can be used directly or as a decorator.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ model: The model to register.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ The model passed in as an argument (Allows decorator usage)
|
|
|
|
+ """
|
|
|
|
+ cls.models.add(model)
|
|
|
|
+ return model
|
|
|
|
+
|
|
|
|
+ @classmethod
|
|
|
|
+ def get_models(cls, include_empty: bool = False) -> set[SQLModelOrSqlAlchemy]:
|
|
|
|
+ """Get registered models.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ include_empty: If True, include models with empty metadata.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ The registered models.
|
|
|
|
+ """
|
|
|
|
+ if include_empty:
|
|
|
|
+ return cls.models
|
|
|
|
+ return {
|
|
|
|
+ model for model in cls.models if not cls._model_metadata_is_empty(model)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @staticmethod
|
|
|
|
+ def _model_metadata_is_empty(model: SQLModelOrSqlAlchemy) -> bool:
|
|
|
|
+ """Check if the model metadata is empty.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ model: The model to check.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ True if the model metadata is empty, False otherwise.
|
|
|
|
+ """
|
|
|
|
+ return len(model.metadata.tables) == 0
|
|
|
|
+
|
|
|
|
+ @classmethod
|
|
|
|
+ def get_metadata(cls) -> sqlalchemy.MetaData:
|
|
|
|
+ """Get the database metadata.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ The database metadata.
|
|
|
|
+ """
|
|
|
|
+ if cls._metadata is not None:
|
|
|
|
+ return cls._metadata
|
|
|
|
+
|
|
|
|
+ models = cls.get_models(include_empty=False)
|
|
|
|
+
|
|
|
|
+ if len(models) == 1:
|
|
|
|
+ metadata = next(iter(models)).metadata
|
|
|
|
+ else:
|
|
|
|
+ # Merge the metadata from all the models.
|
|
|
|
+ # This allows mixing bare sqlalchemy models with sqlmodel models in one database.
|
|
|
|
+ metadata = sqlalchemy.MetaData()
|
|
|
|
+ for model in cls.get_models():
|
|
|
|
+ for table in model.metadata.tables.values():
|
|
|
|
+ table.to_metadata(metadata)
|
|
|
|
+
|
|
|
|
+ # Cache the metadata
|
|
|
|
+ cls._metadata = metadata
|
|
|
|
+
|
|
|
|
+ return metadata
|
|
|
|
+
|
|
|
|
+
|
|
class Model(Base, sqlmodel.SQLModel):
|
|
class Model(Base, sqlmodel.SQLModel):
|
|
"""Base class to define a table in the database."""
|
|
"""Base class to define a table in the database."""
|
|
|
|
|
|
@@ -113,7 +195,7 @@ class Model(Base, sqlmodel.SQLModel):
|
|
def create_all():
|
|
def create_all():
|
|
"""Create all the tables."""
|
|
"""Create all the tables."""
|
|
engine = get_engine()
|
|
engine = get_engine()
|
|
- sqlmodel.SQLModel.metadata.create_all(engine)
|
|
|
|
|
|
+ ModelRegistry.get_metadata().create_all(engine)
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
def get_db_engine():
|
|
def get_db_engine():
|
|
@@ -224,7 +306,7 @@ class Model(Base, sqlmodel.SQLModel):
|
|
) as env:
|
|
) as env:
|
|
env.configure(
|
|
env.configure(
|
|
connection=connection,
|
|
connection=connection,
|
|
- target_metadata=sqlmodel.SQLModel.metadata,
|
|
|
|
|
|
+ target_metadata=ModelRegistry.get_metadata(),
|
|
render_item=cls._alembic_render_item,
|
|
render_item=cls._alembic_render_item,
|
|
process_revision_directives=writer, # type: ignore
|
|
process_revision_directives=writer, # type: ignore
|
|
compare_type=False,
|
|
compare_type=False,
|
|
@@ -300,7 +382,6 @@ class Model(Base, sqlmodel.SQLModel):
|
|
return True
|
|
return True
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
- @property
|
|
|
|
def select(cls):
|
|
def select(cls):
|
|
"""Select rows from the table.
|
|
"""Select rows from the table.
|
|
|
|
|
|
@@ -310,6 +391,9 @@ class Model(Base, sqlmodel.SQLModel):
|
|
return sqlmodel.select(cls)
|
|
return sqlmodel.select(cls)
|
|
|
|
|
|
|
|
|
|
|
|
+ModelRegistry.register(Model)
|
|
|
|
+
|
|
|
|
+
|
|
def session(url: str | None = None) -> sqlmodel.Session:
|
|
def session(url: str | None = None) -> sqlmodel.Session:
|
|
"""Get a session to interact with the database.
|
|
"""Get a session to interact with the database.
|
|
|
|
|