Ver Fonte

Do not auto-determine generic args if already supplied (#4148)

* add failing test for figure_out_type

* do not auto-determine generic args if already supplied

* move has_args to utils.types, add tests for it
benedikt-bartscher há 7 meses atrás
pai
commit
0889276e24

+ 21 - 0
reflex/utils/types.py

@@ -220,6 +220,27 @@ def is_literal(cls: GenericType) -> bool:
     return get_origin(cls) is Literal
 
 
+def has_args(cls) -> bool:
+    """Check if the class has generic parameters.
+
+    Args:
+        cls: The class to check.
+
+    Returns:
+        Whether the class has generic
+    """
+    if get_args(cls):
+        return True
+
+    # Check if the class inherits from a generic class (using __orig_bases__)
+    if hasattr(cls, "__orig_bases__"):
+        for base in cls.__orig_bases__:
+            if get_args(base):
+                return True
+
+    return False
+
+
 def is_optional(cls: GenericType) -> bool:
     """Check if a class is an Optional.
 

+ 6 - 3
reflex/vars/base.py

@@ -56,7 +56,7 @@ from reflex.utils.imports import (
     ParsedImportDict,
     parse_imports,
 )
-from reflex.utils.types import GenericType, Self, get_origin
+from reflex.utils.types import GenericType, Self, get_origin, has_args
 
 if TYPE_CHECKING:
     from reflex.state import BaseState
@@ -1266,6 +1266,11 @@ def figure_out_type(value: Any) -> types.GenericType:
     Returns:
         The type of the value.
     """
+    if isinstance(value, Var):
+        return value._var_type
+    type_ = type(value)
+    if has_args(type_):
+        return type_
     if isinstance(value, list):
         return List[unionize(*(figure_out_type(v) for v in value))]
     if isinstance(value, set):
@@ -1277,8 +1282,6 @@ def figure_out_type(value: Any) -> types.GenericType:
             unionize(*(figure_out_type(k) for k in value)),
             unionize(*(figure_out_type(v) for v in value.values())),
         ]
-    if isinstance(value, Var):
-        return value._var_type
     return type(value)
 
 

+ 46 - 1
tests/units/utils/test_types.py

@@ -1,4 +1,4 @@
-from typing import Any, List, Literal, Tuple, Union
+from typing import Any, Dict, List, Literal, Tuple, Union
 
 import pytest
 
@@ -45,3 +45,48 @@ def test_issubclass(
     cls: types.GenericType, cls_check: types.GenericType, expected: bool
 ) -> None:
     assert types._issubclass(cls, cls_check) == expected
+
+
+class CustomDict(dict[str, str]):
+    """A custom dict with generic arguments."""
+
+    pass
+
+
+class ChildCustomDict(CustomDict):
+    """A child of CustomDict."""
+
+    pass
+
+
+class GenericDict(dict):
+    """A generic dict with no generic arguments."""
+
+    pass
+
+
+class ChildGenericDict(GenericDict):
+    """A child of GenericDict."""
+
+    pass
+
+
+@pytest.mark.parametrize(
+    "cls,expected",
+    [
+        (int, False),
+        (str, False),
+        (float, False),
+        (Tuple[int], True),
+        (List[int], True),
+        (Union[int, str], True),
+        (Union[str, int], True),
+        (Dict[str, int], True),
+        (CustomDict, True),
+        (ChildCustomDict, True),
+        (GenericDict, False),
+        (ChildGenericDict, False),
+    ],
+)
+def test_has_args(cls, expected: bool) -> None:
+    assert types.has_args(cls) == expected

+ 28 - 0
tests/units/vars/test_base.py

@@ -5,6 +5,30 @@ import pytest
 from reflex.vars.base import figure_out_type
 
 
+class CustomDict(dict[str, str]):
+    """A custom dict with generic arguments."""
+
+    pass
+
+
+class ChildCustomDict(CustomDict):
+    """A child of CustomDict."""
+
+    pass
+
+
+class GenericDict(dict):
+    """A generic dict with no generic arguments."""
+
+    pass
+
+
+class ChildGenericDict(GenericDict):
+    """A child of GenericDict."""
+
+    pass
+
+
 @pytest.mark.parametrize(
     ("value", "expected"),
     [
@@ -15,6 +39,10 @@ from reflex.vars.base import figure_out_type
         ([1, 2.0, "a"], List[Union[int, float, str]]),
         ({"a": 1, "b": 2}, Dict[str, int]),
         ({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]),
+        (CustomDict(), CustomDict),
+        (ChildCustomDict(), ChildCustomDict),
+        (GenericDict({1: 1}), Dict[int, int]),
+        (ChildGenericDict({1: 1}), Dict[int, int]),
     ],
 )
 def test_figure_out_type(value, expected):