model.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. """Database built into Pynecone."""
  2. from typing import Optional
  3. import sqlmodel
  4. from pynecone.base import Base
  5. from pynecone.config import get_config
  6. def get_engine():
  7. """Get the database engine.
  8. Returns:
  9. The database engine.
  10. Raises:
  11. ValueError: If the database url is None.
  12. """
  13. url = get_config().db_url
  14. if not url:
  15. raise ValueError("No database url in config")
  16. return sqlmodel.create_engine(url, echo=False)
  17. class Model(Base, sqlmodel.SQLModel):
  18. """Base class to define a table in the database."""
  19. # The primary key for the table.
  20. id: Optional[int] = sqlmodel.Field(primary_key=True)
  21. def __init_subclass__(cls):
  22. """Drop the default primary key field if any primary key field is defined."""
  23. non_default_primary_key_fields = [
  24. field_name
  25. for field_name, field in cls.__fields__.items()
  26. if field_name != "id" and getattr(field.field_info, "primary_key", None)
  27. ]
  28. if non_default_primary_key_fields:
  29. cls.__fields__.pop("id", None)
  30. super().__init_subclass__()
  31. def dict(self, **kwargs):
  32. """Convert the object to a dictionary.
  33. Args:
  34. kwargs: Ignored but needed for compatibility.
  35. Returns:
  36. The object as a dictionary.
  37. """
  38. return {name: getattr(self, name) for name in self.__fields__}
  39. @staticmethod
  40. def create_all():
  41. """Create all the tables."""
  42. engine = get_engine()
  43. sqlmodel.SQLModel.metadata.create_all(engine)
  44. @classmethod
  45. @property
  46. def select(cls):
  47. """Select rows from the table.
  48. Returns:
  49. The select statement.
  50. """
  51. return sqlmodel.select(cls)
  52. def session(url=None):
  53. """Get a session to interact with the database.
  54. Args:
  55. url: The database url.
  56. Returns:
  57. A database session.
  58. """
  59. if url is not None:
  60. return sqlmodel.Session(sqlmodel.create_engine(url))
  61. engine = get_engine()
  62. return sqlmodel.Session(engine)