model.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. """Database built into Reflex."""
  2. from __future__ import annotations
  3. import os
  4. from collections import defaultdict
  5. from pathlib import Path
  6. from typing import Any, Optional
  7. import alembic.autogenerate
  8. import alembic.command
  9. import alembic.config
  10. import alembic.operations.ops
  11. import alembic.runtime.environment
  12. import alembic.script
  13. import alembic.util
  14. import sqlalchemy
  15. import sqlmodel
  16. from reflex import constants
  17. from reflex.base import Base
  18. from reflex.config import get_config
  19. from reflex.utils import console
  20. def get_engine(url: str | None = None):
  21. """Get the database engine.
  22. Args:
  23. url: the DB url to use.
  24. Returns:
  25. The database engine.
  26. Raises:
  27. ValueError: If the database url is None.
  28. """
  29. conf = get_config()
  30. url = url or conf.db_url
  31. if url is None:
  32. raise ValueError("No database url configured")
  33. if not Path(constants.ALEMBIC_CONFIG).exists():
  34. console.warn(
  35. "Database is not initialized, run [bold]reflex db init[/bold] first."
  36. )
  37. # Print the SQL queries if the log level is INFO or lower.
  38. echo_db_query = os.environ.get("SQLALCHEMY_ECHO") == "True"
  39. # Needed for the admin dash on sqlite.
  40. connect_args = {"check_same_thread": False} if url.startswith("sqlite") else {}
  41. return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
  42. class Model(Base, sqlmodel.SQLModel):
  43. """Base class to define a table in the database."""
  44. # The primary key for the table.
  45. id: Optional[int] = sqlmodel.Field(primary_key=True)
  46. def __init_subclass__(cls):
  47. """Drop the default primary key field if any primary key field is defined."""
  48. non_default_primary_key_fields = [
  49. field_name
  50. for field_name, field in cls.__fields__.items()
  51. if field_name != "id" and getattr(field.field_info, "primary_key", None)
  52. ]
  53. if non_default_primary_key_fields:
  54. cls.__fields__.pop("id", None)
  55. super().__init_subclass__()
  56. def dict(self, **kwargs):
  57. """Convert the object to a dictionary.
  58. Args:
  59. kwargs: Ignored but needed for compatibility.
  60. Returns:
  61. The object as a dictionary.
  62. """
  63. return {name: getattr(self, name) for name in self.__fields__}
  64. @staticmethod
  65. def create_all():
  66. """Create all the tables."""
  67. engine = get_engine()
  68. sqlmodel.SQLModel.metadata.create_all(engine)
  69. @staticmethod
  70. def get_db_engine():
  71. """Get the database engine.
  72. Returns:
  73. The database engine.
  74. """
  75. return get_engine()
  76. @staticmethod
  77. def _alembic_config():
  78. """Get the alembic configuration and script_directory.
  79. Returns:
  80. tuple of (config, script_directory)
  81. """
  82. config = alembic.config.Config(constants.ALEMBIC_CONFIG)
  83. return config, alembic.script.ScriptDirectory(
  84. config.get_main_option("script_location", default="version"),
  85. )
  86. @staticmethod
  87. def _alembic_render_item(
  88. type_: str,
  89. obj: Any,
  90. autogen_context: "alembic.autogenerate.api.AutogenContext",
  91. ):
  92. """Alembic render_item hook call.
  93. This method is called to provide python code for the given obj,
  94. but currently it is only used to add `sqlmodel` to the import list
  95. when generating migration scripts.
  96. See https://alembic.sqlalchemy.org/en/latest/api/runtime.html
  97. Args:
  98. type_: One of "schema", "table", "column", "index",
  99. "unique_constraint", or "foreign_key_constraint".
  100. obj: The object being rendered.
  101. autogen_context: Shared AutogenContext passed to each render_item call.
  102. Returns:
  103. False - Indicating that the default rendering should be used.
  104. """
  105. autogen_context.imports.add("import sqlmodel")
  106. return False
  107. @classmethod
  108. def alembic_init(cls):
  109. """Initialize alembic for the project."""
  110. alembic.command.init(
  111. config=alembic.config.Config(constants.ALEMBIC_CONFIG),
  112. directory=str(Path(constants.ALEMBIC_CONFIG).parent / "alembic"),
  113. )
  114. @classmethod
  115. def alembic_autogenerate(
  116. cls,
  117. connection: sqlalchemy.engine.Connection,
  118. message: str | None = None,
  119. write_migration_scripts: bool = True,
  120. ) -> bool:
  121. """Generate migration scripts for alembic-detectable changes.
  122. Args:
  123. connection: SQLAlchemy connection to use when detecting changes.
  124. message: Human readable identifier describing the generated revision.
  125. write_migration_scripts: If True, write autogenerated revisions to script directory.
  126. Returns:
  127. True when changes have been detected.
  128. """
  129. if not Path(constants.ALEMBIC_CONFIG).exists():
  130. return False
  131. config, script_directory = cls._alembic_config()
  132. revision_context = alembic.autogenerate.api.RevisionContext(
  133. config=config,
  134. script_directory=script_directory,
  135. command_args=defaultdict(
  136. lambda: None,
  137. autogenerate=True,
  138. head="head",
  139. message=message,
  140. ),
  141. )
  142. writer = alembic.autogenerate.rewriter.Rewriter()
  143. @writer.rewrites(alembic.operations.ops.AddColumnOp)
  144. def render_add_column_with_server_default(context, revision, op):
  145. # Carry the sqlmodel default as server_default so that newly added
  146. # columns get the desired default value in existing rows.
  147. if op.column.default is not None and op.column.server_default is None:
  148. op.column.server_default = sqlalchemy.DefaultClause(
  149. sqlalchemy.sql.expression.literal(op.column.default.arg),
  150. )
  151. return op
  152. def run_autogenerate(rev, context):
  153. revision_context.run_autogenerate(rev, context)
  154. return []
  155. with alembic.runtime.environment.EnvironmentContext(
  156. config=config,
  157. script=script_directory,
  158. fn=run_autogenerate,
  159. ) as env:
  160. env.configure(
  161. connection=connection,
  162. target_metadata=sqlmodel.SQLModel.metadata,
  163. render_item=cls._alembic_render_item,
  164. process_revision_directives=writer, # type: ignore
  165. )
  166. env.run_migrations()
  167. changes_detected = False
  168. if revision_context.generated_revisions:
  169. upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops
  170. if upgrade_ops is not None:
  171. changes_detected = bool(upgrade_ops.ops)
  172. if changes_detected and write_migration_scripts:
  173. # Must iterate the generator to actually write the scripts.
  174. _ = tuple(revision_context.generate_scripts())
  175. return changes_detected
  176. @classmethod
  177. def _alembic_upgrade(
  178. cls,
  179. connection: sqlalchemy.engine.Connection,
  180. to_rev: str = "head",
  181. ) -> None:
  182. """Apply alembic migrations up to the given revision.
  183. Args:
  184. connection: SQLAlchemy connection to use when performing upgrade.
  185. to_rev: Revision to migrate towards.
  186. """
  187. config, script_directory = cls._alembic_config()
  188. def run_upgrade(rev, context):
  189. return script_directory._upgrade_revs(to_rev, rev)
  190. with alembic.runtime.environment.EnvironmentContext(
  191. config=config,
  192. script=script_directory,
  193. fn=run_upgrade,
  194. ) as env:
  195. env.configure(connection=connection)
  196. env.run_migrations()
  197. @classmethod
  198. def migrate(cls, autogenerate: bool = False) -> bool | None:
  199. """Execute alembic migrations for all sqlmodel Model classes.
  200. If alembic is not installed or has not been initialized for the project,
  201. then no action is performed.
  202. If there are no revisions currently tracked by alembic, then
  203. an initial revision will be created based on sqlmodel metadata.
  204. If models in the app have changed in incompatible ways that alembic
  205. cannot automatically generate revisions for, the app may not be able to
  206. start up until migration scripts have been corrected by hand.
  207. Args:
  208. autogenerate: If True, generate migration script and use it to upgrade schema
  209. (otherwise, just bring the schema to current "head" revision).
  210. Returns:
  211. True - indicating the process was successful.
  212. None - indicating the process was skipped.
  213. """
  214. if not Path(constants.ALEMBIC_CONFIG).exists():
  215. return
  216. with cls.get_db_engine().connect() as connection:
  217. cls._alembic_upgrade(connection=connection)
  218. if autogenerate:
  219. changes_detected = cls.alembic_autogenerate(connection=connection)
  220. if changes_detected:
  221. cls._alembic_upgrade(connection=connection)
  222. connection.commit()
  223. return True
  224. @classmethod
  225. @property
  226. def select(cls):
  227. """Select rows from the table.
  228. Returns:
  229. The select statement.
  230. """
  231. return sqlmodel.select(cls)
  232. def session(url: str | None = None) -> sqlmodel.Session:
  233. """Get a session to interact with the database.
  234. Args:
  235. url: The database url.
  236. Returns:
  237. A database session.
  238. """
  239. return sqlmodel.Session(get_engine(url))