base.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. """Define the base Reflex class."""
  2. from __future__ import annotations
  3. from typing import Any, Dict
  4. import pydantic
  5. from pydantic.fields import ModelField
  6. class Base(pydantic.BaseModel):
  7. """The base class subclassed by all Reflex classes.
  8. This class wraps Pydantic and provides common methods such as
  9. serialization and setting fields.
  10. Any data structure that needs to be transferred between the
  11. frontend and backend should subclass this class.
  12. """
  13. class Config:
  14. """Pydantic config."""
  15. arbitrary_types_allowed = True
  16. use_enum_values = True
  17. extra = "allow"
  18. def json(self) -> str:
  19. """Convert the object to a json string.
  20. Returns:
  21. The object as a json string.
  22. """
  23. return self.__config__.json_dumps(self.dict(), default=list)
  24. def set(self, **kwargs):
  25. """Set multiple fields and return the object.
  26. Args:
  27. **kwargs: The fields and values to set.
  28. Returns:
  29. The object with the fields set.
  30. """
  31. for key, value in kwargs.items():
  32. setattr(self, key, value)
  33. return self
  34. @classmethod
  35. def get_fields(cls) -> Dict[str, Any]:
  36. """Get the fields of the object.
  37. Returns:
  38. The fields of the object.
  39. """
  40. return cls.__fields__
  41. @classmethod
  42. def add_field(cls, var: Any, default_value: Any):
  43. """Add a pydantic field after class definition.
  44. Used by State.add_var() to correctly handle the new variable.
  45. Args:
  46. var: The variable to add a pydantic field for.
  47. default_value: The default value of the field
  48. """
  49. new_field = ModelField.infer(
  50. name=var.name,
  51. value=default_value,
  52. annotation=var.type_,
  53. class_validators=None,
  54. config=cls.__config__,
  55. )
  56. cls.__fields__.update({var.name: new_field})
  57. def get_value(self, key: str) -> Any:
  58. """Get the value of a field.
  59. Args:
  60. key: The key of the field.
  61. Returns:
  62. The value of the field.
  63. """
  64. return self._get_value(
  65. key,
  66. to_dict=True,
  67. by_alias=False,
  68. include=None,
  69. exclude=None,
  70. exclude_unset=False,
  71. exclude_defaults=False,
  72. exclude_none=False,
  73. )