Bladeren bron

Bare sqlalchemy metadata (#2355)

benedikt-bartscher 1 jaar geleden
bovenliggende
commit
5701a72c8f
1 gewijzigde bestanden met toevoegingen van 88 en 4 verwijderingen
  1. 88 4
      reflex/model.py

+ 88 - 4
reflex/model.py

@@ -5,7 +5,7 @@ from __future__ import annotations
 import os
 import os
 from collections import defaultdict
 from collections import defaultdict
 from pathlib import Path
 from pathlib import Path
-from typing import Any, Optional
+from typing import Any, ClassVar, Optional, Type, Union
 
 
 import alembic.autogenerate
 import alembic.autogenerate
 import alembic.command
 import alembic.command
@@ -51,6 +51,88 @@ def get_engine(url: str | None = None):
     return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
     return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
 
 
 
 
+SQLModelOrSqlAlchemy = Union[
+    Type[sqlmodel.SQLModel], Type[sqlalchemy.orm.DeclarativeBase]
+]
+
+
+class ModelRegistry:
+    """Registry for all models."""
+
+    models: ClassVar[set[SQLModelOrSqlAlchemy]] = set()
+
+    # Cache the metadata to avoid re-creating it.
+    _metadata: ClassVar[sqlalchemy.MetaData | None] = None
+
+    @classmethod
+    def register(cls, model: SQLModelOrSqlAlchemy):
+        """Register a model. Can be used directly or as a decorator.
+
+        Args:
+            model: The model to register.
+
+        Returns:
+            The model passed in as an argument (Allows decorator usage)
+        """
+        cls.models.add(model)
+        return model
+
+    @classmethod
+    def get_models(cls, include_empty: bool = False) -> set[SQLModelOrSqlAlchemy]:
+        """Get registered models.
+
+        Args:
+            include_empty: If True, include models with empty metadata.
+
+        Returns:
+            The registered models.
+        """
+        if include_empty:
+            return cls.models
+        return {
+            model for model in cls.models if not cls._model_metadata_is_empty(model)
+        }
+
+    @staticmethod
+    def _model_metadata_is_empty(model: SQLModelOrSqlAlchemy) -> bool:
+        """Check if the model metadata is empty.
+
+        Args:
+            model: The model to check.
+
+        Returns:
+            True if the model metadata is empty, False otherwise.
+        """
+        return len(model.metadata.tables) == 0
+
+    @classmethod
+    def get_metadata(cls) -> sqlalchemy.MetaData:
+        """Get the database metadata.
+
+        Returns:
+            The database metadata.
+        """
+        if cls._metadata is not None:
+            return cls._metadata
+
+        models = cls.get_models(include_empty=False)
+
+        if len(models) == 1:
+            metadata = next(iter(models)).metadata
+        else:
+            # Merge the metadata from all the models.
+            # This allows mixing bare sqlalchemy models with sqlmodel models in one database.
+            metadata = sqlalchemy.MetaData()
+            for model in cls.get_models():
+                for table in model.metadata.tables.values():
+                    table.to_metadata(metadata)
+
+        # Cache the metadata
+        cls._metadata = metadata
+
+        return metadata
+
+
 class Model(Base, sqlmodel.SQLModel):
 class Model(Base, sqlmodel.SQLModel):
     """Base class to define a table in the database."""
     """Base class to define a table in the database."""
 
 
@@ -113,7 +195,7 @@ class Model(Base, sqlmodel.SQLModel):
     def create_all():
     def create_all():
         """Create all the tables."""
         """Create all the tables."""
         engine = get_engine()
         engine = get_engine()
-        sqlmodel.SQLModel.metadata.create_all(engine)
+        ModelRegistry.get_metadata().create_all(engine)
 
 
     @staticmethod
     @staticmethod
     def get_db_engine():
     def get_db_engine():
@@ -224,7 +306,7 @@ class Model(Base, sqlmodel.SQLModel):
         ) as env:
         ) as env:
             env.configure(
             env.configure(
                 connection=connection,
                 connection=connection,
-                target_metadata=sqlmodel.SQLModel.metadata,
+                target_metadata=ModelRegistry.get_metadata(),
                 render_item=cls._alembic_render_item,
                 render_item=cls._alembic_render_item,
                 process_revision_directives=writer,  # type: ignore
                 process_revision_directives=writer,  # type: ignore
                 compare_type=False,
                 compare_type=False,
@@ -300,7 +382,6 @@ class Model(Base, sqlmodel.SQLModel):
         return True
         return True
 
 
     @classmethod
     @classmethod
-    @property
     def select(cls):
     def select(cls):
         """Select rows from the table.
         """Select rows from the table.
 
 
@@ -310,6 +391,9 @@ class Model(Base, sqlmodel.SQLModel):
         return sqlmodel.select(cls)
         return sqlmodel.select(cls)
 
 
 
 
+ModelRegistry.register(Model)
+
+
 def session(url: str | None = None) -> sqlmodel.Session:
 def session(url: str | None = None) -> sqlmodel.Session:
     """Get a session to interact with the database.
     """Get a session to interact with the database.