ソースを参照

[REF-3009] type transforming serializers (#3227)

* wip type transforming serializers

* old python sucks

* typing fixups

* Expose the `to` parameter on `rx.serializer` for type conversion

Serializers can also return a tuple of `(serialized_value, type)`, if both ways
are specified, then the returned value MUST match the `to` parameter.

When initializing a new rx.Var, if `_var_is_string` is not specified and the serializer returns a `str` type, then mark `_var_is_string=True` to indicate that the Var should be treated like a string literal.

Include datetime, color, types, and paths as "serializing to str" type.

Avoid other changes at this point to reduce fallout from this change:

  Notably, the `serialize_str` function does NOT cast to `str`, which
  would cause existing code to treat all Var initialized with a str as a
  str literal even though this was NOT the default before.

Update test cases to accomodate these changes.

* Raise deprecation warning for rx.Var.create with string literal

In the future, we will treat strings as string literals in the JS code. To get
a Var that is not treated like a string, pass _var_is_string=False.

This will allow our serializers to automatically identify cast string literals
with less special cases (and the special cases need to be explicitly
identified).

* Add test case for mismatched serialized types

* fix old python

* Remove serializer returning a tuple feature

Simplify the logic; instead of making a wrapper function that returns
a tuple, just save the type conversions in a separate global.

* Reset the LRU cache when adding new serializers

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
benedikt-bartscher 11 ヶ月 前
コミット
e42d4ed9ef

+ 98 - 14
reflex/utils/serializers.py

@@ -2,13 +2,27 @@
 
 from __future__ import annotations
 
+import functools
 import json
 import types as builtin_types
 import warnings
 from datetime import date, datetime, time, timedelta
 from enum import Enum
 from pathlib import Path
-from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union, get_type_hints
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    List,
+    Literal,
+    Optional,
+    Set,
+    Tuple,
+    Type,
+    Union,
+    get_type_hints,
+    overload,
+)
 
 from reflex.base import Base
 from reflex.constants.colors import Color, format_color
@@ -17,15 +31,24 @@ from reflex.utils import exceptions, types
 # Mapping from type to a serializer.
 # The serializer should convert the type to a JSON object.
 SerializedType = Union[str, bool, int, float, list, dict]
+
+
 Serializer = Callable[[Type], SerializedType]
+
+
 SERIALIZERS: dict[Type, Serializer] = {}
+SERIALIZER_TYPES: dict[Type, Type] = {}
 
 
-def serializer(fn: Serializer) -> Serializer:
+def serializer(
+    fn: Serializer | None = None,
+    to: Type | None = None,
+) -> Serializer:
     """Decorator to add a serializer for a given type.
 
     Args:
         fn: The function to decorate.
+        to: The type returned by the serializer. If this is `str`, then any Var created from this type will be treated as a string.
 
     Returns:
         The decorated function.
@@ -33,8 +56,9 @@ def serializer(fn: Serializer) -> Serializer:
     Raises:
         ValueError: If the function does not take a single argument.
     """
-    # Get the global serializers.
-    global SERIALIZERS
+    if fn is None:
+        # If the function is not provided, return a partial that acts as a decorator.
+        return functools.partial(serializer, to=to)  # type: ignore
 
     # Check the type hints to get the type of the argument.
     type_hints = get_type_hints(fn)
@@ -54,18 +78,47 @@ def serializer(fn: Serializer) -> Serializer:
             f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
         )
 
+    # Apply type transformation if requested
+    if to is not None:
+        SERIALIZER_TYPES[type_] = to
+        get_serializer_type.cache_clear()
+
     # Register the serializer.
     SERIALIZERS[type_] = fn
+    get_serializer.cache_clear()
 
     # Return the function.
     return fn
 
 
-def serialize(value: Any) -> SerializedType | None:
+@overload
+def serialize(
+    value: Any, get_type: Literal[True]
+) -> Tuple[Optional[SerializedType], Optional[types.GenericType]]:
+    ...
+
+
+@overload
+def serialize(value: Any, get_type: Literal[False]) -> Optional[SerializedType]:
+    ...
+
+
+@overload
+def serialize(value: Any) -> Optional[SerializedType]:
+    ...
+
+
+def serialize(
+    value: Any, get_type: bool = False
+) -> Union[
+    Optional[SerializedType],
+    Tuple[Optional[SerializedType], Optional[types.GenericType]],
+]:
     """Serialize the value to a JSON string.
 
     Args:
         value: The value to serialize.
+        get_type: Whether to return the type of the serialized value.
 
     Returns:
         The serialized value, or None if a serializer is not found.
@@ -75,13 +128,22 @@ def serialize(value: Any) -> SerializedType | None:
 
     # If there is no serializer, return None.
     if serializer is None:
+        if get_type:
+            return None, None
         return None
 
     # Serialize the value.
-    return serializer(value)
+    serialized = serializer(value)
 
+    # Return the serialized value and the type.
+    if get_type:
+        return serialized, get_serializer_type(type(value))
+    else:
+        return serialized
 
-def get_serializer(type_: Type) -> Serializer | None:
+
+@functools.lru_cache
+def get_serializer(type_: Type) -> Optional[Serializer]:
     """Get the serializer for the type.
 
     Args:
@@ -90,8 +152,6 @@ def get_serializer(type_: Type) -> Serializer | None:
     Returns:
         The serializer for the type, or None if there is no serializer.
     """
-    global SERIALIZERS
-
     # First, check if the type is registered.
     serializer = SERIALIZERS.get(type_)
     if serializer is not None:
@@ -106,6 +166,30 @@ def get_serializer(type_: Type) -> Serializer | None:
     return None
 
 
+@functools.lru_cache
+def get_serializer_type(type_: Type) -> Optional[Type]:
+    """Get the converted type for the type after serializing.
+
+    Args:
+        type_: The type to get the serializer type for.
+
+    Returns:
+        The serialized type for the type, or None if there is no type conversion registered.
+    """
+    # First, check if the type is registered.
+    serializer = SERIALIZER_TYPES.get(type_)
+    if serializer is not None:
+        return serializer
+
+    # If the type is not registered, check if it is a subclass of a registered type.
+    for registered_type, serializer in reversed(SERIALIZER_TYPES.items()):
+        if types._issubclass(type_, registered_type):
+            return serializer
+
+    # If there is no serializer, return None.
+    return None
+
+
 def has_serializer(type_: Type) -> bool:
     """Check if there is a serializer for the type.
 
@@ -118,7 +202,7 @@ def has_serializer(type_: Type) -> bool:
     return get_serializer(type_) is not None
 
 
-@serializer
+@serializer(to=str)
 def serialize_type(value: type) -> str:
     """Serialize a python type.
 
@@ -226,7 +310,7 @@ def serialize_dict(prop: Dict[str, Any]) -> str:
     return format.unwrap_vars(fprop)
 
 
-@serializer
+@serializer(to=str)
 def serialize_datetime(dt: Union[date, datetime, time, timedelta]) -> str:
     """Serialize a datetime to a JSON string.
 
@@ -239,8 +323,8 @@ def serialize_datetime(dt: Union[date, datetime, time, timedelta]) -> str:
     return str(dt)
 
 
-@serializer
-def serialize_path(path: Path):
+@serializer(to=str)
+def serialize_path(path: Path) -> str:
     """Serialize a pathlib.Path to a JSON string.
 
     Args:
@@ -265,7 +349,7 @@ def serialize_enum(en: Enum) -> str:
     return en.value
 
 
-@serializer
+@serializer(to=str)
 def serialize_color(color: Color) -> str:
     """Serialize a color.
 

+ 31 - 10
reflex/vars.py

@@ -347,7 +347,7 @@ class Var:
         cls,
         value: Any,
         _var_is_local: bool = True,
-        _var_is_string: bool = False,
+        _var_is_string: bool | None = None,
         _var_data: Optional[VarData] = None,
     ) -> Var | None:
         """Create a var from a value.
@@ -380,18 +380,39 @@ class Var:
 
         # Try to serialize the value.
         type_ = type(value)
-        name = value if type_ in types.JSONType else serializers.serialize(value)
+        if type_ in types.JSONType:
+            name = value
+        else:
+            name, serialized_type = serializers.serialize(value, get_type=True)
+            if (
+                serialized_type is not None
+                and _var_is_string is None
+                and issubclass(serialized_type, str)
+            ):
+                _var_is_string = True
         if name is None:
             raise VarTypeError(
                 f"No JSON serializer found for var {value} of type {type_}."
             )
         name = name if isinstance(name, str) else format.json_dumps(name)
 
+        if _var_is_string is None and type_ is str:
+            console.deprecate(
+                feature_name="Creating a Var from a string without specifying _var_is_string",
+                reason=(
+                    "Specify _var_is_string=False to create a Var that is not a string literal. "
+                    "In the future, creating a Var from a string will be treated as a string literal "
+                    "by default."
+                ),
+                deprecation_version="0.5.4",
+                removal_version="0.6.0",
+            )
+
         return BaseVar(
             _var_name=name,
             _var_type=type_,
             _var_is_local=_var_is_local,
-            _var_is_string=_var_is_string,
+            _var_is_string=_var_is_string if _var_is_string is not None else False,
             _var_data=_var_data,
         )
 
@@ -400,7 +421,7 @@ class Var:
         cls,
         value: Any,
         _var_is_local: bool = True,
-        _var_is_string: bool = False,
+        _var_is_string: bool | None = None,
         _var_data: Optional[VarData] = None,
     ) -> Var:
         """Create a var from a value, asserting that it is not None.
@@ -847,19 +868,19 @@ class Var:
                 if invoke_fn:
                     # invoke the function on left operand.
                     operation_name = (
-                        f"{left_operand_full_name}.{fn}({right_operand_full_name})"
-                    )  # type: ignore
+                        f"{left_operand_full_name}.{fn}({right_operand_full_name})"  # type: ignore
+                    )
                 else:
                     # pass the operands as arguments to the function.
                     operation_name = (
-                        f"{left_operand_full_name} {op} {right_operand_full_name}"
-                    )  # type: ignore
+                        f"{left_operand_full_name} {op} {right_operand_full_name}"  # type: ignore
+                    )
                     operation_name = f"{fn}({operation_name})"
             else:
                 # apply operator to operands (left operand <operator> right_operand)
                 operation_name = (
-                    f"{left_operand_full_name} {op} {right_operand_full_name}"
-                )  # type: ignore
+                    f"{left_operand_full_name} {op} {right_operand_full_name}"  # type: ignore
+                )
                 operation_name = format.wrap(operation_name, "(")
         else:
             # apply operator to left operand (<operator> left_operand)

+ 2 - 2
reflex/vars.pyi

@@ -51,11 +51,11 @@ class Var:
     _var_data: VarData | None = None
     @classmethod
     def create(
-        cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None,
+        cls, value: Any, _var_is_local: bool = True, _var_is_string: bool | None = None, _var_data: VarData | None = None,
     ) -> Optional[Var]: ...
     @classmethod
     def create_safe(
-        cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None,
+        cls, value: Any, _var_is_local: bool = True, _var_is_string: bool | None = None, _var_data: VarData | None = None,
     ) -> Var: ...
     @classmethod
     def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...

+ 10 - 4
tests/components/core/test_colors.py

@@ -2,6 +2,7 @@ import pytest
 
 import reflex as rx
 from reflex.components.datadisplay.code import CodeBlock
+from reflex.constants.colors import Color
 from reflex.vars import Var
 
 
@@ -50,7 +51,12 @@ def create_color_var(color):
     ],
 )
 def test_color(color, expected):
-    assert str(color) == expected
+    assert color._var_is_string or color._var_type is str
+    assert color._var_full_name == expected
+    if color._var_type == Color:
+        assert str(color) == f"{{`{expected}`}}"
+    else:
+        assert str(color) == expected
 
 
 @pytest.mark.parametrize(
@@ -96,9 +102,9 @@ def test_color_with_conditionals(cond_var, expected):
 @pytest.mark.parametrize(
     "color, expected",
     [
-        (create_color_var(rx.color("red")), "var(--red-7)"),
-        (create_color_var(rx.color("green", shade=1)), "var(--green-1)"),
-        (create_color_var(rx.color("blue", alpha=True)), "var(--blue-a7)"),
+        (create_color_var(rx.color("red")), "{`var(--red-7)`}"),
+        (create_color_var(rx.color("green", shade=1)), "{`var(--green-1)`}"),
+        (create_color_var(rx.color("blue", alpha=True)), "{`var(--blue-a7)`}"),
         ("red", "red"),
         ("green", "green"),
         ("blue", "blue"),

+ 40 - 0
tests/utils/test_serializers.py

@@ -1,5 +1,6 @@
 import datetime
 from enum import Enum
+from pathlib import Path
 from typing import Any, Dict, List, Type
 
 import pytest
@@ -90,6 +91,9 @@ def test_add_serializer():
 
     # Remove the serializer.
     serializers.SERIALIZERS.pop(Foo)
+    # LRU cache will still have the serializer, so we need to clear it.
+    assert serializers.has_serializer(Foo)
+    serializers.get_serializer.cache_clear()
     assert not serializers.has_serializer(Foo)
 
 
@@ -194,3 +198,39 @@ def test_serialize(value: Any, expected: str):
         expected: The expected result.
     """
     assert serializers.serialize(value) == expected
+
+
+@pytest.mark.parametrize(
+    "value,expected,exp_var_is_string",
+    [
+        ("test", "test", False),
+        (1, "1", False),
+        (1.0, "1.0", False),
+        (True, "true", False),
+        (False, "false", False),
+        ([1, 2, 3], "[1, 2, 3]", False),
+        ([{"key": 1}, {"key": 2}], '[{"key": 1}, {"key": 2}]', False),
+        (StrEnum.FOO, "foo", False),
+        ([StrEnum.FOO, StrEnum.BAR], '["foo", "bar"]', False),
+        (
+            BaseSubclass(ts=datetime.timedelta(1, 1, 1)),
+            '{"ts": "1 day, 0:00:01.000001"}',
+            False,
+        ),
+        (datetime.datetime(2021, 1, 1, 1, 1, 1, 1), "2021-01-01 01:01:01.000001", True),
+        (Color(color="slate", shade=1), "var(--slate-1)", True),
+        (BaseSubclass, "BaseSubclass", True),
+        (Path("."), ".", True),
+    ],
+)
+def test_serialize_var_to_str(value: Any, expected: str, exp_var_is_string: bool):
+    """Test that serialize with `to=str` passed to a Var is marked with _var_is_string.
+
+    Args:
+        value: The value to serialize.
+        expected: The expected result.
+        exp_var_is_string: The expected value of _var_is_string.
+    """
+    v = Var.create_safe(value)
+    assert v._var_full_name == expected
+    assert v._var_is_string == exp_var_is_string