base.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. """Define the base Reflex class."""
  2. from __future__ import annotations
  3. import os
  4. from typing import TYPE_CHECKING, Any, List, Type
  5. try:
  6. import pydantic.v1.main as pydantic_main
  7. from pydantic.v1 import BaseModel
  8. from pydantic.v1.fields import ModelField
  9. except ModuleNotFoundError:
  10. if not TYPE_CHECKING:
  11. import pydantic.main as pydantic_main
  12. from pydantic import BaseModel
  13. from pydantic.fields import ModelField # type: ignore
  14. def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:
  15. """Ensure that the field's name does not shadow an existing attribute of the model.
  16. Args:
  17. bases: List of base models to check for shadowed attrs.
  18. field_name: name of attribute
  19. Raises:
  20. VarNameError: If state var field shadows another in its parent state
  21. """
  22. from reflex.utils.exceptions import VarNameError
  23. # can't use reflex.config.environment here cause of circular import
  24. reload = os.getenv("__RELOAD_CONFIG", "").lower() == "true"
  25. base = None
  26. try:
  27. for base in bases:
  28. if not reload and getattr(base, field_name, None):
  29. pass
  30. except TypeError as te:
  31. raise VarNameError(
  32. f'State var "{field_name}" in {base} has been shadowed by a substate var; '
  33. f'use a different field name instead".'
  34. ) from te
  35. # monkeypatch pydantic validate_field_name method to skip validating
  36. # shadowed state vars when reloading app via utils.prerequisites.get_app(reload=True)
  37. pydantic_main.validate_field_name = validate_field_name # type: ignore
  38. if TYPE_CHECKING:
  39. from reflex.vars import Var
  40. class Base(BaseModel): # pyright: ignore [reportUnboundVariable]
  41. """The base class subclassed by all Reflex classes.
  42. This class wraps Pydantic and provides common methods such as
  43. serialization and setting fields.
  44. Any data structure that needs to be transferred between the
  45. frontend and backend should subclass this class.
  46. """
  47. class Config:
  48. """Pydantic config."""
  49. arbitrary_types_allowed = True
  50. use_enum_values = True
  51. extra = "allow"
  52. def json(self) -> str:
  53. """Convert the object to a json string.
  54. Returns:
  55. The object as a json string.
  56. """
  57. from reflex.utils.serializers import serialize
  58. return self.__config__.json_dumps( # type: ignore
  59. self.dict(),
  60. default=serialize,
  61. )
  62. def set(self, **kwargs):
  63. """Set multiple fields and return the object.
  64. Args:
  65. **kwargs: The fields and values to set.
  66. Returns:
  67. The object with the fields set.
  68. """
  69. for key, value in kwargs.items():
  70. setattr(self, key, value)
  71. return self
  72. @classmethod
  73. def get_fields(cls) -> dict[str, ModelField]:
  74. """Get the fields of the object.
  75. Returns:
  76. The fields of the object.
  77. """
  78. return cls.__fields__
  79. @classmethod
  80. def add_field(cls, var: Var, default_value: Any):
  81. """Add a pydantic field after class definition.
  82. Used by State.add_var() to correctly handle the new variable.
  83. Args:
  84. var: The variable to add a pydantic field for.
  85. default_value: The default value of the field
  86. """
  87. var_name = var._var_field_name
  88. new_field = ModelField.infer(
  89. name=var_name,
  90. value=default_value,
  91. annotation=var._var_type,
  92. class_validators=None,
  93. config=cls.__config__, # type: ignore
  94. )
  95. cls.__fields__.update({var_name: new_field})
  96. def get_value(self, key: str) -> Any:
  97. """Get the value of a field.
  98. Args:
  99. key: The key of the field.
  100. Returns:
  101. The value of the field.
  102. """
  103. if isinstance(key, str):
  104. # Seems like this function signature was wrong all along?
  105. # If the user wants a field that we know of, get it and pass it off to _get_value
  106. return getattr(self, key, key)
  107. return key