Sfoglia il codice sorgente

[REF-3009] type transforming serializers (#3227)

* wip type transforming serializers

* old python sucks

* typing fixups

* Expose the `to` parameter on `rx.serializer` for type conversion

Serializers can also return a tuple of `(serialized_value, type)`, if both ways
are specified, then the returned value MUST match the `to` parameter.

When initializing a new rx.Var, if `_var_is_string` is not specified and the serializer returns a `str` type, then mark `_var_is_string=True` to indicate that the Var should be treated like a string literal.

Include datetime, color, types, and paths as "serializing to str" type.

Avoid other changes at this point to reduce fallout from this change:

  Notably, the `serialize_str` function does NOT cast to `str`, which
  would cause existing code to treat all Var initialized with a str as a
  str literal even though this was NOT the default before.

Update test cases to accomodate these changes.

* Raise deprecation warning for rx.Var.create with string literal

In the future, we will treat strings as string literals in the JS code. To get
a Var that is not treated like a string, pass _var_is_string=False.

This will allow our serializers to automatically identify cast string literals
with less special cases (and the special cases need to be explicitly
identified).

* Add test case for mismatched serialized types

* fix old python

* Remove serializer returning a tuple feature

Simplify the logic; instead of making a wrapper function that returns
a tuple, just save the type conversions in a separate global.

* Reset the LRU cache when adding new serializers

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
benedikt-bartscher 11 mesi fa
parent
commit
e42d4ed9ef

+ 98 - 14
reflex/utils/serializers.py

@@ -2,13 +2,27 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
+import functools
 import json
 import json
 import types as builtin_types
 import types as builtin_types
 import warnings
 import warnings
 from datetime import date, datetime, time, timedelta
 from datetime import date, datetime, time, timedelta
 from enum import Enum
 from enum import Enum
 from pathlib import Path
 from pathlib import Path
-from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union, get_type_hints
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    List,
+    Literal,
+    Optional,
+    Set,
+    Tuple,
+    Type,
+    Union,
+    get_type_hints,
+    overload,
+)
 
 
 from reflex.base import Base
 from reflex.base import Base
 from reflex.constants.colors import Color, format_color
 from reflex.constants.colors import Color, format_color
@@ -17,15 +31,24 @@ from reflex.utils import exceptions, types
 # Mapping from type to a serializer.
 # Mapping from type to a serializer.
 # The serializer should convert the type to a JSON object.
 # The serializer should convert the type to a JSON object.
 SerializedType = Union[str, bool, int, float, list, dict]
 SerializedType = Union[str, bool, int, float, list, dict]
+
+
 Serializer = Callable[[Type], SerializedType]
 Serializer = Callable[[Type], SerializedType]
+
+
 SERIALIZERS: dict[Type, Serializer] = {}
 SERIALIZERS: dict[Type, Serializer] = {}
+SERIALIZER_TYPES: dict[Type, Type] = {}
 
 
 
 
-def serializer(fn: Serializer) -> Serializer:
+def serializer(
+    fn: Serializer | None = None,
+    to: Type | None = None,
+) -> Serializer:
     """Decorator to add a serializer for a given type.
     """Decorator to add a serializer for a given type.
 
 
     Args:
     Args:
         fn: The function to decorate.
         fn: The function to decorate.
+        to: The type returned by the serializer. If this is `str`, then any Var created from this type will be treated as a string.
 
 
     Returns:
     Returns:
         The decorated function.
         The decorated function.
@@ -33,8 +56,9 @@ def serializer(fn: Serializer) -> Serializer:
     Raises:
     Raises:
         ValueError: If the function does not take a single argument.
         ValueError: If the function does not take a single argument.
     """
     """
-    # Get the global serializers.
-    global SERIALIZERS
+    if fn is None:
+        # If the function is not provided, return a partial that acts as a decorator.
+        return functools.partial(serializer, to=to)  # type: ignore
 
 
     # Check the type hints to get the type of the argument.
     # Check the type hints to get the type of the argument.
     type_hints = get_type_hints(fn)
     type_hints = get_type_hints(fn)
@@ -54,18 +78,47 @@ def serializer(fn: Serializer) -> Serializer:
             f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
             f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
         )
         )
 
 
+    # Apply type transformation if requested
+    if to is not None:
+        SERIALIZER_TYPES[type_] = to
+        get_serializer_type.cache_clear()
+
     # Register the serializer.
     # Register the serializer.
     SERIALIZERS[type_] = fn
     SERIALIZERS[type_] = fn
+    get_serializer.cache_clear()
 
 
     # Return the function.
     # Return the function.
     return fn
     return fn
 
 
 
 
-def serialize(value: Any) -> SerializedType | None:
+@overload
+def serialize(
+    value: Any, get_type: Literal[True]
+) -> Tuple[Optional[SerializedType], Optional[types.GenericType]]:
+    ...
+
+
+@overload
+def serialize(value: Any, get_type: Literal[False]) -> Optional[SerializedType]:
+    ...
+
+
+@overload
+def serialize(value: Any) -> Optional[SerializedType]:
+    ...
+
+
+def serialize(
+    value: Any, get_type: bool = False
+) -> Union[
+    Optional[SerializedType],
+    Tuple[Optional[SerializedType], Optional[types.GenericType]],
+]:
     """Serialize the value to a JSON string.
     """Serialize the value to a JSON string.
 
 
     Args:
     Args:
         value: The value to serialize.
         value: The value to serialize.
+        get_type: Whether to return the type of the serialized value.
 
 
     Returns:
     Returns:
         The serialized value, or None if a serializer is not found.
         The serialized value, or None if a serializer is not found.
@@ -75,13 +128,22 @@ def serialize(value: Any) -> SerializedType | None:
 
 
     # If there is no serializer, return None.
     # If there is no serializer, return None.
     if serializer is None:
     if serializer is None:
+        if get_type:
+            return None, None
         return None
         return None
 
 
     # Serialize the value.
     # Serialize the value.
-    return serializer(value)
+    serialized = serializer(value)
 
 
+    # Return the serialized value and the type.
+    if get_type:
+        return serialized, get_serializer_type(type(value))
+    else:
+        return serialized
 
 
-def get_serializer(type_: Type) -> Serializer | None:
+
+@functools.lru_cache
+def get_serializer(type_: Type) -> Optional[Serializer]:
     """Get the serializer for the type.
     """Get the serializer for the type.
 
 
     Args:
     Args:
@@ -90,8 +152,6 @@ def get_serializer(type_: Type) -> Serializer | None:
     Returns:
     Returns:
         The serializer for the type, or None if there is no serializer.
         The serializer for the type, or None if there is no serializer.
     """
     """
-    global SERIALIZERS
-
     # First, check if the type is registered.
     # First, check if the type is registered.
     serializer = SERIALIZERS.get(type_)
     serializer = SERIALIZERS.get(type_)
     if serializer is not None:
     if serializer is not None:
@@ -106,6 +166,30 @@ def get_serializer(type_: Type) -> Serializer | None:
     return None
     return None
 
 
 
 
+@functools.lru_cache
+def get_serializer_type(type_: Type) -> Optional[Type]:
+    """Get the converted type for the type after serializing.
+
+    Args:
+        type_: The type to get the serializer type for.
+
+    Returns:
+        The serialized type for the type, or None if there is no type conversion registered.
+    """
+    # First, check if the type is registered.
+    serializer = SERIALIZER_TYPES.get(type_)
+    if serializer is not None:
+        return serializer
+
+    # If the type is not registered, check if it is a subclass of a registered type.
+    for registered_type, serializer in reversed(SERIALIZER_TYPES.items()):
+        if types._issubclass(type_, registered_type):
+            return serializer
+
+    # If there is no serializer, return None.
+    return None
+
+
 def has_serializer(type_: Type) -> bool:
 def has_serializer(type_: Type) -> bool:
     """Check if there is a serializer for the type.
     """Check if there is a serializer for the type.
 
 
@@ -118,7 +202,7 @@ def has_serializer(type_: Type) -> bool:
     return get_serializer(type_) is not None
     return get_serializer(type_) is not None
 
 
 
 
-@serializer
+@serializer(to=str)
 def serialize_type(value: type) -> str:
 def serialize_type(value: type) -> str:
     """Serialize a python type.
     """Serialize a python type.
 
 
@@ -226,7 +310,7 @@ def serialize_dict(prop: Dict[str, Any]) -> str:
     return format.unwrap_vars(fprop)
     return format.unwrap_vars(fprop)
 
 
 
 
-@serializer
+@serializer(to=str)
 def serialize_datetime(dt: Union[date, datetime, time, timedelta]) -> str:
 def serialize_datetime(dt: Union[date, datetime, time, timedelta]) -> str:
     """Serialize a datetime to a JSON string.
     """Serialize a datetime to a JSON string.
 
 
@@ -239,8 +323,8 @@ def serialize_datetime(dt: Union[date, datetime, time, timedelta]) -> str:
     return str(dt)
     return str(dt)
 
 
 
 
-@serializer
-def serialize_path(path: Path):
+@serializer(to=str)
+def serialize_path(path: Path) -> str:
     """Serialize a pathlib.Path to a JSON string.
     """Serialize a pathlib.Path to a JSON string.
 
 
     Args:
     Args:
@@ -265,7 +349,7 @@ def serialize_enum(en: Enum) -> str:
     return en.value
     return en.value
 
 
 
 
-@serializer
+@serializer(to=str)
 def serialize_color(color: Color) -> str:
 def serialize_color(color: Color) -> str:
     """Serialize a color.
     """Serialize a color.
 
 

+ 31 - 10
reflex/vars.py

@@ -347,7 +347,7 @@ class Var:
         cls,
         cls,
         value: Any,
         value: Any,
         _var_is_local: bool = True,
         _var_is_local: bool = True,
-        _var_is_string: bool = False,
+        _var_is_string: bool | None = None,
         _var_data: Optional[VarData] = None,
         _var_data: Optional[VarData] = None,
     ) -> Var | None:
     ) -> Var | None:
         """Create a var from a value.
         """Create a var from a value.
@@ -380,18 +380,39 @@ class Var:
 
 
         # Try to serialize the value.
         # Try to serialize the value.
         type_ = type(value)
         type_ = type(value)
-        name = value if type_ in types.JSONType else serializers.serialize(value)
+        if type_ in types.JSONType:
+            name = value
+        else:
+            name, serialized_type = serializers.serialize(value, get_type=True)
+            if (
+                serialized_type is not None
+                and _var_is_string is None
+                and issubclass(serialized_type, str)
+            ):
+                _var_is_string = True
         if name is None:
         if name is None:
             raise VarTypeError(
             raise VarTypeError(
                 f"No JSON serializer found for var {value} of type {type_}."
                 f"No JSON serializer found for var {value} of type {type_}."
             )
             )
         name = name if isinstance(name, str) else format.json_dumps(name)
         name = name if isinstance(name, str) else format.json_dumps(name)
 
 
+        if _var_is_string is None and type_ is str:
+            console.deprecate(
+                feature_name="Creating a Var from a string without specifying _var_is_string",
+                reason=(
+                    "Specify _var_is_string=False to create a Var that is not a string literal. "
+                    "In the future, creating a Var from a string will be treated as a string literal "
+                    "by default."
+                ),
+                deprecation_version="0.5.4",
+                removal_version="0.6.0",
+            )
+
         return BaseVar(
         return BaseVar(
             _var_name=name,
             _var_name=name,
             _var_type=type_,
             _var_type=type_,
             _var_is_local=_var_is_local,
             _var_is_local=_var_is_local,
-            _var_is_string=_var_is_string,
+            _var_is_string=_var_is_string if _var_is_string is not None else False,
             _var_data=_var_data,
             _var_data=_var_data,
         )
         )
 
 
@@ -400,7 +421,7 @@ class Var:
         cls,
         cls,
         value: Any,
         value: Any,
         _var_is_local: bool = True,
         _var_is_local: bool = True,
-        _var_is_string: bool = False,
+        _var_is_string: bool | None = None,
         _var_data: Optional[VarData] = None,
         _var_data: Optional[VarData] = None,
     ) -> Var:
     ) -> Var:
         """Create a var from a value, asserting that it is not None.
         """Create a var from a value, asserting that it is not None.
@@ -847,19 +868,19 @@ class Var:
                 if invoke_fn:
                 if invoke_fn:
                     # invoke the function on left operand.
                     # invoke the function on left operand.
                     operation_name = (
                     operation_name = (
-                        f"{left_operand_full_name}.{fn}({right_operand_full_name})"
-                    )  # type: ignore
+                        f"{left_operand_full_name}.{fn}({right_operand_full_name})"  # type: ignore
+                    )
                 else:
                 else:
                     # pass the operands as arguments to the function.
                     # pass the operands as arguments to the function.
                     operation_name = (
                     operation_name = (
-                        f"{left_operand_full_name} {op} {right_operand_full_name}"
-                    )  # type: ignore
+                        f"{left_operand_full_name} {op} {right_operand_full_name}"  # type: ignore
+                    )
                     operation_name = f"{fn}({operation_name})"
                     operation_name = f"{fn}({operation_name})"
             else:
             else:
                 # apply operator to operands (left operand <operator> right_operand)
                 # apply operator to operands (left operand <operator> right_operand)
                 operation_name = (
                 operation_name = (
-                    f"{left_operand_full_name} {op} {right_operand_full_name}"
-                )  # type: ignore
+                    f"{left_operand_full_name} {op} {right_operand_full_name}"  # type: ignore
+                )
                 operation_name = format.wrap(operation_name, "(")
                 operation_name = format.wrap(operation_name, "(")
         else:
         else:
             # apply operator to left operand (<operator> left_operand)
             # apply operator to left operand (<operator> left_operand)

+ 2 - 2
reflex/vars.pyi

@@ -51,11 +51,11 @@ class Var:
     _var_data: VarData | None = None
     _var_data: VarData | None = None
     @classmethod
     @classmethod
     def create(
     def create(
-        cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None,
+        cls, value: Any, _var_is_local: bool = True, _var_is_string: bool | None = None, _var_data: VarData | None = None,
     ) -> Optional[Var]: ...
     ) -> Optional[Var]: ...
     @classmethod
     @classmethod
     def create_safe(
     def create_safe(
-        cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None,
+        cls, value: Any, _var_is_local: bool = True, _var_is_string: bool | None = None, _var_data: VarData | None = None,
     ) -> Var: ...
     ) -> Var: ...
     @classmethod
     @classmethod
     def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...
     def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...

+ 10 - 4
tests/components/core/test_colors.py

@@ -2,6 +2,7 @@ import pytest
 
 
 import reflex as rx
 import reflex as rx
 from reflex.components.datadisplay.code import CodeBlock
 from reflex.components.datadisplay.code import CodeBlock
+from reflex.constants.colors import Color
 from reflex.vars import Var
 from reflex.vars import Var
 
 
 
 
@@ -50,7 +51,12 @@ def create_color_var(color):
     ],
     ],
 )
 )
 def test_color(color, expected):
 def test_color(color, expected):
-    assert str(color) == expected
+    assert color._var_is_string or color._var_type is str
+    assert color._var_full_name == expected
+    if color._var_type == Color:
+        assert str(color) == f"{{`{expected}`}}"
+    else:
+        assert str(color) == expected
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
@@ -96,9 +102,9 @@ def test_color_with_conditionals(cond_var, expected):
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "color, expected",
     "color, expected",
     [
     [
-        (create_color_var(rx.color("red")), "var(--red-7)"),
-        (create_color_var(rx.color("green", shade=1)), "var(--green-1)"),
-        (create_color_var(rx.color("blue", alpha=True)), "var(--blue-a7)"),
+        (create_color_var(rx.color("red")), "{`var(--red-7)`}"),
+        (create_color_var(rx.color("green", shade=1)), "{`var(--green-1)`}"),
+        (create_color_var(rx.color("blue", alpha=True)), "{`var(--blue-a7)`}"),
         ("red", "red"),
         ("red", "red"),
         ("green", "green"),
         ("green", "green"),
         ("blue", "blue"),
         ("blue", "blue"),

+ 40 - 0
tests/utils/test_serializers.py

@@ -1,5 +1,6 @@
 import datetime
 import datetime
 from enum import Enum
 from enum import Enum
+from pathlib import Path
 from typing import Any, Dict, List, Type
 from typing import Any, Dict, List, Type
 
 
 import pytest
 import pytest
@@ -90,6 +91,9 @@ def test_add_serializer():
 
 
     # Remove the serializer.
     # Remove the serializer.
     serializers.SERIALIZERS.pop(Foo)
     serializers.SERIALIZERS.pop(Foo)
+    # LRU cache will still have the serializer, so we need to clear it.
+    assert serializers.has_serializer(Foo)
+    serializers.get_serializer.cache_clear()
     assert not serializers.has_serializer(Foo)
     assert not serializers.has_serializer(Foo)
 
 
 
 
@@ -194,3 +198,39 @@ def test_serialize(value: Any, expected: str):
         expected: The expected result.
         expected: The expected result.
     """
     """
     assert serializers.serialize(value) == expected
     assert serializers.serialize(value) == expected
+
+
+@pytest.mark.parametrize(
+    "value,expected,exp_var_is_string",
+    [
+        ("test", "test", False),
+        (1, "1", False),
+        (1.0, "1.0", False),
+        (True, "true", False),
+        (False, "false", False),
+        ([1, 2, 3], "[1, 2, 3]", False),
+        ([{"key": 1}, {"key": 2}], '[{"key": 1}, {"key": 2}]', False),
+        (StrEnum.FOO, "foo", False),
+        ([StrEnum.FOO, StrEnum.BAR], '["foo", "bar"]', False),
+        (
+            BaseSubclass(ts=datetime.timedelta(1, 1, 1)),
+            '{"ts": "1 day, 0:00:01.000001"}',
+            False,
+        ),
+        (datetime.datetime(2021, 1, 1, 1, 1, 1, 1), "2021-01-01 01:01:01.000001", True),
+        (Color(color="slate", shade=1), "var(--slate-1)", True),
+        (BaseSubclass, "BaseSubclass", True),
+        (Path("."), ".", True),
+    ],
+)
+def test_serialize_var_to_str(value: Any, expected: str, exp_var_is_string: bool):
+    """Test that serialize with `to=str` passed to a Var is marked with _var_is_string.
+
+    Args:
+        value: The value to serialize.
+        expected: The expected result.
+        exp_var_is_string: The expected value of _var_is_string.
+    """
+    v = Var.create_safe(value)
+    assert v._var_full_name == expected
+    assert v._var_is_string == exp_var_is_string