Bläddra i källkod

[ENG-3953] Support pydantic BaseModel (v1 and v2) as state var (#4338)

* [ENG-3953] Support pydantic BaseModel (v1 and v2) as state var

Provide serializers and mutable proxy tracking for pydantic models directly.

* conditionally define v2 serializer

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>

* Add `MutableProxy._is_mutable_value` to avoid duplicate logic

* Conditionally import BaseModel to handle older pydantic v1 versions

* pre-commit fu

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
Masen Furer 5 månader sedan
förälder
incheckning
a6b324bd3e
3 ändrade filer med 128 tillägg och 6 borttagningar
  1. 32 6
      reflex/state.py
  2. 47 0
      reflex/utils/serializers.py
  3. 49 0
      tests/units/test_state.py

+ 32 - 6
reflex/state.py

@@ -62,6 +62,13 @@ try:
 except ModuleNotFoundError:
 except ModuleNotFoundError:
     import pydantic
     import pydantic
 
 
+from pydantic import BaseModel as BaseModelV2
+
+try:
+    from pydantic.v1 import BaseModel as BaseModelV1
+except ModuleNotFoundError:
+    BaseModelV1 = BaseModelV2
+
 import wrapt
 import wrapt
 from redis.asyncio import Redis
 from redis.asyncio import Redis
 from redis.exceptions import ResponseError
 from redis.exceptions import ResponseError
@@ -1250,7 +1257,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             if parent_state is not None:
             if parent_state is not None:
                 return getattr(parent_state, name)
                 return getattr(parent_state, name)
 
 
-        if isinstance(value, MutableProxy.__mutable_types__) and (
+        if MutableProxy._is_mutable_type(value) and (
             name in super().__getattribute__("base_vars") or name in backend_vars
             name in super().__getattribute__("base_vars") or name in backend_vars
         ):
         ):
             # track changes in mutable containers (list, dict, set, etc)
             # track changes in mutable containers (list, dict, set, etc)
@@ -3558,7 +3565,16 @@ class MutableProxy(wrapt.ObjectProxy):
         pydantic.BaseModel.__dict__
         pydantic.BaseModel.__dict__
     )
     )
 
 
-    __mutable_types__ = (list, dict, set, Base, DeclarativeBase)
+    # These types will be wrapped in MutableProxy
+    __mutable_types__ = (
+        list,
+        dict,
+        set,
+        Base,
+        DeclarativeBase,
+        BaseModelV2,
+        BaseModelV1,
+    )
 
 
     def __init__(self, wrapped: Any, state: BaseState, field_name: str):
     def __init__(self, wrapped: Any, state: BaseState, field_name: str):
         """Create a proxy for a mutable object that tracks changes.
         """Create a proxy for a mutable object that tracks changes.
@@ -3598,6 +3614,18 @@ class MutableProxy(wrapt.ObjectProxy):
         if wrapped is not None:
         if wrapped is not None:
             return wrapped(*args, **(kwargs or {}))
             return wrapped(*args, **(kwargs or {}))
 
 
+    @classmethod
+    def _is_mutable_type(cls, value: Any) -> bool:
+        """Check if a value is of a mutable type and should be wrapped.
+
+        Args:
+            value: The value to check.
+
+        Returns:
+            Whether the value is of a mutable type.
+        """
+        return isinstance(value, cls.__mutable_types__)
+
     def _wrap_recursive(self, value: Any) -> Any:
     def _wrap_recursive(self, value: Any) -> Any:
         """Wrap a value recursively if it is mutable.
         """Wrap a value recursively if it is mutable.
 
 
@@ -3608,9 +3636,7 @@ class MutableProxy(wrapt.ObjectProxy):
             The wrapped value.
             The wrapped value.
         """
         """
         # Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
         # Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
-        if isinstance(value, self.__mutable_types__) and not isinstance(
-            value, MutableProxy
-        ):
+        if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
             return type(self)(
             return type(self)(
                 wrapped=value,
                 wrapped=value,
                 state=self._self_state,
                 state=self._self_state,
@@ -3668,7 +3694,7 @@ class MutableProxy(wrapt.ObjectProxy):
                     self._wrap_recursive_decorator,
                     self._wrap_recursive_decorator,
                 )
                 )
 
 
-        if isinstance(value, self.__mutable_types__) and __name not in (
+        if self._is_mutable_type(value) and __name not in (
             "__wrapped__",
             "__wrapped__",
             "_self_state",
             "_self_state",
         ):
         ):

+ 47 - 0
reflex/utils/serializers.py

@@ -270,6 +270,53 @@ def serialize_base(value: Base) -> dict:
     }
     }
 
 
 
 
+try:
+    from pydantic.v1 import BaseModel as BaseModelV1
+
+    @serializer(to=dict)
+    def serialize_base_model_v1(model: BaseModelV1) -> dict:
+        """Serialize a pydantic v1 BaseModel instance.
+
+        Args:
+            model: The BaseModel to serialize.
+
+        Returns:
+            The serialized BaseModel.
+        """
+        return model.dict()
+
+    from pydantic import BaseModel as BaseModelV2
+
+    if BaseModelV1 is not BaseModelV2:
+
+        @serializer(to=dict)
+        def serialize_base_model_v2(model: BaseModelV2) -> dict:
+            """Serialize a pydantic v2 BaseModel instance.
+
+            Args:
+                model: The BaseModel to serialize.
+
+            Returns:
+                The serialized BaseModel.
+            """
+            return model.model_dump()
+except ImportError:
+    # Older pydantic v1 import
+    from pydantic import BaseModel as BaseModelV1
+
+    @serializer(to=dict)
+    def serialize_base_model_v1(model: BaseModelV1) -> dict:
+        """Serialize a pydantic v1 BaseModel instance.
+
+        Args:
+            model: The BaseModel to serialize.
+
+        Returns:
+            The serialized BaseModel.
+        """
+        return model.dict()
+
+
 @serializer
 @serializer
 def serialize_set(value: Set) -> list:
 def serialize_set(value: Set) -> list:
     """Serialize a set to a JSON serializable list.
     """Serialize a set to a JSON serializable list.

+ 49 - 0
tests/units/test_state.py

@@ -16,6 +16,8 @@ from unittest.mock import AsyncMock, Mock
 import pytest
 import pytest
 import pytest_asyncio
 import pytest_asyncio
 from plotly.graph_objects import Figure
 from plotly.graph_objects import Figure
+from pydantic import BaseModel as BaseModelV2
+from pydantic.v1 import BaseModel as BaseModelV1
 
 
 import reflex as rx
 import reflex as rx
 import reflex.config
 import reflex.config
@@ -3413,6 +3415,53 @@ def test_typed_state() -> None:
     _ = TypedState(field="str")
     _ = TypedState(field="str")
 
 
 
 
+class ModelV1(BaseModelV1):
+    """A pydantic BaseModel v1."""
+
+    foo: str = "bar"
+
+
+class ModelV2(BaseModelV2):
+    """A pydantic BaseModel v2."""
+
+    foo: str = "bar"
+
+
+@dataclasses.dataclass
+class ModelDC:
+    """A dataclass."""
+
+    foo: str = "bar"
+
+
+class PydanticState(rx.State):
+    """A state with pydantic BaseModel vars."""
+
+    v1: ModelV1 = ModelV1()
+    v2: ModelV2 = ModelV2()
+    dc: ModelDC = ModelDC()
+
+
+def test_mutable_models():
+    """Test that dataclass and pydantic BaseModel v1 and v2 use dep tracking."""
+    state = PydanticState()
+    assert isinstance(state.v1, MutableProxy)
+    state.v1.foo = "baz"
+    assert state.dirty_vars == {"v1"}
+    state.dirty_vars.clear()
+
+    assert isinstance(state.v2, MutableProxy)
+    state.v2.foo = "baz"
+    assert state.dirty_vars == {"v2"}
+    state.dirty_vars.clear()
+
+    # Not yet supported ENG-4083
+    # assert isinstance(state.dc, MutableProxy)
+    # state.dc.foo = "baz"
+    # assert state.dirty_vars == {"dc"}
+    # state.dirty_vars.clear()
+
+
 def test_get_value():
 def test_get_value():
     class GetValueState(rx.State):
     class GetValueState(rx.State):
         foo: str = "FOO"
         foo: str = "FOO"