浏览代码

improve typing for serializer decorator (#4317)

* improve typing for serializer decorator

* use wrapped logic

* dang it darglint

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
benedikt-bartscher 6 月之前
父节点
当前提交
8fd5c9f200
共有 1 个文件被更改,包括 58 次插入40 次删除
  1. 58 40
      reflex/utils/serializers.py

+ 58 - 40
reflex/utils/serializers.py

@@ -18,6 +18,7 @@ from typing import (
     Set,
     Tuple,
     Type,
+    TypeVar,
     Union,
     get_type_hints,
     overload,
@@ -32,17 +33,33 @@ from reflex.utils import types
 SerializedType = Union[str, bool, int, float, list, dict, None]
 
 
-Serializer = Callable[[Type], SerializedType]
+Serializer = Callable[[Any], SerializedType]
 
 
 SERIALIZERS: dict[Type, Serializer] = {}
 SERIALIZER_TYPES: dict[Type, Type] = {}
 
+SERIALIZED_FUNCTION = TypeVar("SERIALIZED_FUNCTION", bound=Serializer)
 
+
+@overload
+def serializer(
+    fn: None = None,
+    to: Type[SerializedType] | None = None,
+) -> Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]: ...
+
+
+@overload
 def serializer(
-    fn: Serializer | None = None,
-    to: Type | None = None,
-) -> Serializer:
+    fn: SERIALIZED_FUNCTION,
+    to: Type[SerializedType] | None = None,
+) -> SERIALIZED_FUNCTION: ...
+
+
+def serializer(
+    fn: SERIALIZED_FUNCTION | None = None,
+    to: Any = None,
+) -> SERIALIZED_FUNCTION | Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]:
     """Decorator to add a serializer for a given type.
 
     Args:
@@ -51,43 +68,44 @@ def serializer(
 
     Returns:
         The decorated function.
-
-    Raises:
-        ValueError: If the function does not take a single argument.
     """
-    if fn is None:
-        # If the function is not provided, return a partial that acts as a decorator.
-        return functools.partial(serializer, to=to)  # type: ignore
-
-    # Check the type hints to get the type of the argument.
-    type_hints = get_type_hints(fn)
-    args = [arg for arg in type_hints if arg != "return"]
-
-    # Make sure the function takes a single argument.
-    if len(args) != 1:
-        raise ValueError("Serializer must take a single argument.")
-
-    # Get the type of the argument.
-    type_ = type_hints[args[0]]
-
-    # Make sure the type is not already registered.
-    registered_fn = SERIALIZERS.get(type_)
-    if registered_fn is not None and registered_fn != fn:
-        raise ValueError(
-            f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
-        )
-
-    # Apply type transformation if requested
-    if to is not None or ((to := type_hints.get("return")) is not None):
-        SERIALIZER_TYPES[type_] = to
-        get_serializer_type.cache_clear()
-
-    # Register the serializer.
-    SERIALIZERS[type_] = fn
-    get_serializer.cache_clear()
-
-    # Return the function.
-    return fn
+
+    def wrapper(fn: SERIALIZED_FUNCTION) -> SERIALIZED_FUNCTION:
+        # Check the type hints to get the type of the argument.
+        type_hints = get_type_hints(fn)
+        args = [arg for arg in type_hints if arg != "return"]
+
+        # Make sure the function takes a single argument.
+        if len(args) != 1:
+            raise ValueError("Serializer must take a single argument.")
+
+        # Get the type of the argument.
+        type_ = type_hints[args[0]]
+
+        # Make sure the type is not already registered.
+        registered_fn = SERIALIZERS.get(type_)
+        if registered_fn is not None and registered_fn != fn:
+            raise ValueError(
+                f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
+            )
+
+        to_type = to or type_hints.get("return")
+
+        # Apply type transformation if requested
+        if to_type:
+            SERIALIZER_TYPES[type_] = to_type
+            get_serializer_type.cache_clear()
+
+        # Register the serializer.
+        SERIALIZERS[type_] = fn
+        get_serializer.cache_clear()
+
+        # Return the function.
+        return fn
+
+    if fn is not None:
+        return wrapper(fn)
+    return wrapper
 
 
 @overload