Parcourir la source

remove format_state and override behavior for bare (#3979)

* remove format_state and override behavior for bare

* pass the test cases

* only do one level of dicting dataclasses

* remove dict and replace list with set

* delete unnecessary serialize calls

* remove serialize for mutable proxy

* dang it darglint
Khaleel Al-Adhami il y a 8 mois
Parent
commit
0ab161c119

+ 1 - 1
reflex/compiler/utils.py

@@ -155,7 +155,7 @@ def compile_state(state: Type[BaseState]) -> dict:
         initial_state = state(_reflex_internal_init=True).dict(
             initial=True, include_computed=False
         )
-    return format.format_state(initial_state)
+    return initial_state
 
 
 def _compile_client_storage_field(

+ 3 - 1
reflex/components/base/bare.py

@@ -7,7 +7,7 @@ from typing import Any, Iterator
 from reflex.components.component import Component
 from reflex.components.tags import Tag
 from reflex.components.tags.tagless import Tagless
-from reflex.vars.base import Var
+from reflex.vars import ArrayVar, BooleanVar, ObjectVar, Var
 
 
 class Bare(Component):
@@ -33,6 +33,8 @@ class Bare(Component):
 
     def _render(self) -> Tag:
         if isinstance(self.contents, Var):
+            if isinstance(self.contents, (BooleanVar, ObjectVar, ArrayVar)):
+                return Tagless(contents=f"{{{str(self.contents.to_string())}}}")
             return Tagless(contents=f"{{{str(self.contents)}}}")
         return Tagless(contents=str(self.contents))
 

+ 1 - 2
reflex/middleware/hydrate_middleware.py

@@ -9,7 +9,6 @@ from reflex import constants
 from reflex.event import Event, get_hydrate_event
 from reflex.middleware.middleware import Middleware
 from reflex.state import BaseState, StateUpdate
-from reflex.utils import format
 
 if TYPE_CHECKING:
     from reflex.app import App
@@ -43,7 +42,7 @@ class HydrateMiddleware(Middleware):
         setattr(state, constants.CompileVars.IS_HYDRATED, False)
 
         # Get the initial state.
-        delta = format.format_state(state.dict())
+        delta = state.dict()
         # since a full dict was captured, clean any dirtiness
         state._clean()
 

+ 6 - 15
reflex/state.py

@@ -73,7 +73,7 @@ from reflex.utils.exceptions import (
     LockExpiredError,
 )
 from reflex.utils.exec import is_testing_env
-from reflex.utils.serializers import SerializedType, serialize, serializer
+from reflex.utils.serializers import serializer
 from reflex.utils.types import override
 from reflex.vars import VarData
 
@@ -1790,9 +1790,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         for substate in self.dirty_substates.union(self._always_dirty_substates):
             delta.update(substates[substate].get_delta())
 
-        # Format the delta.
-        delta = format.format_state(delta)
-
         # Return the delta.
         return delta
 
@@ -2433,7 +2430,7 @@ class StateUpdate:
         Returns:
             The state update as a JSON string.
         """
-        return format.json_dumps(dataclasses.asdict(self))
+        return format.json_dumps(self)
 
 
 class StateManager(Base, ABC):
@@ -3660,22 +3657,16 @@ class MutableProxy(wrapt.ObjectProxy):
 
 
 @serializer
-def serialize_mutable_proxy(mp: MutableProxy) -> SerializedType:
-    """Serialize the wrapped value of a MutableProxy.
+def serialize_mutable_proxy(mp: MutableProxy):
+    """Return the wrapped value of a MutableProxy.
 
     Args:
         mp: The MutableProxy to serialize.
 
     Returns:
-        The serialized wrapped object.
-
-    Raises:
-        ValueError: when the wrapped object is not serializable.
+        The wrapped object.
     """
-    value = serialize(mp.__wrapped__)
-    if value is None:
-        raise ValueError(f"Cannot serialize {type(mp.__wrapped__)}")
-    return value
+    return mp.__wrapped__
 
 
 class ImmutableMutableProxy(MutableProxy):

+ 1 - 43
reflex/utils/format.py

@@ -9,7 +9,7 @@ import re
 from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
 
 from reflex import constants
-from reflex.utils import exceptions, types
+from reflex.utils import exceptions
 from reflex.utils.console import deprecate
 
 if TYPE_CHECKING:
@@ -624,48 +624,6 @@ def format_query_params(router_data: dict[str, Any]) -> dict[str, str]:
     return {k.replace("-", "_"): v for k, v in params.items()}
 
 
-def format_state(value: Any, key: Optional[str] = None) -> Any:
-    """Recursively format values in the given state.
-
-    Args:
-        value: The state to format.
-        key: The key associated with the value (optional).
-
-    Returns:
-        The formatted state.
-
-    Raises:
-        TypeError: If the given value is not a valid state.
-    """
-    from reflex.utils import serializers
-
-    # Handle dicts.
-    if isinstance(value, dict):
-        return {k: format_state(v, k) for k, v in value.items()}
-
-    # Handle lists, sets, typles.
-    if isinstance(value, types.StateIterBases):
-        return [format_state(v) for v in value]
-
-    # Return state vars as is.
-    if isinstance(value, types.StateBases):
-        return value
-
-    # Serialize the value.
-    serialized = serializers.serialize(value)
-    if serialized is not None:
-        return serialized
-
-    if key is None:
-        raise TypeError(
-            f"No JSON serializer found for var {value} of type {type(value)}."
-        )
-    else:
-        raise TypeError(
-            f"No JSON serializer found for State Var '{key}' of value {value} of type {type(value)}."
-        )
-
-
 def format_state_name(state_name: str) -> str:
     """Format a state name, replacing dots with double underscore.
 

+ 7 - 46
reflex/utils/serializers.py

@@ -12,7 +12,6 @@ from pathlib import Path
 from typing import (
     Any,
     Callable,
-    Dict,
     List,
     Literal,
     Optional,
@@ -126,7 +125,8 @@ def serialize(
     # If there is no serializer, return None.
     if serializer is None:
         if dataclasses.is_dataclass(value) and not isinstance(value, type):
-            return serialize(dataclasses.asdict(value))
+            return {k.name: getattr(value, k.name) for k in dataclasses.fields(value)}
+
         if get_type:
             return None, None
         return None
@@ -214,32 +214,6 @@ def serialize_type(value: type) -> str:
     return value.__name__
 
 
-@serializer
-def serialize_str(value: str) -> str:
-    """Serialize a string.
-
-    Args:
-        value: The string to serialize.
-
-    Returns:
-        The serialized string.
-    """
-    return value
-
-
-@serializer
-def serialize_primitive(value: Union[bool, int, float, None]):
-    """Serialize a primitive type.
-
-    Args:
-        value: The number/bool/None to serialize.
-
-    Returns:
-        The serialized number/bool/None.
-    """
-    return value
-
-
 @serializer
 def serialize_base(value: Base) -> dict:
     """Serialize a Base instance.
@@ -250,33 +224,20 @@ def serialize_base(value: Base) -> dict:
     Returns:
         The serialized Base.
     """
-    return {k: serialize(v) for k, v in value.dict().items() if not callable(v)}
+    return {k: v for k, v in value.dict().items() if not callable(v)}
 
 
 @serializer
-def serialize_list(value: Union[List, Tuple, Set]) -> list:
-    """Serialize a list to a JSON string.
+def serialize_set(value: Set) -> list:
+    """Serialize a set to a JSON serializable list.
 
     Args:
-        value: The list to serialize.
+        value: The set to serialize.
 
     Returns:
         The serialized list.
     """
-    return [serialize(item) for item in value]
-
-
-@serializer
-def serialize_dict(prop: Dict[str, Any]) -> dict:
-    """Serialize a dictionary to a JSON string.
-
-    Args:
-        prop: The dictionary to serialize.
-
-    Returns:
-        The serialized dictionary.
-    """
-    return {k: serialize(v) for k, v in prop.items()}
+    return list(value)
 
 
 @serializer(to=str)

+ 1 - 1
reflex/vars/base.py

@@ -1141,7 +1141,7 @@ def serialize_literal(value: LiteralVar):
     Returns:
         The serialized Literal.
     """
-    return serializers.serialize(value._var_value)
+    return value._var_value
 
 
 P = ParamSpec("P")

+ 2 - 2
tests/integration/test_var_operations.py

@@ -793,8 +793,8 @@ def test_var_operations(driver, var_operations: AppHarness):
         ("foreach_list_ix", "1\n2"),
         ("foreach_list_nested", "1\n1\n2"),
         # rx.memo component with state
-        ("memo_comp", "1210"),
-        ("memo_comp_nested", "345"),
+        ("memo_comp", "[1,2]10"),
+        ("memo_comp_nested", "[3,4]5"),
         # foreach in a match
         ("foreach_in_match", "first\nsecond\nthird"),
     ]

+ 1 - 2
tests/units/test_app.py

@@ -1,6 +1,5 @@
 from __future__ import annotations
 
-import dataclasses
 import functools
 import io
 import json
@@ -1053,7 +1052,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
                     f"comp_{arg_name}": exp_val,
                     constants.CompileVars.IS_HYDRATED: False,
                     # "side_effect_counter": exp_index,
-                    "router": dataclasses.asdict(exp_router),
+                    "router": exp_router,
                 }
             },
             events=[

+ 2 - 1
tests/units/utils/test_format.py

@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 import datetime
+import json
 from typing import Any, List
 
 import plotly.graph_objects as go
@@ -621,7 +622,7 @@ def test_format_state(input, output):
         input: The state to format.
         output: The expected formatted state.
     """
-    assert format.format_state(input) == output
+    assert json.loads(format.json_dumps(input)) == json.loads(format.json_dumps(output))
 
 
 @pytest.mark.parametrize(

+ 5 - 3
tests/units/utils/test_serializers.py

@@ -1,19 +1,21 @@
 import datetime
+import json
 from enum import Enum
 from pathlib import Path
-from typing import Any, Dict, Type
+from typing import Any, Type
 
 import pytest
 
 from reflex.base import Base
 from reflex.components.core.colors import Color
 from reflex.utils import serializers
+from reflex.utils.format import json_dumps
 from reflex.vars.base import LiteralVar
 
 
 @pytest.mark.parametrize(
     "type_,expected",
-    [(str, True), (dict, True), (Dict[int, int], True), (Enum, True)],
+    [(Enum, True)],
 )
 def test_has_serializer(type_: Type, expected: bool):
     """Test that has_serializer returns the correct value.
@@ -198,7 +200,7 @@ def test_serialize(value: Any, expected: str):
         value: The value to serialize.
         expected: The expected result.
     """
-    assert serializers.serialize(value) == expected
+    assert json.loads(json_dumps(value)) == json.loads(json_dumps(expected))
 
 
 @pytest.mark.parametrize(