Преглед изворни кода

Reuse the sqlalchemy engine once it's created (#4493)

* Reuse the sqlalchemy engine once it's created

* Implement `rx.asession` for async database support

Requires setting `async_db_url` to the same database as `db_url`, except using
an async driver; this will vary by database.

* resolve the url first, so the key into _ENGINE is correct

* Ping db connections before returning them from pool

Move connect engine kwargs to a separate function

* the param is `echo`

* sanity check that config db_url and async_db_url are the same

throw a warning if the part following the `://` differs between these two

* create_async_engine: use sqlalchemy async API

update types

* redact ASYNC_DB_URL similarly to DB_URL when overridden in config

* update rx.asession docstring

* use async_sessionmaker

* Redact sensitive env vars instead of hiding them
Masen Furer пре 5 месеци
родитељ
комит
4922f7ba05
4 измењених фајлова са 141 додато и 12 уклоњено
  1. 1 1
      reflex/__init__.py
  2. 1 0
      reflex/__init__.pyi
  3. 18 6
      reflex/config.py
  4. 121 5
      reflex/model.py

+ 1 - 1
reflex/__init__.py

@@ -331,7 +331,7 @@ _MAPPING: dict = {
         "SessionStorage",
     ],
     "middleware": ["middleware", "Middleware"],
-    "model": ["session", "Model"],
+    "model": ["asession", "session", "Model"],
     "state": [
         "var",
         "ComponentState",

+ 1 - 0
reflex/__init__.pyi

@@ -186,6 +186,7 @@ from .istate.wrappers import get_state as get_state
 from .middleware import Middleware as Middleware
 from .middleware import middleware as middleware
 from .model import Model as Model
+from .model import asession as asession
 from .model import session as session
 from .page import page as page
 from .state import ComponentState as ComponentState

+ 18 - 6
reflex/config.py

@@ -512,6 +512,9 @@ class EnvironmentVariables:
     # Whether to print the SQL queries if the log level is INFO or lower.
     SQLALCHEMY_ECHO: EnvVar[bool] = env_var(False)
 
+    # Whether to check db connections before using them.
+    SQLALCHEMY_POOL_PRE_PING: EnvVar[bool] = env_var(True)
+
     # Whether to ignore the redis config error. Some redis servers only allow out-of-band configuration.
     REFLEX_IGNORE_REDIS_CONFIG_ERROR: EnvVar[bool] = env_var(False)
 
@@ -568,6 +571,10 @@ class EnvironmentVariables:
 environment = EnvironmentVariables()
 
 
+# These vars are not logged because they may contain sensitive information.
+_sensitive_env_vars = {"DB_URL", "ASYNC_DB_URL", "REDIS_URL"}
+
+
 class Config(Base):
     """The config defines runtime settings for the app.
 
@@ -621,6 +628,9 @@ class Config(Base):
     # The database url used by rx.Model.
     db_url: Optional[str] = "sqlite:///reflex.db"
 
+    # The async database url used by rx.Model.
+    async_db_url: Optional[str] = None
+
     # The redis url
     redis_url: Optional[str] = None
 
@@ -748,18 +758,20 @@ class Config(Base):
 
             # If the env var is set, override the config value.
             if env_var is not None:
-                if key.upper() != "DB_URL":
-                    console.info(
-                        f"Overriding config value {key} with env var {key.upper()}={env_var}",
-                        dedupe=True,
-                    )
-
                 # Interpret the value.
                 value = interpret_env_var_value(env_var, field.outer_type_, field.name)
 
                 # Set the value.
                 updated_values[key] = value
 
+                if key.upper() in _sensitive_env_vars:
+                    env_var = "***"
+
+                console.info(
+                    f"Overriding config value {key} with env var {key.upper()}={env_var}",
+                    dedupe=True,
+                )
+
         return updated_values
 
     def get_event_namespace(self) -> str:

+ 121 - 5
reflex/model.py

@@ -2,6 +2,7 @@
 
 from __future__ import annotations
 
+import re
 from collections import defaultdict
 from typing import Any, ClassVar, Optional, Type, Union
 
@@ -14,6 +15,7 @@ import alembic.script
 import alembic.util
 import sqlalchemy
 import sqlalchemy.exc
+import sqlalchemy.ext.asyncio
 import sqlalchemy.orm
 
 from reflex.base import Base
@@ -21,6 +23,48 @@ from reflex.config import environment, get_config
 from reflex.utils import console
 from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
 
+_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
+_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
+_AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {}
+
+# Import AsyncSession _after_ reflex.utils.compat
+from sqlmodel.ext.asyncio.session import AsyncSession  # noqa: E402
+
+
+def _safe_db_url_for_logging(url: str) -> str:
+    """Remove username and password from the database URL for logging.
+
+    Args:
+        url: The database URL.
+
+    Returns:
+        The database URL with the username and password removed.
+    """
+    return re.sub(r"://[^@]+@", "://<username>:<password>@", url)
+
+
+def get_engine_args(url: str | None = None) -> dict[str, Any]:
+    """Get the database engine arguments.
+
+    Args:
+        url: The database url.
+
+    Returns:
+        The database engine arguments as a dict.
+    """
+    kwargs: dict[str, Any] = dict(
+        # Print the SQL queries if the log level is INFO or lower.
+        echo=environment.SQLALCHEMY_ECHO.get(),
+        # Check connections before returning them.
+        pool_pre_ping=environment.SQLALCHEMY_POOL_PRE_PING.get(),
+    )
+    conf = get_config()
+    url = url or conf.db_url
+    if url is not None and url.startswith("sqlite"):
+        # Needed for the admin dash on sqlite.
+        kwargs["connect_args"] = {"check_same_thread": False}
+    return kwargs
+
 
 def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
     """Get the database engine.
@@ -38,15 +82,62 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
     url = url or conf.db_url
     if url is None:
         raise ValueError("No database url configured")
+
+    global _ENGINE
+    if url in _ENGINE:
+        return _ENGINE[url]
+
     if not environment.ALEMBIC_CONFIG.get().exists():
         console.warn(
             "Database is not initialized, run [bold]reflex db init[/bold] first."
         )
-    # Print the SQL queries if the log level is INFO or lower.
-    echo_db_query = environment.SQLALCHEMY_ECHO.get()
-    # Needed for the admin dash on sqlite.
-    connect_args = {"check_same_thread": False} if url.startswith("sqlite") else {}
-    return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
+    _ENGINE[url] = sqlmodel.create_engine(
+        url,
+        **get_engine_args(url),
+    )
+    return _ENGINE[url]
+
+
+def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
+    """Get the async database engine.
+
+    Args:
+        url: The database url.
+
+    Returns:
+        The async database engine.
+
+    Raises:
+        ValueError: If the async database url is None.
+    """
+    if url is None:
+        conf = get_config()
+        url = conf.async_db_url
+        if url is not None and conf.db_url is not None:
+            async_db_url_tail = url.partition("://")[2]
+            db_url_tail = conf.db_url.partition("://")[2]
+            if async_db_url_tail != db_url_tail:
+                console.warn(
+                    f"async_db_url `{_safe_db_url_for_logging(url)}` "
+                    "should reference the same database as "
+                    f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
+                )
+    if url is None:
+        raise ValueError("No async database url configured")
+
+    global _ASYNC_ENGINE
+    if url in _ASYNC_ENGINE:
+        return _ASYNC_ENGINE[url]
+
+    if not environment.ALEMBIC_CONFIG.get().exists():
+        console.warn(
+            "Database is not initialized, run [bold]reflex db init[/bold] first."
+        )
+    _ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
+        url,
+        **get_engine_args(url),
+    )
+    return _ASYNC_ENGINE[url]
 
 
 async def get_db_status() -> bool:
@@ -425,6 +516,31 @@ def session(url: str | None = None) -> sqlmodel.Session:
     return sqlmodel.Session(get_engine(url))
 
 
+def asession(url: str | None = None) -> AsyncSession:
+    """Get an async sqlmodel session to interact with the database.
+
+    async with rx.asession() as asession:
+        ...
+
+    Most operations against the `asession` must be awaited.
+
+    Args:
+        url: The database url.
+
+    Returns:
+        An async database session.
+    """
+    global _AsyncSessionLocal
+    if url not in _AsyncSessionLocal:
+        _AsyncSessionLocal[url] = sqlalchemy.ext.asyncio.async_sessionmaker(
+            bind=get_async_engine(url),
+            class_=AsyncSession,
+            autocommit=False,
+            autoflush=False,
+        )
+    return _AsyncSessionLocal[url]()
+
+
 def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
     """Get a bare sqlalchemy session to interact with the database.