ソースを参照

remove format_state and override behavior for bare (#3979)

* remove format_state and override behavior for bare

* pass the test cases

* only do one level of dicting dataclasses

* remove dict and replace list with set

* delete unnecessary serialize calls

* remove serialize for mutable proxy

* dang it darglint
Khaleel Al-Adhami 8 ヶ月 前
コミット
0ab161c119

+ 1 - 1
reflex/compiler/utils.py

@@ -155,7 +155,7 @@ def compile_state(state: Type[BaseState]) -> dict:
         initial_state = state(_reflex_internal_init=True).dict(
         initial_state = state(_reflex_internal_init=True).dict(
             initial=True, include_computed=False
             initial=True, include_computed=False
         )
         )
-    return format.format_state(initial_state)
+    return initial_state
 
 
 
 
 def _compile_client_storage_field(
 def _compile_client_storage_field(

+ 3 - 1
reflex/components/base/bare.py

@@ -7,7 +7,7 @@ from typing import Any, Iterator
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.components.tags import Tag
 from reflex.components.tags import Tag
 from reflex.components.tags.tagless import Tagless
 from reflex.components.tags.tagless import Tagless
-from reflex.vars.base import Var
+from reflex.vars import ArrayVar, BooleanVar, ObjectVar, Var
 
 
 
 
 class Bare(Component):
 class Bare(Component):
@@ -33,6 +33,8 @@ class Bare(Component):
 
 
     def _render(self) -> Tag:
     def _render(self) -> Tag:
         if isinstance(self.contents, Var):
         if isinstance(self.contents, Var):
+            if isinstance(self.contents, (BooleanVar, ObjectVar, ArrayVar)):
+                return Tagless(contents=f"{{{str(self.contents.to_string())}}}")
             return Tagless(contents=f"{{{str(self.contents)}}}")
             return Tagless(contents=f"{{{str(self.contents)}}}")
         return Tagless(contents=str(self.contents))
         return Tagless(contents=str(self.contents))
 
 

+ 1 - 2
reflex/middleware/hydrate_middleware.py

@@ -9,7 +9,6 @@ from reflex import constants
 from reflex.event import Event, get_hydrate_event
 from reflex.event import Event, get_hydrate_event
 from reflex.middleware.middleware import Middleware
 from reflex.middleware.middleware import Middleware
 from reflex.state import BaseState, StateUpdate
 from reflex.state import BaseState, StateUpdate
-from reflex.utils import format
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from reflex.app import App
     from reflex.app import App
@@ -43,7 +42,7 @@ class HydrateMiddleware(Middleware):
         setattr(state, constants.CompileVars.IS_HYDRATED, False)
         setattr(state, constants.CompileVars.IS_HYDRATED, False)
 
 
         # Get the initial state.
         # Get the initial state.
-        delta = format.format_state(state.dict())
+        delta = state.dict()
         # since a full dict was captured, clean any dirtiness
         # since a full dict was captured, clean any dirtiness
         state._clean()
         state._clean()
 
 

+ 6 - 15
reflex/state.py

@@ -73,7 +73,7 @@ from reflex.utils.exceptions import (
     LockExpiredError,
     LockExpiredError,
 )
 )
 from reflex.utils.exec import is_testing_env
 from reflex.utils.exec import is_testing_env
-from reflex.utils.serializers import SerializedType, serialize, serializer
+from reflex.utils.serializers import serializer
 from reflex.utils.types import override
 from reflex.utils.types import override
 from reflex.vars import VarData
 from reflex.vars import VarData
 
 
@@ -1790,9 +1790,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         for substate in self.dirty_substates.union(self._always_dirty_substates):
         for substate in self.dirty_substates.union(self._always_dirty_substates):
             delta.update(substates[substate].get_delta())
             delta.update(substates[substate].get_delta())
 
 
-        # Format the delta.
-        delta = format.format_state(delta)
-
         # Return the delta.
         # Return the delta.
         return delta
         return delta
 
 
@@ -2433,7 +2430,7 @@ class StateUpdate:
         Returns:
         Returns:
             The state update as a JSON string.
             The state update as a JSON string.
         """
         """
-        return format.json_dumps(dataclasses.asdict(self))
+        return format.json_dumps(self)
 
 
 
 
 class StateManager(Base, ABC):
 class StateManager(Base, ABC):
@@ -3660,22 +3657,16 @@ class MutableProxy(wrapt.ObjectProxy):
 
 
 
 
 @serializer
 @serializer
-def serialize_mutable_proxy(mp: MutableProxy) -> SerializedType:
-    """Serialize the wrapped value of a MutableProxy.
+def serialize_mutable_proxy(mp: MutableProxy):
+    """Return the wrapped value of a MutableProxy.
 
 
     Args:
     Args:
         mp: The MutableProxy to serialize.
         mp: The MutableProxy to serialize.
 
 
     Returns:
     Returns:
-        The serialized wrapped object.
-
-    Raises:
-        ValueError: when the wrapped object is not serializable.
+        The wrapped object.
     """
     """
-    value = serialize(mp.__wrapped__)
-    if value is None:
-        raise ValueError(f"Cannot serialize {type(mp.__wrapped__)}")
-    return value
+    return mp.__wrapped__
 
 
 
 
 class ImmutableMutableProxy(MutableProxy):
 class ImmutableMutableProxy(MutableProxy):

+ 1 - 43
reflex/utils/format.py

@@ -9,7 +9,7 @@ import re
 from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
 from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
 
 
 from reflex import constants
 from reflex import constants
-from reflex.utils import exceptions, types
+from reflex.utils import exceptions
 from reflex.utils.console import deprecate
 from reflex.utils.console import deprecate
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -624,48 +624,6 @@ def format_query_params(router_data: dict[str, Any]) -> dict[str, str]:
     return {k.replace("-", "_"): v for k, v in params.items()}
     return {k.replace("-", "_"): v for k, v in params.items()}
 
 
 
 
-def format_state(value: Any, key: Optional[str] = None) -> Any:
-    """Recursively format values in the given state.
-
-    Args:
-        value: The state to format.
-        key: The key associated with the value (optional).
-
-    Returns:
-        The formatted state.
-
-    Raises:
-        TypeError: If the given value is not a valid state.
-    """
-    from reflex.utils import serializers
-
-    # Handle dicts.
-    if isinstance(value, dict):
-        return {k: format_state(v, k) for k, v in value.items()}
-
-    # Handle lists, sets, typles.
-    if isinstance(value, types.StateIterBases):
-        return [format_state(v) for v in value]
-
-    # Return state vars as is.
-    if isinstance(value, types.StateBases):
-        return value
-
-    # Serialize the value.
-    serialized = serializers.serialize(value)
-    if serialized is not None:
-        return serialized
-
-    if key is None:
-        raise TypeError(
-            f"No JSON serializer found for var {value} of type {type(value)}."
-        )
-    else:
-        raise TypeError(
-            f"No JSON serializer found for State Var '{key}' of value {value} of type {type(value)}."
-        )
-
-
 def format_state_name(state_name: str) -> str:
 def format_state_name(state_name: str) -> str:
     """Format a state name, replacing dots with double underscore.
     """Format a state name, replacing dots with double underscore.
 
 

+ 7 - 46
reflex/utils/serializers.py

@@ -12,7 +12,6 @@ from pathlib import Path
 from typing import (
 from typing import (
     Any,
     Any,
     Callable,
     Callable,
-    Dict,
     List,
     List,
     Literal,
     Literal,
     Optional,
     Optional,
@@ -126,7 +125,8 @@ def serialize(
     # If there is no serializer, return None.
     # If there is no serializer, return None.
     if serializer is None:
     if serializer is None:
         if dataclasses.is_dataclass(value) and not isinstance(value, type):
         if dataclasses.is_dataclass(value) and not isinstance(value, type):
-            return serialize(dataclasses.asdict(value))
+            return {k.name: getattr(value, k.name) for k in dataclasses.fields(value)}
+
         if get_type:
         if get_type:
             return None, None
             return None, None
         return None
         return None
@@ -214,32 +214,6 @@ def serialize_type(value: type) -> str:
     return value.__name__
     return value.__name__
 
 
 
 
-@serializer
-def serialize_str(value: str) -> str:
-    """Serialize a string.
-
-    Args:
-        value: The string to serialize.
-
-    Returns:
-        The serialized string.
-    """
-    return value
-
-
-@serializer
-def serialize_primitive(value: Union[bool, int, float, None]):
-    """Serialize a primitive type.
-
-    Args:
-        value: The number/bool/None to serialize.
-
-    Returns:
-        The serialized number/bool/None.
-    """
-    return value
-
-
 @serializer
 @serializer
 def serialize_base(value: Base) -> dict:
 def serialize_base(value: Base) -> dict:
     """Serialize a Base instance.
     """Serialize a Base instance.
@@ -250,33 +224,20 @@ def serialize_base(value: Base) -> dict:
     Returns:
     Returns:
         The serialized Base.
         The serialized Base.
     """
     """
-    return {k: serialize(v) for k, v in value.dict().items() if not callable(v)}
+    return {k: v for k, v in value.dict().items() if not callable(v)}
 
 
 
 
 @serializer
 @serializer
-def serialize_list(value: Union[List, Tuple, Set]) -> list:
-    """Serialize a list to a JSON string.
+def serialize_set(value: Set) -> list:
+    """Serialize a set to a JSON serializable list.
 
 
     Args:
     Args:
-        value: The list to serialize.
+        value: The set to serialize.
 
 
     Returns:
     Returns:
         The serialized list.
         The serialized list.
     """
     """
-    return [serialize(item) for item in value]
-
-
-@serializer
-def serialize_dict(prop: Dict[str, Any]) -> dict:
-    """Serialize a dictionary to a JSON string.
-
-    Args:
-        prop: The dictionary to serialize.
-
-    Returns:
-        The serialized dictionary.
-    """
-    return {k: serialize(v) for k, v in prop.items()}
+    return list(value)
 
 
 
 
 @serializer(to=str)
 @serializer(to=str)

+ 1 - 1
reflex/vars/base.py

@@ -1141,7 +1141,7 @@ def serialize_literal(value: LiteralVar):
     Returns:
     Returns:
         The serialized Literal.
         The serialized Literal.
     """
     """
-    return serializers.serialize(value._var_value)
+    return value._var_value
 
 
 
 
 P = ParamSpec("P")
 P = ParamSpec("P")

+ 2 - 2
tests/integration/test_var_operations.py

@@ -793,8 +793,8 @@ def test_var_operations(driver, var_operations: AppHarness):
         ("foreach_list_ix", "1\n2"),
         ("foreach_list_ix", "1\n2"),
         ("foreach_list_nested", "1\n1\n2"),
         ("foreach_list_nested", "1\n1\n2"),
         # rx.memo component with state
         # rx.memo component with state
-        ("memo_comp", "1210"),
-        ("memo_comp_nested", "345"),
+        ("memo_comp", "[1,2]10"),
+        ("memo_comp_nested", "[3,4]5"),
         # foreach in a match
         # foreach in a match
         ("foreach_in_match", "first\nsecond\nthird"),
         ("foreach_in_match", "first\nsecond\nthird"),
     ]
     ]

+ 1 - 2
tests/units/test_app.py

@@ -1,6 +1,5 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
-import dataclasses
 import functools
 import functools
 import io
 import io
 import json
 import json
@@ -1053,7 +1052,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
                     f"comp_{arg_name}": exp_val,
                     f"comp_{arg_name}": exp_val,
                     constants.CompileVars.IS_HYDRATED: False,
                     constants.CompileVars.IS_HYDRATED: False,
                     # "side_effect_counter": exp_index,
                     # "side_effect_counter": exp_index,
-                    "router": dataclasses.asdict(exp_router),
+                    "router": exp_router,
                 }
                 }
             },
             },
             events=[
             events=[

+ 2 - 1
tests/units/utils/test_format.py

@@ -1,6 +1,7 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import datetime
 import datetime
+import json
 from typing import Any, List
 from typing import Any, List
 
 
 import plotly.graph_objects as go
 import plotly.graph_objects as go
@@ -621,7 +622,7 @@ def test_format_state(input, output):
         input: The state to format.
         input: The state to format.
         output: The expected formatted state.
         output: The expected formatted state.
     """
     """
-    assert format.format_state(input) == output
+    assert json.loads(format.json_dumps(input)) == json.loads(format.json_dumps(output))
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(

+ 5 - 3
tests/units/utils/test_serializers.py

@@ -1,19 +1,21 @@
 import datetime
 import datetime
+import json
 from enum import Enum
 from enum import Enum
 from pathlib import Path
 from pathlib import Path
-from typing import Any, Dict, Type
+from typing import Any, Type
 
 
 import pytest
 import pytest
 
 
 from reflex.base import Base
 from reflex.base import Base
 from reflex.components.core.colors import Color
 from reflex.components.core.colors import Color
 from reflex.utils import serializers
 from reflex.utils import serializers
+from reflex.utils.format import json_dumps
 from reflex.vars.base import LiteralVar
 from reflex.vars.base import LiteralVar
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "type_,expected",
     "type_,expected",
-    [(str, True), (dict, True), (Dict[int, int], True), (Enum, True)],
+    [(Enum, True)],
 )
 )
 def test_has_serializer(type_: Type, expected: bool):
 def test_has_serializer(type_: Type, expected: bool):
     """Test that has_serializer returns the correct value.
     """Test that has_serializer returns the correct value.
@@ -198,7 +200,7 @@ def test_serialize(value: Any, expected: str):
         value: The value to serialize.
         value: The value to serialize.
         expected: The expected result.
         expected: The expected result.
     """
     """
-    assert serializers.serialize(value) == expected
+    assert json.loads(json_dumps(value)) == json.loads(json_dumps(expected))
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(