model.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. """Database built into Pynecone."""
  2. from typing import Optional
  3. import sqlmodel
  4. from pynecone.base import Base
  5. from pynecone.config import get_config
  6. def get_engine():
  7. """Get the database engine.
  8. Returns:
  9. The database engine.
  10. Raises:
  11. ValueError: If the database url is None.
  12. """
  13. url = get_config().db_url
  14. enable_admin = get_config().enable_admin
  15. if not url:
  16. raise ValueError("No database url in config")
  17. return sqlmodel.create_engine(
  18. url,
  19. echo=False,
  20. connect_args={"check_same_thread": False} if enable_admin else {},
  21. )
  22. class Model(Base, sqlmodel.SQLModel):
  23. """Base class to define a table in the database."""
  24. # The primary key for the table.
  25. id: Optional[int] = sqlmodel.Field(primary_key=True)
  26. def __init_subclass__(cls):
  27. """Drop the default primary key field if any primary key field is defined."""
  28. non_default_primary_key_fields = [
  29. field_name
  30. for field_name, field in cls.__fields__.items()
  31. if field_name != "id" and getattr(field.field_info, "primary_key", None)
  32. ]
  33. if non_default_primary_key_fields:
  34. cls.__fields__.pop("id", None)
  35. super().__init_subclass__()
  36. def dict(self, **kwargs):
  37. """Convert the object to a dictionary.
  38. Args:
  39. kwargs: Ignored but needed for compatibility.
  40. Returns:
  41. The object as a dictionary.
  42. """
  43. return {name: getattr(self, name) for name in self.__fields__}
  44. @staticmethod
  45. def create_all():
  46. """Create all the tables."""
  47. engine = get_engine()
  48. sqlmodel.SQLModel.metadata.create_all(engine)
  49. @staticmethod
  50. def get_db_engine():
  51. """Get the database engine.
  52. Returns:
  53. The database engine.
  54. """
  55. return get_engine()
  56. @classmethod
  57. @property
  58. def select(cls):
  59. """Select rows from the table.
  60. Returns:
  61. The select statement.
  62. """
  63. return sqlmodel.select(cls)
  64. def session(url=None):
  65. """Get a session to interact with the database.
  66. Args:
  67. url: The database url.
  68. Returns:
  69. A database session.
  70. """
  71. enable_admin = get_config().enable_admin
  72. if url is not None:
  73. return sqlmodel.Session(
  74. sqlmodel.create_engine(
  75. url,
  76. connect_args={"check_same_thread": False} if enable_admin else {},
  77. ),
  78. )
  79. engine = get_engine()
  80. return sqlmodel.Session(engine)