Преглед на файлове

Reuse the sqlalchemy engine once it's created (#4493)

* Reuse the sqlalchemy engine once it's created

* Implement `rx.asession` for async database support

Requires setting `async_db_url` to the same database as `db_url`, except using
an async driver; this will vary by database.

* resolve the url first, so the key into _ENGINE is correct

* Ping db connections before returning them from pool

Move connect engine kwargs to a separate function

* the param is `echo`

* sanity check that config db_url and async_db_url are the same

throw a warning if the part following the `://` differs between these two

* create_async_engine: use sqlalchemy async API

update types

* redact ASYNC_DB_URL similarly to DB_URL when overridden in config

* update rx.asession docstring

* use async_sessionmaker

* Redact sensitive env vars instead of hiding them
Masen Furer преди 5 месеца
родител
ревизия
4922f7ba05
променени са 4 файла, в които са добавени 141 реда и са изтрити 12 реда
  1. 1 1
      reflex/__init__.py
  2. 1 0
      reflex/__init__.pyi
  3. 18 6
      reflex/config.py
  4. 121 5
      reflex/model.py

+ 1 - 1
reflex/__init__.py

@@ -331,7 +331,7 @@ _MAPPING: dict = {
         "SessionStorage",
         "SessionStorage",
     ],
     ],
     "middleware": ["middleware", "Middleware"],
     "middleware": ["middleware", "Middleware"],
-    "model": ["session", "Model"],
+    "model": ["asession", "session", "Model"],
     "state": [
     "state": [
         "var",
         "var",
         "ComponentState",
         "ComponentState",

+ 1 - 0
reflex/__init__.pyi

@@ -186,6 +186,7 @@ from .istate.wrappers import get_state as get_state
 from .middleware import Middleware as Middleware
 from .middleware import Middleware as Middleware
 from .middleware import middleware as middleware
 from .middleware import middleware as middleware
 from .model import Model as Model
 from .model import Model as Model
+from .model import asession as asession
 from .model import session as session
 from .model import session as session
 from .page import page as page
 from .page import page as page
 from .state import ComponentState as ComponentState
 from .state import ComponentState as ComponentState

+ 18 - 6
reflex/config.py

@@ -512,6 +512,9 @@ class EnvironmentVariables:
     # Whether to print the SQL queries if the log level is INFO or lower.
     # Whether to print the SQL queries if the log level is INFO or lower.
     SQLALCHEMY_ECHO: EnvVar[bool] = env_var(False)
     SQLALCHEMY_ECHO: EnvVar[bool] = env_var(False)
 
 
+    # Whether to check db connections before using them.
+    SQLALCHEMY_POOL_PRE_PING: EnvVar[bool] = env_var(True)
+
     # Whether to ignore the redis config error. Some redis servers only allow out-of-band configuration.
     # Whether to ignore the redis config error. Some redis servers only allow out-of-band configuration.
     REFLEX_IGNORE_REDIS_CONFIG_ERROR: EnvVar[bool] = env_var(False)
     REFLEX_IGNORE_REDIS_CONFIG_ERROR: EnvVar[bool] = env_var(False)
 
 
@@ -568,6 +571,10 @@ class EnvironmentVariables:
 environment = EnvironmentVariables()
 environment = EnvironmentVariables()
 
 
 
 
+# These vars are not logged because they may contain sensitive information.
+_sensitive_env_vars = {"DB_URL", "ASYNC_DB_URL", "REDIS_URL"}
+
+
 class Config(Base):
 class Config(Base):
     """The config defines runtime settings for the app.
     """The config defines runtime settings for the app.
 
 
@@ -621,6 +628,9 @@ class Config(Base):
     # The database url used by rx.Model.
     # The database url used by rx.Model.
     db_url: Optional[str] = "sqlite:///reflex.db"
     db_url: Optional[str] = "sqlite:///reflex.db"
 
 
+    # The async database url used by rx.Model.
+    async_db_url: Optional[str] = None
+
     # The redis url
     # The redis url
     redis_url: Optional[str] = None
     redis_url: Optional[str] = None
 
 
@@ -748,18 +758,20 @@ class Config(Base):
 
 
             # If the env var is set, override the config value.
             # If the env var is set, override the config value.
             if env_var is not None:
             if env_var is not None:
-                if key.upper() != "DB_URL":
-                    console.info(
-                        f"Overriding config value {key} with env var {key.upper()}={env_var}",
-                        dedupe=True,
-                    )
-
                 # Interpret the value.
                 # Interpret the value.
                 value = interpret_env_var_value(env_var, field.outer_type_, field.name)
                 value = interpret_env_var_value(env_var, field.outer_type_, field.name)
 
 
                 # Set the value.
                 # Set the value.
                 updated_values[key] = value
                 updated_values[key] = value
 
 
+                if key.upper() in _sensitive_env_vars:
+                    env_var = "***"
+
+                console.info(
+                    f"Overriding config value {key} with env var {key.upper()}={env_var}",
+                    dedupe=True,
+                )
+
         return updated_values
         return updated_values
 
 
     def get_event_namespace(self) -> str:
     def get_event_namespace(self) -> str:

+ 121 - 5
reflex/model.py

@@ -2,6 +2,7 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
+import re
 from collections import defaultdict
 from collections import defaultdict
 from typing import Any, ClassVar, Optional, Type, Union
 from typing import Any, ClassVar, Optional, Type, Union
 
 
@@ -14,6 +15,7 @@ import alembic.script
 import alembic.util
 import alembic.util
 import sqlalchemy
 import sqlalchemy
 import sqlalchemy.exc
 import sqlalchemy.exc
+import sqlalchemy.ext.asyncio
 import sqlalchemy.orm
 import sqlalchemy.orm
 
 
 from reflex.base import Base
 from reflex.base import Base
@@ -21,6 +23,48 @@ from reflex.config import environment, get_config
 from reflex.utils import console
 from reflex.utils import console
 from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
 from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
 
 
+_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
+_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
+_AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {}
+
+# Import AsyncSession _after_ reflex.utils.compat
+from sqlmodel.ext.asyncio.session import AsyncSession  # noqa: E402
+
+
+def _safe_db_url_for_logging(url: str) -> str:
+    """Remove username and password from the database URL for logging.
+
+    Args:
+        url: The database URL.
+
+    Returns:
+        The database URL with the username and password removed.
+    """
+    return re.sub(r"://[^@]+@", "://<username>:<password>@", url)
+
+
+def get_engine_args(url: str | None = None) -> dict[str, Any]:
+    """Get the database engine arguments.
+
+    Args:
+        url: The database url.
+
+    Returns:
+        The database engine arguments as a dict.
+    """
+    kwargs: dict[str, Any] = dict(
+        # Print the SQL queries if the log level is INFO or lower.
+        echo=environment.SQLALCHEMY_ECHO.get(),
+        # Check connections before returning them.
+        pool_pre_ping=environment.SQLALCHEMY_POOL_PRE_PING.get(),
+    )
+    conf = get_config()
+    url = url or conf.db_url
+    if url is not None and url.startswith("sqlite"):
+        # Needed for the admin dash on sqlite.
+        kwargs["connect_args"] = {"check_same_thread": False}
+    return kwargs
+
 
 
 def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
 def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
     """Get the database engine.
     """Get the database engine.
@@ -38,15 +82,62 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
     url = url or conf.db_url
     url = url or conf.db_url
     if url is None:
     if url is None:
         raise ValueError("No database url configured")
         raise ValueError("No database url configured")
+
+    global _ENGINE
+    if url in _ENGINE:
+        return _ENGINE[url]
+
     if not environment.ALEMBIC_CONFIG.get().exists():
     if not environment.ALEMBIC_CONFIG.get().exists():
         console.warn(
         console.warn(
             "Database is not initialized, run [bold]reflex db init[/bold] first."
             "Database is not initialized, run [bold]reflex db init[/bold] first."
         )
         )
-    # Print the SQL queries if the log level is INFO or lower.
-    echo_db_query = environment.SQLALCHEMY_ECHO.get()
-    # Needed for the admin dash on sqlite.
-    connect_args = {"check_same_thread": False} if url.startswith("sqlite") else {}
-    return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
+    _ENGINE[url] = sqlmodel.create_engine(
+        url,
+        **get_engine_args(url),
+    )
+    return _ENGINE[url]
+
+
+def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
+    """Get the async database engine.
+
+    Args:
+        url: The database url.
+
+    Returns:
+        The async database engine.
+
+    Raises:
+        ValueError: If the async database url is None.
+    """
+    if url is None:
+        conf = get_config()
+        url = conf.async_db_url
+        if url is not None and conf.db_url is not None:
+            async_db_url_tail = url.partition("://")[2]
+            db_url_tail = conf.db_url.partition("://")[2]
+            if async_db_url_tail != db_url_tail:
+                console.warn(
+                    f"async_db_url `{_safe_db_url_for_logging(url)}` "
+                    "should reference the same database as "
+                    f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
+                )
+    if url is None:
+        raise ValueError("No async database url configured")
+
+    global _ASYNC_ENGINE
+    if url in _ASYNC_ENGINE:
+        return _ASYNC_ENGINE[url]
+
+    if not environment.ALEMBIC_CONFIG.get().exists():
+        console.warn(
+            "Database is not initialized, run [bold]reflex db init[/bold] first."
+        )
+    _ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
+        url,
+        **get_engine_args(url),
+    )
+    return _ASYNC_ENGINE[url]
 
 
 
 
 async def get_db_status() -> bool:
 async def get_db_status() -> bool:
@@ -425,6 +516,31 @@ def session(url: str | None = None) -> sqlmodel.Session:
     return sqlmodel.Session(get_engine(url))
     return sqlmodel.Session(get_engine(url))
 
 
 
 
+def asession(url: str | None = None) -> AsyncSession:
+    """Get an async sqlmodel session to interact with the database.
+
+    async with rx.asession() as asession:
+        ...
+
+    Most operations against the `asession` must be awaited.
+
+    Args:
+        url: The database url.
+
+    Returns:
+        An async database session.
+    """
+    global _AsyncSessionLocal
+    if url not in _AsyncSessionLocal:
+        _AsyncSessionLocal[url] = sqlalchemy.ext.asyncio.async_sessionmaker(
+            bind=get_async_engine(url),
+            class_=AsyncSession,
+            autocommit=False,
+            autoflush=False,
+        )
+    return _AsyncSessionLocal[url]()
+
+
 def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
 def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
     """Get a bare sqlalchemy session to interact with the database.
     """Get a bare sqlalchemy session to interact with the database.