Bladeren bron

Strict type checking for indexing with vars (#1333)

Elijah Ahianyo 1 jaar geleden
bovenliggende
commit
40953d05ac
2 gewijzigde bestanden met toevoegingen van 146 en 8 verwijderingen
  1. 19 3
      reflex/vars.py
  2. 127 5
      tests/test_var.py

+ 19 - 3
reflex/vars.py

@@ -206,7 +206,8 @@ class Var(ABC):
         ):
         ):
             if self.type_ == Any:
             if self.type_ == Any:
                 raise TypeError(
                 raise TypeError(
-                    f"Could not index into var of type Any. (If you are trying to index into a state var, add the correct type annotation to the var.)"
+                    f"Could not index into var of type Any. (If you are trying to index into a state var, "
+                    f"add the correct type annotation to the var.)"
                 )
                 )
             raise TypeError(
             raise TypeError(
                 f"Var {self.name} of type {self.type_} does not support indexing."
                 f"Var {self.name} of type {self.type_} does not support indexing."
@@ -222,8 +223,12 @@ class Var(ABC):
         # Handle list/tuple/str indexing.
         # Handle list/tuple/str indexing.
         if types._issubclass(self.type_, Union[List, Tuple, str]):
         if types._issubclass(self.type_, Union[List, Tuple, str]):
             # List/Tuple/String indices must be ints, slices, or vars.
             # List/Tuple/String indices must be ints, slices, or vars.
-            if not isinstance(i, types.get_args(Union[int, slice, Var])):
-                raise TypeError("Index must be an integer.")
+            if (
+                not isinstance(i, types.get_args(Union[int, slice, Var]))
+                or isinstance(i, Var)
+                and not i.type_ == int
+            ):
+                raise TypeError("Index must be an integer or an integer var.")
 
 
             # Handle slices first.
             # Handle slices first.
             if isinstance(i, slice):
             if isinstance(i, slice):
@@ -253,6 +258,17 @@ class Var(ABC):
             )
             )
 
 
         # Dictionary / dataframe indexing.
         # Dictionary / dataframe indexing.
+        # Tuples are currently not supported as indexes.
+        if (
+            (types._issubclass(self.type_, Dict) or types.is_dataframe(self.type_))
+            and not isinstance(i, types.get_args(Union[int, str, float, Var]))
+        ) or (
+            isinstance(i, Var)
+            and not types._issubclass(i.type_, types.get_args(Union[int, str, float]))
+        ):
+            raise TypeError(
+                "Index must be one of the following types: int, str, int or str Var"
+            )
         # Get the type of the indexed var.
         # Get the type of the indexed var.
         if isinstance(i, str):
         if isinstance(i, str):
             i = format.wrap(i, '"')
             i = format.wrap(i, '"')

+ 127 - 5
tests/test_var.py

@@ -1,8 +1,9 @@
 import typing
 import typing
-from typing import Dict, List, Tuple
+from typing import Dict, List, Set, Tuple
 
 
 import cloudpickle
 import cloudpickle
 import pytest
 import pytest
+from pandas import DataFrame
 
 
 from reflex.base import Base
 from reflex.base import Base
 from reflex.state import State
 from reflex.state import State
@@ -293,11 +294,54 @@ def test_var_indexing_lists(var):
     # Test negative indexing.
     # Test negative indexing.
     assert str(var[-1]) == f"{{{var.name}.at(-1)}}"
     assert str(var[-1]) == f"{{{var.name}.at(-1)}}"
 
 
-    # Test non-integer indexing raises an error.
-    with pytest.raises(TypeError):
-        var["a"]
+
+@pytest.mark.parametrize(
+    "var, index",
+    [
+        (BaseVar(name="lst", type_=List[int]), [1, 2]),
+        (BaseVar(name="lst", type_=List[int]), {"name": "dict"}),
+        (BaseVar(name="lst", type_=List[int]), {"set"}),
+        (
+            BaseVar(name="lst", type_=List[int]),
+            (
+                1,
+                2,
+            ),
+        ),
+        (BaseVar(name="lst", type_=List[int]), 1.5),
+        (BaseVar(name="lst", type_=List[int]), "str"),
+        (BaseVar(name="lst", type_=List[int]), BaseVar(name="string_var", type_=str)),
+        (BaseVar(name="lst", type_=List[int]), BaseVar(name="float_var", type_=float)),
+        (
+            BaseVar(name="lst", type_=List[int]),
+            BaseVar(name="list_var", type_=List[int]),
+        ),
+        (BaseVar(name="lst", type_=List[int]), BaseVar(name="set_var", type_=Set[str])),
+        (
+            BaseVar(name="lst", type_=List[int]),
+            BaseVar(name="dict_var", type_=Dict[str, str]),
+        ),
+        (BaseVar(name="str", type_=str), [1, 2]),
+        (BaseVar(name="lst", type_=str), {"name": "dict"}),
+        (BaseVar(name="lst", type_=str), {"set"}),
+        (BaseVar(name="lst", type_=str), BaseVar(name="string_var", type_=str)),
+        (BaseVar(name="lst", type_=str), BaseVar(name="float_var", type_=float)),
+        (BaseVar(name="str", type_=Tuple[str]), [1, 2]),
+        (BaseVar(name="lst", type_=Tuple[str]), {"name": "dict"}),
+        (BaseVar(name="lst", type_=Tuple[str]), {"set"}),
+        (BaseVar(name="lst", type_=Tuple[str]), BaseVar(name="string_var", type_=str)),
+        (BaseVar(name="lst", type_=Tuple[str]), BaseVar(name="float_var", type_=float)),
+    ],
+)
+def test_var_unsupported_indexing_lists(var, index):
+    """Test unsupported indexing throws a type error.
+
+    Args:
+        var: The base var.
+        index: The base var index.
+    """
     with pytest.raises(TypeError):
     with pytest.raises(TypeError):
-        var[1.5]
+        var[index]
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
@@ -328,6 +372,84 @@ def test_dict_indexing():
     assert str(dct["asdf"]) == '{dct["asdf"]}'
     assert str(dct["asdf"]) == '{dct["asdf"]}'
 
 
 
 
+@pytest.mark.parametrize(
+    "var, index",
+    [
+        (
+            BaseVar(name="dict", type_=Dict[str, str]),
+            [1, 2],
+        ),
+        (
+            BaseVar(name="dict", type_=Dict[str, str]),
+            {"name": "dict"},
+        ),
+        (
+            BaseVar(name="dict", type_=Dict[str, str]),
+            {"set"},
+        ),
+        (
+            BaseVar(name="dict", type_=Dict[str, str]),
+            (
+                1,
+                2,
+            ),
+        ),
+        (
+            BaseVar(name="lst", type_=Dict[str, str]),
+            BaseVar(name="list_var", type_=List[int]),
+        ),
+        (
+            BaseVar(name="lst", type_=Dict[str, str]),
+            BaseVar(name="set_var", type_=Set[str]),
+        ),
+        (
+            BaseVar(name="lst", type_=Dict[str, str]),
+            BaseVar(name="dict_var", type_=Dict[str, str]),
+        ),
+        (
+            BaseVar(name="df", type_=DataFrame),
+            [1, 2],
+        ),
+        (
+            BaseVar(name="df", type_=DataFrame),
+            {"name": "dict"},
+        ),
+        (
+            BaseVar(name="df", type_=DataFrame),
+            {"set"},
+        ),
+        (
+            BaseVar(name="df", type_=DataFrame),
+            (
+                1,
+                2,
+            ),
+        ),
+        (
+            BaseVar(name="df", type_=DataFrame),
+            BaseVar(name="list_var", type_=List[int]),
+        ),
+        (
+            BaseVar(name="df", type_=DataFrame),
+            BaseVar(name="set_var", type_=Set[str]),
+        ),
+        (
+            BaseVar(name="df", type_=DataFrame),
+            BaseVar(name="dict_var", type_=Dict[str, str]),
+        ),
+    ],
+)
+def test_var_unsupported_indexing_dicts(var, index):
+    """Test unsupported indexing throws a type error.
+
+    Args:
+        var: The base var.
+        index: The base var index.
+    """
+    with pytest.raises(TypeError):
+        var[index]
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "fixture,full_name",
     "fixture,full_name",
     [
     [