فهرست منبع

fix serialization as a whole for list/dict/Base containing custom items to serialize (#1984)

Thomas Brandého 1 سال پیش
والد
کامیت
df09c716c6
5فایلهای تغییر یافته به همراه78 افزوده شده و 6 حذف شده
  1. 3 1
      reflex/base.py
  2. 1 1
      reflex/utils/format.py
  3. 17 4
      reflex/utils/serializers.py
  4. 9 0
      tests/utils/test_format.py
  5. 48 0
      tests/utils/test_serializers.py

+ 3 - 1
reflex/base.py

@@ -30,7 +30,9 @@ class Base(pydantic.BaseModel):
         Returns:
             The object as a json string.
         """
-        return self.__config__.json_dumps(self.dict(), default=list)
+        from reflex.utils.serializers import serialize
+
+        return self.__config__.json_dumps(self.dict(), default=serialize)
 
     def set(self, **kwargs):
         """Set multiple fields and return the object.

+ 1 - 1
reflex/utils/format.py

@@ -597,7 +597,7 @@ def json_dumps(obj: Any) -> str:
     Returns:
         A string
     """
-    return json.dumps(obj, ensure_ascii=False, default=list)
+    return json.dumps(obj, ensure_ascii=False, default=serialize)
 
 
 def unwrap_vars(value: str) -> str:

+ 17 - 4
reflex/utils/serializers.py

@@ -93,7 +93,7 @@ def get_serializer(type_: Type) -> Serializer | None:
         return serializer
 
     # If the type is not registered, check if it is a subclass of a registered type.
-    for registered_type, serializer in SERIALIZERS.items():
+    for registered_type, serializer in reversed(SERIALIZERS.items()):
         if types._issubclass(type_, registered_type):
             return serializer
 
@@ -127,18 +127,31 @@ def serialize_str(value: str) -> str:
 
 
 @serializer
-def serialize_primitive(value: Union[bool, int, float, Base, None]) -> str:
+def serialize_primitive(value: Union[bool, int, float, None]) -> str:
     """Serialize a primitive type.
 
     Args:
-        value: The number to serialize.
+        value: The number/bool/None to serialize.
 
     Returns:
-        The serialized number.
+        The serialized number/bool/None.
     """
     return format.json_dumps(value)
 
 
+@serializer
+def serialize_base(value: Base) -> str:
+    """Serialize a Base instance.
+
+    Args:
+        value : The Base to serialize.
+
+    Returns:
+        The serialized Base.
+    """
+    return value.json()
+
+
 @serializer
 def serialize_list(value: Union[List, Tuple, Set]) -> str:
     """Serialize a list to a JSON string.

+ 9 - 0
tests/utils/test_format.py

@@ -1,3 +1,4 @@
+import datetime
 from typing import Any
 
 import pytest
@@ -604,6 +605,14 @@ def test_format_library_name(input: str, output: str):
         ([1, 2, 3], "[1, 2, 3]"),
         ({}, "{}"),
         ({"k1": False, "k2": True}, '{"k1": false, "k2": true}'),
+        (
+            [datetime.timedelta(1, 1, 1), datetime.timedelta(1, 1, 2)],
+            '["1 day, 0:00:01.000001", "1 day, 0:00:01.000002"]',
+        ),
+        (
+            {"key1": datetime.timedelta(1, 1, 1), "key2": datetime.timedelta(1, 1, 2)},
+            '{"key1": "1 day, 0:00:01.000001", "key2": "1 day, 0:00:01.000002"}',
+        ),
     ],
 )
 def test_json_dumps(input, output):

+ 48 - 0
tests/utils/test_serializers.py

@@ -1,8 +1,10 @@
 import datetime
+from enum import Enum
 from typing import Any, Dict, List, Type
 
 import pytest
 
+from reflex.base import Base
 from reflex.utils import serializers
 from reflex.vars import Var
 
@@ -93,6 +95,31 @@ def test_add_serializer():
     assert not serializers.has_serializer(Foo)
 
 
+class StrEnum(str, Enum):
+    """An enum also inheriting from str."""
+
+    FOO = "foo"
+    BAR = "bar"
+
+
+class EnumWithPrefix(Enum):
+    """An enum with a serializer adding a prefix."""
+
+    FOO = "foo"
+    BAR = "bar"
+
+
+@serializers.serializer
+def serialize_EnumWithPrefix(enum: EnumWithPrefix) -> str:
+    return "prefix_" + enum.value
+
+
+class BaseSubclass(Base):
+    """A class inheriting from Base for testing."""
+
+    ts: datetime.timedelta = datetime.timedelta(1, 1, 1)
+
+
 @pytest.mark.parametrize(
     "value,expected",
     [
@@ -104,6 +131,23 @@ def test_add_serializer():
         (None, "null"),
         ([1, 2, 3], "[1, 2, 3]"),
         ([1, "2", 3.0], '[1, "2", 3.0]'),
+        ([{"key": 1}, {"key": 2}], '[{"key": 1}, {"key": 2}]'),
+        (StrEnum.FOO, "foo"),
+        ([StrEnum.FOO, StrEnum.BAR], '["foo", "bar"]'),
+        (
+            {"key1": [1, 2, 3], "key2": [StrEnum.FOO, StrEnum.BAR]},
+            '{"key1": [1, 2, 3], "key2": ["foo", "bar"]}',
+        ),
+        (EnumWithPrefix.FOO, "prefix_foo"),
+        ([EnumWithPrefix.FOO, EnumWithPrefix.BAR], '["prefix_foo", "prefix_bar"]'),
+        (
+            {"key1": EnumWithPrefix.FOO, "key2": EnumWithPrefix.BAR},
+            '{"key1": "prefix_foo", "key2": "prefix_bar"}',
+        ),
+        (
+            BaseSubclass(ts=datetime.timedelta(1, 1, 1)),
+            '{"ts": "1 day, 0:00:01.000001"}',
+        ),
         (
             [1, Var.create_safe("hi"), Var.create_safe("bye", _var_is_local=False)],
             '[1, "hi", bye]',
@@ -121,6 +165,10 @@ def test_add_serializer():
         (datetime.date(2021, 1, 1), "2021-01-01"),
         (datetime.time(1, 1, 1, 1), "01:01:01.000001"),
         (datetime.timedelta(1, 1, 1), "1 day, 0:00:01.000001"),
+        (
+            [datetime.timedelta(1, 1, 1), datetime.timedelta(1, 1, 2)],
+            '["1 day, 0:00:01.000001", "1 day, 0:00:01.000002"]',
+        ),
     ],
 )
 def test_serialize(value: Any, expected: str):