model.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. """Database built into Reflex."""
  2. from __future__ import annotations
  3. import re
  4. from collections import defaultdict
  5. from contextlib import suppress
  6. from typing import Any, ClassVar, Optional, Type, Union
  7. import alembic.autogenerate
  8. import alembic.command
  9. import alembic.config
  10. import alembic.operations.ops
  11. import alembic.runtime.environment
  12. import alembic.script
  13. import alembic.util
  14. import sqlalchemy
  15. import sqlalchemy.exc
  16. import sqlalchemy.ext.asyncio
  17. import sqlalchemy.orm
  18. from alembic.runtime.migration import MigrationContext
  19. from reflex.base import Base
  20. from reflex.config import environment, get_config
  21. from reflex.utils import console
  22. from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
  23. _ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
  24. _ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
  25. _AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {}
  26. # Import AsyncSession _after_ reflex.utils.compat
  27. from sqlmodel.ext.asyncio.session import AsyncSession # noqa: E402
  28. def _safe_db_url_for_logging(url: str) -> str:
  29. """Remove username and password from the database URL for logging.
  30. Args:
  31. url: The database URL.
  32. Returns:
  33. The database URL with the username and password removed.
  34. """
  35. return re.sub(r"://[^@]+@", "://<username>:<password>@", url)
  36. def get_engine_args(url: str | None = None) -> dict[str, Any]:
  37. """Get the database engine arguments.
  38. Args:
  39. url: The database url.
  40. Returns:
  41. The database engine arguments as a dict.
  42. """
  43. kwargs: dict[str, Any] = {
  44. # Print the SQL queries if the log level is INFO or lower.
  45. "echo": environment.SQLALCHEMY_ECHO.get(),
  46. # Check connections before returning them.
  47. "pool_pre_ping": environment.SQLALCHEMY_POOL_PRE_PING.get(),
  48. }
  49. conf = get_config()
  50. url = url or conf.db_url
  51. if url is not None and url.startswith("sqlite"):
  52. # Needed for the admin dash on sqlite.
  53. kwargs["connect_args"] = {"check_same_thread": False}
  54. return kwargs
  55. def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
  56. """Get the database engine.
  57. Args:
  58. url: the DB url to use.
  59. Returns:
  60. The database engine.
  61. Raises:
  62. ValueError: If the database url is None.
  63. """
  64. conf = get_config()
  65. url = url or conf.db_url
  66. if url is None:
  67. raise ValueError("No database url configured")
  68. global _ENGINE
  69. if url in _ENGINE:
  70. return _ENGINE[url]
  71. if not environment.ALEMBIC_CONFIG.get().exists():
  72. console.warn(
  73. "Database is not initialized, run [bold]reflex db init[/bold] first."
  74. )
  75. _ENGINE[url] = sqlmodel.create_engine(
  76. url,
  77. **get_engine_args(url),
  78. )
  79. return _ENGINE[url]
  80. def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
  81. """Get the async database engine.
  82. Args:
  83. url: The database url.
  84. Returns:
  85. The async database engine.
  86. Raises:
  87. ValueError: If the async database url is None.
  88. """
  89. if url is None:
  90. conf = get_config()
  91. url = conf.async_db_url
  92. if url is not None and conf.db_url is not None:
  93. async_db_url_tail = url.partition("://")[2]
  94. db_url_tail = conf.db_url.partition("://")[2]
  95. if async_db_url_tail != db_url_tail:
  96. console.warn(
  97. f"async_db_url `{_safe_db_url_for_logging(url)}` "
  98. "should reference the same database as "
  99. f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
  100. )
  101. if url is None:
  102. raise ValueError("No async database url configured")
  103. global _ASYNC_ENGINE
  104. if url in _ASYNC_ENGINE:
  105. return _ASYNC_ENGINE[url]
  106. if not environment.ALEMBIC_CONFIG.get().exists():
  107. console.warn(
  108. "Database is not initialized, run [bold]reflex db init[/bold] first."
  109. )
  110. _ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
  111. url,
  112. **get_engine_args(url),
  113. )
  114. return _ASYNC_ENGINE[url]
  115. async def get_db_status() -> dict[str, bool]:
  116. """Checks the status of the database connection.
  117. Attempts to connect to the database and execute a simple query to verify connectivity.
  118. Returns:
  119. The status of the database connection.
  120. """
  121. status = True
  122. try:
  123. engine = get_engine()
  124. with engine.connect() as connection:
  125. connection.execute(sqlalchemy.text("SELECT 1"))
  126. except sqlalchemy.exc.OperationalError:
  127. status = False
  128. return {"db": status}
  129. SQLModelOrSqlAlchemy = Union[
  130. Type[sqlmodel.SQLModel], Type[sqlalchemy.orm.DeclarativeBase]
  131. ]
  132. class ModelRegistry:
  133. """Registry for all models."""
  134. models: ClassVar[set[SQLModelOrSqlAlchemy]] = set()
  135. # Cache the metadata to avoid re-creating it.
  136. _metadata: ClassVar[sqlalchemy.MetaData | None] = None
  137. @classmethod
  138. def register(cls, model: SQLModelOrSqlAlchemy):
  139. """Register a model. Can be used directly or as a decorator.
  140. Args:
  141. model: The model to register.
  142. Returns:
  143. The model passed in as an argument (Allows decorator usage)
  144. """
  145. cls.models.add(model)
  146. return model
  147. @classmethod
  148. def get_models(cls, include_empty: bool = False) -> set[SQLModelOrSqlAlchemy]:
  149. """Get registered models.
  150. Args:
  151. include_empty: If True, include models with empty metadata.
  152. Returns:
  153. The registered models.
  154. """
  155. if include_empty:
  156. return cls.models
  157. return {
  158. model for model in cls.models if not cls._model_metadata_is_empty(model)
  159. }
  160. @staticmethod
  161. def _model_metadata_is_empty(model: SQLModelOrSqlAlchemy) -> bool:
  162. """Check if the model metadata is empty.
  163. Args:
  164. model: The model to check.
  165. Returns:
  166. True if the model metadata is empty, False otherwise.
  167. """
  168. return len(model.metadata.tables) == 0
  169. @classmethod
  170. def get_metadata(cls) -> sqlalchemy.MetaData:
  171. """Get the database metadata.
  172. Returns:
  173. The database metadata.
  174. """
  175. if cls._metadata is not None:
  176. return cls._metadata
  177. models = cls.get_models(include_empty=False)
  178. if len(models) == 1:
  179. metadata = next(iter(models)).metadata
  180. else:
  181. # Merge the metadata from all the models.
  182. # This allows mixing bare sqlalchemy models with sqlmodel models in one database.
  183. metadata = sqlalchemy.MetaData()
  184. for model in cls.get_models():
  185. for table in model.metadata.tables.values():
  186. table.to_metadata(metadata)
  187. # Cache the metadata
  188. cls._metadata = metadata
  189. return metadata
  190. class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssues,reportIncompatibleVariableOverride]
  191. """Base class to define a table in the database."""
  192. # The primary key for the table.
  193. id: Optional[int] = sqlmodel.Field(default=None, primary_key=True)
  194. def __init_subclass__(cls):
  195. """Drop the default primary key field if any primary key field is defined."""
  196. non_default_primary_key_fields = [
  197. field_name
  198. for field_name, field in cls.__fields__.items()
  199. if field_name != "id" and sqlmodel_field_has_primary_key(field)
  200. ]
  201. if non_default_primary_key_fields:
  202. cls.__fields__.pop("id", None)
  203. super().__init_subclass__()
  204. @classmethod
  205. def _dict_recursive(cls, value: Any):
  206. """Recursively serialize the relationship object(s).
  207. Args:
  208. value: The value to serialize.
  209. Returns:
  210. The serialized value.
  211. """
  212. if hasattr(value, "dict"):
  213. return value.dict()
  214. elif isinstance(value, list):
  215. return [cls._dict_recursive(item) for item in value]
  216. return value
  217. def dict(self, **kwargs):
  218. """Convert the object to a dictionary.
  219. Args:
  220. kwargs: Ignored but needed for compatibility.
  221. Returns:
  222. The object as a dictionary.
  223. """
  224. base_fields = {name: getattr(self, name) for name in self.__fields__}
  225. relationships = {}
  226. # SQLModel relationships do not appear in __fields__, but should be included if present.
  227. for name in self.__sqlmodel_relationships__:
  228. with suppress(
  229. sqlalchemy.orm.exc.DetachedInstanceError # This happens when the relationship was never loaded and the session is closed.
  230. ):
  231. relationships[name] = self._dict_recursive(getattr(self, name))
  232. return {
  233. **base_fields,
  234. **relationships,
  235. }
  236. @staticmethod
  237. def create_all():
  238. """Create all the tables."""
  239. engine = get_engine()
  240. ModelRegistry.get_metadata().create_all(engine)
  241. @staticmethod
  242. def get_db_engine():
  243. """Get the database engine.
  244. Returns:
  245. The database engine.
  246. """
  247. return get_engine()
  248. @staticmethod
  249. def _alembic_config():
  250. """Get the alembic configuration and script_directory.
  251. Returns:
  252. tuple of (config, script_directory)
  253. """
  254. config = alembic.config.Config(environment.ALEMBIC_CONFIG.get())
  255. return config, alembic.script.ScriptDirectory(
  256. config.get_main_option("script_location", default="version"),
  257. )
  258. @staticmethod
  259. def _alembic_render_item(
  260. type_: str,
  261. obj: Any,
  262. autogen_context: "alembic.autogenerate.api.AutogenContext",
  263. ):
  264. """Alembic render_item hook call.
  265. This method is called to provide python code for the given obj,
  266. but currently it is only used to add `sqlmodel` to the import list
  267. when generating migration scripts.
  268. See https://alembic.sqlalchemy.org/en/latest/api/runtime.html
  269. Args:
  270. type_: One of "schema", "table", "column", "index",
  271. "unique_constraint", or "foreign_key_constraint".
  272. obj: The object being rendered.
  273. autogen_context: Shared AutogenContext passed to each render_item call.
  274. Returns:
  275. False - Indicating that the default rendering should be used.
  276. """
  277. autogen_context.imports.add("import sqlmodel")
  278. return False
  279. @classmethod
  280. def alembic_init(cls):
  281. """Initialize alembic for the project."""
  282. alembic.command.init(
  283. config=alembic.config.Config(environment.ALEMBIC_CONFIG.get()),
  284. directory=str(environment.ALEMBIC_CONFIG.get().parent / "alembic"),
  285. )
  286. @classmethod
  287. def alembic_autogenerate(
  288. cls,
  289. connection: sqlalchemy.engine.Connection,
  290. message: str | None = None,
  291. write_migration_scripts: bool = True,
  292. ) -> bool:
  293. """Generate migration scripts for alembic-detectable changes.
  294. Args:
  295. connection: SQLAlchemy connection to use when detecting changes.
  296. message: Human readable identifier describing the generated revision.
  297. write_migration_scripts: If True, write autogenerated revisions to script directory.
  298. Returns:
  299. True when changes have been detected.
  300. """
  301. if not environment.ALEMBIC_CONFIG.get().exists():
  302. return False
  303. config, script_directory = cls._alembic_config()
  304. revision_context = alembic.autogenerate.api.RevisionContext(
  305. config=config,
  306. script_directory=script_directory,
  307. command_args=defaultdict(
  308. lambda: None,
  309. autogenerate=True,
  310. head="head",
  311. message=message,
  312. ),
  313. )
  314. writer = alembic.autogenerate.rewriter.Rewriter()
  315. @writer.rewrites(alembic.operations.ops.AddColumnOp)
  316. def render_add_column_with_server_default(
  317. context: MigrationContext,
  318. revision: str | None,
  319. op: Any,
  320. ):
  321. # Carry the sqlmodel default as server_default so that newly added
  322. # columns get the desired default value in existing rows.
  323. if op.column.default is not None and op.column.server_default is None:
  324. op.column.server_default = sqlalchemy.DefaultClause(
  325. sqlalchemy.sql.expression.literal(op.column.default.arg),
  326. )
  327. return op
  328. def run_autogenerate(rev: str, context: MigrationContext):
  329. revision_context.run_autogenerate(rev, context)
  330. return []
  331. with alembic.runtime.environment.EnvironmentContext(
  332. config=config,
  333. script=script_directory,
  334. fn=run_autogenerate,
  335. ) as env:
  336. env.configure(
  337. connection=connection,
  338. target_metadata=ModelRegistry.get_metadata(),
  339. render_item=cls._alembic_render_item,
  340. process_revision_directives=writer,
  341. compare_type=False,
  342. render_as_batch=True, # for sqlite compatibility
  343. )
  344. env.run_migrations()
  345. changes_detected = False
  346. if revision_context.generated_revisions:
  347. upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops
  348. if upgrade_ops is not None:
  349. changes_detected = bool(upgrade_ops.ops)
  350. if changes_detected and write_migration_scripts:
  351. # Must iterate the generator to actually write the scripts.
  352. _ = tuple(revision_context.generate_scripts())
  353. return changes_detected
  354. @classmethod
  355. def _alembic_upgrade(
  356. cls,
  357. connection: sqlalchemy.engine.Connection,
  358. to_rev: str = "head",
  359. ) -> None:
  360. """Apply alembic migrations up to the given revision.
  361. Args:
  362. connection: SQLAlchemy connection to use when performing upgrade.
  363. to_rev: Revision to migrate towards.
  364. """
  365. config, script_directory = cls._alembic_config()
  366. def run_upgrade(rev: str, context: MigrationContext):
  367. return script_directory._upgrade_revs(to_rev, rev)
  368. with alembic.runtime.environment.EnvironmentContext(
  369. config=config,
  370. script=script_directory,
  371. fn=run_upgrade,
  372. ) as env:
  373. env.configure(connection=connection)
  374. env.run_migrations()
  375. @classmethod
  376. def migrate(cls, autogenerate: bool = False) -> bool | None:
  377. """Execute alembic migrations for all sqlmodel Model classes.
  378. If alembic is not installed or has not been initialized for the project,
  379. then no action is performed.
  380. If there are no revisions currently tracked by alembic, then
  381. an initial revision will be created based on sqlmodel metadata.
  382. If models in the app have changed in incompatible ways that alembic
  383. cannot automatically generate revisions for, the app may not be able to
  384. start up until migration scripts have been corrected by hand.
  385. Args:
  386. autogenerate: If True, generate migration script and use it to upgrade schema
  387. (otherwise, just bring the schema to current "head" revision).
  388. Returns:
  389. True - indicating the process was successful.
  390. None - indicating the process was skipped.
  391. """
  392. if not environment.ALEMBIC_CONFIG.get().exists():
  393. return
  394. with cls.get_db_engine().connect() as connection:
  395. cls._alembic_upgrade(connection=connection)
  396. if autogenerate:
  397. changes_detected = cls.alembic_autogenerate(connection=connection)
  398. if changes_detected:
  399. cls._alembic_upgrade(connection=connection)
  400. connection.commit()
  401. return True
  402. @classmethod
  403. def select(cls):
  404. """Select rows from the table.
  405. Returns:
  406. The select statement.
  407. """
  408. return sqlmodel.select(cls)
  409. ModelRegistry.register(Model)
  410. def session(url: str | None = None) -> sqlmodel.Session:
  411. """Get a sqlmodel session to interact with the database.
  412. Args:
  413. url: The database url.
  414. Returns:
  415. A database session.
  416. """
  417. return sqlmodel.Session(get_engine(url))
  418. def asession(url: str | None = None) -> AsyncSession:
  419. """Get an async sqlmodel session to interact with the database.
  420. async with rx.asession() as asession:
  421. ...
  422. Most operations against the `asession` must be awaited.
  423. Args:
  424. url: The database url.
  425. Returns:
  426. An async database session.
  427. """
  428. global _AsyncSessionLocal
  429. if url not in _AsyncSessionLocal:
  430. _AsyncSessionLocal[url] = sqlalchemy.ext.asyncio.async_sessionmaker(
  431. bind=get_async_engine(url),
  432. class_=AsyncSession,
  433. expire_on_commit=False,
  434. autocommit=False,
  435. autoflush=False,
  436. )
  437. return _AsyncSessionLocal[url]()
  438. def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
  439. """Get a bare sqlalchemy session to interact with the database.
  440. Args:
  441. url: The database url.
  442. Returns:
  443. A database session.
  444. """
  445. return sqlalchemy.orm.Session(get_engine(url))