Jelajahi Sumber

use serializer for state update and rework serializers (#3934)

* use serializer for state update and rework serializers

* format
Khaleel Al-Adhami 8 bulan lalu
induk
melakukan
a57095ffe8

+ 1 - 23
reflex/state.py

@@ -8,7 +8,6 @@ import copy
 import dataclasses
 import dataclasses
 import functools
 import functools
 import inspect
 import inspect
-import json
 import os
 import os
 import uuid
 import uuid
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
@@ -206,27 +205,6 @@ class RouterData:
         object.__setattr__(self, "headers", HeaderData(router_data))
         object.__setattr__(self, "headers", HeaderData(router_data))
         object.__setattr__(self, "page", PageData(router_data))
         object.__setattr__(self, "page", PageData(router_data))
 
 
-    def toJson(self) -> str:
-        """Convert the object to a JSON string.
-
-        Returns:
-            The JSON string.
-        """
-        return json.dumps(dataclasses.asdict(self))
-
-
-@serializer
-def serialize_routerdata(value: RouterData) -> str:
-    """Serialize a RouterData instance.
-
-    Args:
-        value: The RouterData to serialize.
-
-    Returns:
-        The serialized RouterData.
-    """
-    return value.toJson()
-
 
 
 def _no_chain_background_task(
 def _no_chain_background_task(
     state_cls: Type["BaseState"], name: str, fn: Callable
     state_cls: Type["BaseState"], name: str, fn: Callable
@@ -2415,7 +2393,7 @@ class StateUpdate:
         Returns:
         Returns:
             The state update as a JSON string.
             The state update as a JSON string.
         """
         """
-        return json.dumps(dataclasses.asdict(self))
+        return format.json_dumps(dataclasses.asdict(self))
 
 
 
 
 class StateManager(Base, ABC):
 class StateManager(Base, ABC):

+ 2 - 22
reflex/utils/format.py

@@ -2,7 +2,6 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
-import dataclasses
 import inspect
 import inspect
 import json
 import json
 import os
 import os
@@ -410,22 +409,11 @@ def format_props(*single_props, **key_value_props) -> list[str]:
 
 
     return [
     return [
         (
         (
-            f"{name}={format_prop(prop)}"
-            if isinstance(prop, Var) and not isinstance(prop, Var)
-            else (
-                f"{name}={{{format_prop(prop if isinstance(prop, Var) else LiteralVar.create(prop))}}}"
-            )
+            f"{name}={{{format_prop(prop if isinstance(prop, Var) else LiteralVar.create(prop))}}}"
         )
         )
         for name, prop in sorted(key_value_props.items())
         for name, prop in sorted(key_value_props.items())
         if prop is not None
         if prop is not None
-    ] + [
-        (
-            str(prop)
-            if isinstance(prop, Var) and not isinstance(prop, Var)
-            else f"{str(LiteralVar.create(prop))}"
-        )
-        for prop in single_props
-    ]
+    ] + [(f"{str(LiteralVar.create(prop))}") for prop in single_props]
 
 
 
 
 def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
 def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
@@ -623,14 +611,6 @@ def format_state(value: Any, key: Optional[str] = None) -> Any:
     if isinstance(value, dict):
     if isinstance(value, dict):
         return {k: format_state(v, k) for k, v in value.items()}
         return {k: format_state(v, k) for k, v in value.items()}
 
 
-    # Hand dataclasses.
-    if dataclasses.is_dataclass(value):
-        if isinstance(value, type):
-            raise TypeError(
-                f"Cannot format state of type {type(value)}. Please provide an instance of the dataclass."
-            )
-        return {k: format_state(v, k) for k, v in dataclasses.asdict(value).items()}
-
     # Handle lists, sets, typles.
     # Handle lists, sets, typles.
     if isinstance(value, types.StateIterBases):
     if isinstance(value, types.StateIterBases):
         return [format_state(v) for v in value]
         return [format_state(v) for v in value]

+ 12 - 17
reflex/utils/serializers.py

@@ -2,6 +2,7 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
+import dataclasses
 import functools
 import functools
 import json
 import json
 import warnings
 import warnings
@@ -29,7 +30,7 @@ from reflex.utils import types
 
 
 # Mapping from type to a serializer.
 # Mapping from type to a serializer.
 # The serializer should convert the type to a JSON object.
 # The serializer should convert the type to a JSON object.
-SerializedType = Union[str, bool, int, float, list, dict]
+SerializedType = Union[str, bool, int, float, list, dict, None]
 
 
 
 
 Serializer = Callable[[Type], SerializedType]
 Serializer = Callable[[Type], SerializedType]
@@ -124,6 +125,8 @@ def serialize(
 
 
     # If there is no serializer, return None.
     # If there is no serializer, return None.
     if serializer is None:
     if serializer is None:
+        if dataclasses.is_dataclass(value) and not isinstance(value, type):
+            return serialize(dataclasses.asdict(value))
         if get_type:
         if get_type:
             return None, None
             return None, None
         return None
         return None
@@ -225,7 +228,7 @@ def serialize_str(value: str) -> str:
 
 
 
 
 @serializer
 @serializer
-def serialize_primitive(value: Union[bool, int, float, None]) -> str:
+def serialize_primitive(value: Union[bool, int, float, None]):
     """Serialize a primitive type.
     """Serialize a primitive type.
 
 
     Args:
     Args:
@@ -234,13 +237,11 @@ def serialize_primitive(value: Union[bool, int, float, None]) -> str:
     Returns:
     Returns:
         The serialized number/bool/None.
         The serialized number/bool/None.
     """
     """
-    from reflex.utils import format
-
-    return format.json_dumps(value)
+    return value
 
 
 
 
 @serializer
 @serializer
-def serialize_base(value: Base) -> str:
+def serialize_base(value: Base) -> dict:
     """Serialize a Base instance.
     """Serialize a Base instance.
 
 
     Args:
     Args:
@@ -249,13 +250,11 @@ def serialize_base(value: Base) -> str:
     Returns:
     Returns:
         The serialized Base.
         The serialized Base.
     """
     """
-    from reflex.vars import LiteralVar
-
-    return str(LiteralVar.create(value))
+    return {k: serialize(v) for k, v in value.dict().items() if not callable(v)}
 
 
 
 
 @serializer
 @serializer
-def serialize_list(value: Union[List, Tuple, Set]) -> str:
+def serialize_list(value: Union[List, Tuple, Set]) -> list:
     """Serialize a list to a JSON string.
     """Serialize a list to a JSON string.
 
 
     Args:
     Args:
@@ -264,13 +263,11 @@ def serialize_list(value: Union[List, Tuple, Set]) -> str:
     Returns:
     Returns:
         The serialized list.
         The serialized list.
     """
     """
-    from reflex.vars import LiteralArrayVar
-
-    return str(LiteralArrayVar.create(value))
+    return [serialize(item) for item in value]
 
 
 
 
 @serializer
 @serializer
-def serialize_dict(prop: Dict[str, Any]) -> str:
+def serialize_dict(prop: Dict[str, Any]) -> dict:
     """Serialize a dictionary to a JSON string.
     """Serialize a dictionary to a JSON string.
 
 
     Args:
     Args:
@@ -279,9 +276,7 @@ def serialize_dict(prop: Dict[str, Any]) -> str:
     Returns:
     Returns:
         The serialized dictionary.
         The serialized dictionary.
     """
     """
-    from reflex.vars import LiteralObjectVar
-
-    return str(LiteralObjectVar.create(prop))
+    return {k: serialize(v) for k, v in prop.items()}
 
 
 
 
 @serializer(to=str)
 @serializer(to=str)

+ 13 - 15
reflex/vars/base.py

@@ -936,21 +936,6 @@ class Var(Generic[VAR_TYPE]):
 OUTPUT = TypeVar("OUTPUT", bound=Var)
 OUTPUT = TypeVar("OUTPUT", bound=Var)
 
 
 
 
-def _encode_var(value: Var) -> str:
-    """Encode the state name into a formatted var.
-
-    Args:
-        value: The value to encode the state name into.
-
-    Returns:
-        The encoded var.
-    """
-    return f"{value}"
-
-
-serializers.serializer(_encode_var)
-
-
 class LiteralVar(Var):
 class LiteralVar(Var):
     """Base class for immutable literal vars."""
     """Base class for immutable literal vars."""
 
 
@@ -1101,6 +1086,19 @@ class LiteralVar(Var):
         )
         )
 
 
 
 
+@serializers.serializer
+def serialize_literal(value: LiteralVar):
+    """Serialize a Literal type.
+
+    Args:
+        value: The Literal to serialize.
+
+    Returns:
+        The serialized Literal.
+    """
+    return serializers.serialize(value._var_value)
+
+
 P = ParamSpec("P")
 P = ParamSpec("P")
 T = TypeVar("T")
 T = TypeVar("T")
 
 

+ 21 - 21
tests/utils/test_format.py

@@ -352,28 +352,28 @@ def test_format_match(
     "prop,formatted",
     "prop,formatted",
     [
     [
         ("string", '"string"'),
         ("string", '"string"'),
-        ("{wrapped_string}", "{wrapped_string}"),
-        (True, "{true}"),
-        (False, "{false}"),
-        (123, "{123}"),
-        (3.14, "{3.14}"),
-        ([1, 2, 3], "{[1, 2, 3]}"),
-        (["a", "b", "c"], '{["a", "b", "c"]}'),
-        ({"a": 1, "b": 2, "c": 3}, '{({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })}'),
-        ({"a": 'foo "bar" baz'}, r'{({ ["a"] : "foo \"bar\" baz" })}'),
+        ("{wrapped_string}", '"{wrapped_string}"'),
+        (True, "true"),
+        (False, "false"),
+        (123, "123"),
+        (3.14, "3.14"),
+        ([1, 2, 3], "[1, 2, 3]"),
+        (["a", "b", "c"], '["a", "b", "c"]'),
+        ({"a": 1, "b": 2, "c": 3}, '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })'),
+        ({"a": 'foo "bar" baz'}, r'({ ["a"] : "foo \"bar\" baz" })'),
         (
         (
             {
             {
                 "a": 'foo "{ "bar" }" baz',
                 "a": 'foo "{ "bar" }" baz',
                 "b": Var(_js_expr="val", _var_type=str).guess_type(),
                 "b": Var(_js_expr="val", _var_type=str).guess_type(),
             },
             },
-            r'{({ ["a"] : "foo \"{ \"bar\" }\" baz", ["b"] : val })}',
+            r'({ ["a"] : "foo \"{ \"bar\" }\" baz", ["b"] : val })',
         ),
         ),
         (
         (
             EventChain(
             EventChain(
                 events=[EventSpec(handler=EventHandler(fn=mock_event))],
                 events=[EventSpec(handler=EventHandler(fn=mock_event))],
                 args_spec=lambda: [],
                 args_spec=lambda: [],
             ),
             ),
-            '{(...args) => addEvents([Event("mock_event", {})], args, {})}',
+            '((...args) => ((addEvents([(Event("mock_event", ({  })))], args, ({  })))))',
         ),
         ),
         (
         (
             EventChain(
             EventChain(
@@ -382,7 +382,7 @@ def test_format_match(
                         handler=EventHandler(fn=mock_event),
                         handler=EventHandler(fn=mock_event),
                         args=(
                         args=(
                             (
                             (
-                                LiteralVar.create("arg"),
+                                Var(_js_expr="arg"),
                                 Var(
                                 Var(
                                     _js_expr="_e",
                                     _js_expr="_e",
                                 )
                                 )
@@ -394,7 +394,7 @@ def test_format_match(
                 ],
                 ],
                 args_spec=lambda e: [e.target.value],
                 args_spec=lambda e: [e.target.value],
             ),
             ),
-            '{(_e) => addEvents([Event("mock_event", {"arg":_e["target"]["value"]})], [_e], {})}',
+            '((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] })))], [_e], ({  })))))',
         ),
         ),
         (
         (
             EventChain(
             EventChain(
@@ -402,7 +402,7 @@ def test_format_match(
                 args_spec=lambda: [],
                 args_spec=lambda: [],
                 event_actions={"stopPropagation": True},
                 event_actions={"stopPropagation": True},
             ),
             ),
-            '{(...args) => addEvents([Event("mock_event", {})], args, {"stopPropagation": true})}',
+            '((...args) => ((addEvents([(Event("mock_event", ({  })))], args, ({ ["stopPropagation"] : true })))))',
         ),
         ),
         (
         (
             EventChain(
             EventChain(
@@ -410,9 +410,9 @@ def test_format_match(
                 args_spec=lambda: [],
                 args_spec=lambda: [],
                 event_actions={"preventDefault": True},
                 event_actions={"preventDefault": True},
             ),
             ),
-            '{(...args) => addEvents([Event("mock_event", {})], args, {"preventDefault": true})}',
+            '((...args) => ((addEvents([(Event("mock_event", ({  })))], args, ({ ["preventDefault"] : true })))))',
         ),
         ),
-        ({"a": "red", "b": "blue"}, '{({ ["a"] : "red", ["b"] : "blue" })}'),
+        ({"a": "red", "b": "blue"}, '({ ["a"] : "red", ["b"] : "blue" })'),
         (Var(_js_expr="var", _var_type=int).guess_type(), "var"),
         (Var(_js_expr="var", _var_type=int).guess_type(), "var"),
         (
         (
             Var(
             Var(
@@ -427,15 +427,15 @@ def test_format_match(
         ),
         ),
         (
         (
             {"a": Var(_js_expr="val", _var_type=str).guess_type()},
             {"a": Var(_js_expr="val", _var_type=str).guess_type()},
-            '{({ ["a"] : val })}',
+            '({ ["a"] : val })',
         ),
         ),
         (
         (
             {"a": Var(_js_expr='"val"', _var_type=str).guess_type()},
             {"a": Var(_js_expr='"val"', _var_type=str).guess_type()},
-            '{({ ["a"] : "val" })}',
+            '({ ["a"] : "val" })',
         ),
         ),
         (
         (
             {"a": Var(_js_expr='state.colors["val"]', _var_type=str).guess_type()},
             {"a": Var(_js_expr='state.colors["val"]', _var_type=str).guess_type()},
-            '{({ ["a"] : state.colors["val"] })}',
+            '({ ["a"] : state.colors["val"] })',
         ),
         ),
         # tricky real-world case from markdown component
         # tricky real-world case from markdown component
         (
         (
@@ -444,7 +444,7 @@ def test_format_match(
                     _js_expr=f"(({{node, ...props}}) => <Heading {{...props}} {''.join(Tag(name='', props=Style({'as_': 'h1'})).format_props())} />)"
                     _js_expr=f"(({{node, ...props}}) => <Heading {{...props}} {''.join(Tag(name='', props=Style({'as_': 'h1'})).format_props())} />)"
                 ),
                 ),
             },
             },
-            '{({ ["h1"] : (({node, ...props}) => <Heading {...props} as={"h1"} />) })}',
+            '({ ["h1"] : (({node, ...props}) => <Heading {...props} as={"h1"} />) })',
         ),
         ),
     ],
     ],
 )
 )
@@ -455,7 +455,7 @@ def test_format_prop(prop: Var, formatted: str):
         prop: The prop to test.
         prop: The prop to test.
         formatted: The expected formatted value.
         formatted: The expected formatted value.
     """
     """
-    assert format.format_prop(prop) == formatted
+    assert format.format_prop(LiteralVar.create(prop)) == formatted
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(

+ 35 - 24
tests/utils/test_serializers.py

@@ -8,7 +8,7 @@ import pytest
 from reflex.base import Base
 from reflex.base import Base
 from reflex.components.core.colors import Color
 from reflex.components.core.colors import Color
 from reflex.utils import serializers
 from reflex.utils import serializers
-from reflex.vars.base import LiteralVar, Var
+from reflex.vars.base import LiteralVar
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
@@ -123,48 +123,59 @@ class BaseSubclass(Base):
     "value,expected",
     "value,expected",
     [
     [
         ("test", "test"),
         ("test", "test"),
-        (1, "1"),
-        (1.0, "1.0"),
-        (True, "true"),
-        (False, "false"),
-        (None, "null"),
-        ([1, 2, 3], "[1, 2, 3]"),
-        ([1, "2", 3.0], '[1, "2", 3.0]'),
-        ([{"key": 1}, {"key": 2}], '[({ ["key"] : 1 }), ({ ["key"] : 2 })]'),
+        (1, 1),
+        (1.0, 1.0),
+        (True, True),
+        (False, False),
+        (None, None),
+        ([1, 2, 3], [1, 2, 3]),
+        ([1, "2", 3.0], [1, "2", 3.0]),
+        ([{"key": 1}, {"key": 2}], [{"key": 1}, {"key": 2}]),
         (StrEnum.FOO, "foo"),
         (StrEnum.FOO, "foo"),
-        ([StrEnum.FOO, StrEnum.BAR], '["foo", "bar"]'),
+        ([StrEnum.FOO, StrEnum.BAR], ["foo", "bar"]),
         (
         (
             {"key1": [1, 2, 3], "key2": [StrEnum.FOO, StrEnum.BAR]},
             {"key1": [1, 2, 3], "key2": [StrEnum.FOO, StrEnum.BAR]},
-            '({ ["key1"] : [1, 2, 3], ["key2"] : ["foo", "bar"] })',
+            {
+                "key1": [1, 2, 3],
+                "key2": ["foo", "bar"],
+            },
         ),
         ),
         (EnumWithPrefix.FOO, "prefix_foo"),
         (EnumWithPrefix.FOO, "prefix_foo"),
-        ([EnumWithPrefix.FOO, EnumWithPrefix.BAR], '["prefix_foo", "prefix_bar"]'),
+        ([EnumWithPrefix.FOO, EnumWithPrefix.BAR], ["prefix_foo", "prefix_bar"]),
         (
         (
             {"key1": EnumWithPrefix.FOO, "key2": EnumWithPrefix.BAR},
             {"key1": EnumWithPrefix.FOO, "key2": EnumWithPrefix.BAR},
-            '({ ["key1"] : "prefix_foo", ["key2"] : "prefix_bar" })',
+            {
+                "key1": "prefix_foo",
+                "key2": "prefix_bar",
+            },
         ),
         ),
         (TestEnum.FOO, "foo"),
         (TestEnum.FOO, "foo"),
-        ([TestEnum.FOO, TestEnum.BAR], '["foo", "bar"]'),
+        ([TestEnum.FOO, TestEnum.BAR], ["foo", "bar"]),
         (
         (
             {"key1": TestEnum.FOO, "key2": TestEnum.BAR},
             {"key1": TestEnum.FOO, "key2": TestEnum.BAR},
-            '({ ["key1"] : "foo", ["key2"] : "bar" })',
+            {
+                "key1": "foo",
+                "key2": "bar",
+            },
         ),
         ),
         (
         (
             BaseSubclass(ts=datetime.timedelta(1, 1, 1)),
             BaseSubclass(ts=datetime.timedelta(1, 1, 1)),
-            '({ ["ts"] : "1 day, 0:00:01.000001" })',
+            {
+                "ts": "1 day, 0:00:01.000001",
+            },
         ),
         ),
         (
         (
-            [1, LiteralVar.create("hi"), Var(_js_expr="bye")],
-            '[1, "hi", bye]',
+            [1, LiteralVar.create("hi")],
+            [1, "hi"],
         ),
         ),
         (
         (
-            (1, LiteralVar.create("hi"), Var(_js_expr="bye")),
-            '[1, "hi", bye]',
+            (1, LiteralVar.create("hi")),
+            [1, "hi"],
         ),
         ),
-        ({1: 2, 3: 4}, "({ [1] : 2, [3] : 4 })"),
+        ({1: 2, 3: 4}, {1: 2, 3: 4}),
         (
         (
-            {1: LiteralVar.create("hi"), 3: Var(_js_expr="bye")},
-            '({ [1] : "hi", [3] : bye })',
+            {1: LiteralVar.create("hi")},
+            {1: "hi"},
         ),
         ),
         (datetime.datetime(2021, 1, 1, 1, 1, 1, 1), "2021-01-01 01:01:01.000001"),
         (datetime.datetime(2021, 1, 1, 1, 1, 1, 1), "2021-01-01 01:01:01.000001"),
         (datetime.date(2021, 1, 1), "2021-01-01"),
         (datetime.date(2021, 1, 1), "2021-01-01"),
@@ -172,7 +183,7 @@ class BaseSubclass(Base):
         (datetime.timedelta(1, 1, 1), "1 day, 0:00:01.000001"),
         (datetime.timedelta(1, 1, 1), "1 day, 0:00:01.000001"),
         (
         (
             [datetime.timedelta(1, 1, 1), datetime.timedelta(1, 1, 2)],
             [datetime.timedelta(1, 1, 1), datetime.timedelta(1, 1, 2)],
-            '["1 day, 0:00:01.000001", "1 day, 0:00:01.000002"]',
+            ["1 day, 0:00:01.000001", "1 day, 0:00:01.000002"],
         ),
         ),
         (Color(color="slate", shade=1), "var(--slate-1)"),
         (Color(color="slate", shade=1), "var(--slate-1)"),
         (Color(color="orange", shade=1, alpha=True), "var(--orange-a1)"),
         (Color(color="orange", shade=1, alpha=True), "var(--orange-a1)"),