|
@@ -2,6 +2,7 @@
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
+import re
|
|
|
from collections import defaultdict
|
|
|
from typing import Any, ClassVar, Optional, Type, Union
|
|
|
|
|
@@ -14,6 +15,7 @@ import alembic.script
|
|
|
import alembic.util
|
|
|
import sqlalchemy
|
|
|
import sqlalchemy.exc
|
|
|
+import sqlalchemy.ext.asyncio
|
|
|
import sqlalchemy.orm
|
|
|
|
|
|
from reflex.base import Base
|
|
@@ -21,6 +23,48 @@ from reflex.config import environment, get_config
|
|
|
from reflex.utils import console
|
|
|
from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
|
|
|
|
|
|
+_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
|
|
|
+_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
|
|
|
+_AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {}
|
|
|
+
|
|
|
+# Import AsyncSession _after_ reflex.utils.compat
|
|
|
+from sqlmodel.ext.asyncio.session import AsyncSession # noqa: E402
|
|
|
+
|
|
|
+
|
|
|
+def _safe_db_url_for_logging(url: str) -> str:
|
|
|
+ """Remove username and password from the database URL for logging.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ url: The database URL.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The database URL with the username and password removed.
|
|
|
+ """
|
|
|
+ return re.sub(r"://[^@]+@", "://<username>:<password>@", url)
|
|
|
+
|
|
|
+
|
|
|
+def get_engine_args(url: str | None = None) -> dict[str, Any]:
|
|
|
+ """Get the database engine arguments.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ url: The database url.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The database engine arguments as a dict.
|
|
|
+ """
|
|
|
+ kwargs: dict[str, Any] = dict(
|
|
|
+ # Print the SQL queries if the log level is INFO or lower.
|
|
|
+ echo=environment.SQLALCHEMY_ECHO.get(),
|
|
|
+ # Check connections before returning them.
|
|
|
+ pool_pre_ping=environment.SQLALCHEMY_POOL_PRE_PING.get(),
|
|
|
+ )
|
|
|
+ conf = get_config()
|
|
|
+ url = url or conf.db_url
|
|
|
+ if url is not None and url.startswith("sqlite"):
|
|
|
+ # Needed for the admin dash on sqlite.
|
|
|
+ kwargs["connect_args"] = {"check_same_thread": False}
|
|
|
+ return kwargs
|
|
|
+
|
|
|
|
|
|
def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
|
|
|
"""Get the database engine.
|
|
@@ -38,15 +82,62 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
|
|
|
url = url or conf.db_url
|
|
|
if url is None:
|
|
|
raise ValueError("No database url configured")
|
|
|
+
|
|
|
+ global _ENGINE
|
|
|
+ if url in _ENGINE:
|
|
|
+ return _ENGINE[url]
|
|
|
+
|
|
|
if not environment.ALEMBIC_CONFIG.get().exists():
|
|
|
console.warn(
|
|
|
"Database is not initialized, run [bold]reflex db init[/bold] first."
|
|
|
)
|
|
|
- # Print the SQL queries if the log level is INFO or lower.
|
|
|
- echo_db_query = environment.SQLALCHEMY_ECHO.get()
|
|
|
- # Needed for the admin dash on sqlite.
|
|
|
- connect_args = {"check_same_thread": False} if url.startswith("sqlite") else {}
|
|
|
- return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
|
|
|
+ _ENGINE[url] = sqlmodel.create_engine(
|
|
|
+ url,
|
|
|
+ **get_engine_args(url),
|
|
|
+ )
|
|
|
+ return _ENGINE[url]
|
|
|
+
|
|
|
+
|
|
|
+def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
|
|
|
+ """Get the async database engine.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ url: The database url.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The async database engine.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ ValueError: If the async database url is None.
|
|
|
+ """
|
|
|
+ if url is None:
|
|
|
+ conf = get_config()
|
|
|
+ url = conf.async_db_url
|
|
|
+ if url is not None and conf.db_url is not None:
|
|
|
+ async_db_url_tail = url.partition("://")[2]
|
|
|
+ db_url_tail = conf.db_url.partition("://")[2]
|
|
|
+ if async_db_url_tail != db_url_tail:
|
|
|
+ console.warn(
|
|
|
+ f"async_db_url `{_safe_db_url_for_logging(url)}` "
|
|
|
+ "should reference the same database as "
|
|
|
+ f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
|
|
|
+ )
|
|
|
+ if url is None:
|
|
|
+ raise ValueError("No async database url configured")
|
|
|
+
|
|
|
+ global _ASYNC_ENGINE
|
|
|
+ if url in _ASYNC_ENGINE:
|
|
|
+ return _ASYNC_ENGINE[url]
|
|
|
+
|
|
|
+ if not environment.ALEMBIC_CONFIG.get().exists():
|
|
|
+ console.warn(
|
|
|
+ "Database is not initialized, run [bold]reflex db init[/bold] first."
|
|
|
+ )
|
|
|
+ _ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
|
|
|
+ url,
|
|
|
+ **get_engine_args(url),
|
|
|
+ )
|
|
|
+ return _ASYNC_ENGINE[url]
|
|
|
|
|
|
|
|
|
async def get_db_status() -> bool:
|
|
@@ -425,6 +516,31 @@ def session(url: str | None = None) -> sqlmodel.Session:
|
|
|
return sqlmodel.Session(get_engine(url))
|
|
|
|
|
|
|
|
|
+def asession(url: str | None = None) -> AsyncSession:
|
|
|
+ """Get an async sqlmodel session to interact with the database.
|
|
|
+
|
|
|
+ async with rx.asession() as asession:
|
|
|
+ ...
|
|
|
+
|
|
|
+ Most operations against the `asession` must be awaited.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ url: The database url.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ An async database session.
|
|
|
+ """
|
|
|
+ global _AsyncSessionLocal
|
|
|
+ if url not in _AsyncSessionLocal:
|
|
|
+ _AsyncSessionLocal[url] = sqlalchemy.ext.asyncio.async_sessionmaker(
|
|
|
+ bind=get_async_engine(url),
|
|
|
+ class_=AsyncSession,
|
|
|
+ autocommit=False,
|
|
|
+ autoflush=False,
|
|
|
+ )
|
|
|
+ return _AsyncSessionLocal[url]()
|
|
|
+
|
|
|
+
|
|
|
def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
|
|
|
"""Get a bare sqlalchemy session to interact with the database.
|
|
|
|