Browse Source

use getattr when given str in getitem (#4761)

* use getattr when given str in getitem

* stronger checking and tests

* switch ordering

* use safe issubclass

* calculate origin differently
Khaleel Al-Adhami 3 months ago
parent
commit
1651289485
2 changed files with 32 additions and 4 deletions
  1. 13 4
      reflex/vars/object.py
  2. 19 0
      tests/integration/test_var_operations.py

+ 13 - 4
reflex/vars/object.py

@@ -22,7 +22,12 @@ from typing_extensions import is_typeddict
 
 
 from reflex.utils import types
 from reflex.utils import types
 from reflex.utils.exceptions import VarAttributeError
 from reflex.utils.exceptions import VarAttributeError
-from reflex.utils.types import GenericType, get_attribute_access_type, get_origin
+from reflex.utils.types import (
+    GenericType,
+    get_attribute_access_type,
+    get_origin,
+    safe_issubclass,
+)
 
 
 from .base import (
 from .base import (
     CachedVarOperation,
     CachedVarOperation,
@@ -187,10 +192,14 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
         Returns:
         Returns:
             The item from the object.
             The item from the object.
         """
         """
+        from .sequence import LiteralStringVar
+
         if not isinstance(key, (StringVar, str, int, NumberVar)) or (
         if not isinstance(key, (StringVar, str, int, NumberVar)) or (
             isinstance(key, NumberVar) and key._is_strict_float()
             isinstance(key, NumberVar) and key._is_strict_float()
         ):
         ):
             raise_unsupported_operand_types("[]", (type(self), type(key)))
             raise_unsupported_operand_types("[]", (type(self), type(key)))
+        if isinstance(key, str) and isinstance(Var.create(key), LiteralStringVar):
+            return self.__getattr__(key)
         return ObjectItemOperation.create(self, key).guess_type()
         return ObjectItemOperation.create(self, key).guess_type()
 
 
     # NoReturn is used here to catch when key value is Any
     # NoReturn is used here to catch when key value is Any
@@ -260,12 +269,12 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
         if types.is_optional(var_type):
         if types.is_optional(var_type):
             var_type = get_args(var_type)[0]
             var_type = get_args(var_type)[0]
 
 
-        fixed_type = var_type if isclass(var_type) else get_origin(var_type)
+        fixed_type = get_origin(var_type) or var_type
 
 
         if (
         if (
-            (isclass(fixed_type) and not issubclass(fixed_type, Mapping))
+            is_typeddict(fixed_type)
+            or (isclass(fixed_type) and not safe_issubclass(fixed_type, Mapping))
             or (fixed_type in types.UnionTypes)
             or (fixed_type in types.UnionTypes)
-            or is_typeddict(fixed_type)
         ):
         ):
             attribute_type = get_attribute_access_type(var_type, name)
             attribute_type = get_attribute_access_type(var_type, name)
             if attribute_type is None:
             if attribute_type is None:

+ 19 - 0
tests/integration/test_var_operations.py

@@ -10,6 +10,8 @@ from reflex.testing import AppHarness
 
 
 def VarOperations():
 def VarOperations():
     """App with var operations."""
     """App with var operations."""
+    from typing import TypedDict
+
     import reflex as rx
     import reflex as rx
     from reflex.vars.base import LiteralVar
     from reflex.vars.base import LiteralVar
     from reflex.vars.sequence import ArrayVar
     from reflex.vars.sequence import ArrayVar
@@ -17,6 +19,10 @@ def VarOperations():
     class Object(rx.Base):
     class Object(rx.Base):
         name: str = "hello"
         name: str = "hello"
 
 
+    class Person(TypedDict):
+        name: str
+        age: int
+
     class VarOperationState(rx.State):
     class VarOperationState(rx.State):
         int_var1: rx.Field[int] = rx.field(10)
         int_var1: rx.Field[int] = rx.field(10)
         int_var2: rx.Field[int] = rx.field(5)
         int_var2: rx.Field[int] = rx.field(5)
@@ -34,6 +40,9 @@ def VarOperations():
         dict1: rx.Field[dict[int, int]] = rx.field({1: 2})
         dict1: rx.Field[dict[int, int]] = rx.field({1: 2})
         dict2: rx.Field[dict[int, int]] = rx.field({3: 4})
         dict2: rx.Field[dict[int, int]] = rx.field({3: 4})
         html_str: rx.Field[str] = rx.field("<div>hello</div>")
         html_str: rx.Field[str] = rx.field("<div>hello</div>")
+        people: rx.Field[list[Person]] = rx.field(
+            [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]
+        )
 
 
     app = rx.App(_state=rx.State)
     app = rx.App(_state=rx.State)
 
 
@@ -619,6 +628,15 @@ def VarOperations():
                 ),
                 ),
                 id="dict_in_foreach3",
                 id="dict_in_foreach3",
             ),
             ),
+            rx.box(
+                rx.foreach(
+                    VarOperationState.people,
+                    lambda person: rx.text.span(
+                        "Hello " + person["name"], person["age"] + 3
+                    ),
+                ),
+                id="typed_dict_in_foreach",
+            ),
         )
         )
 
 
 
 
@@ -826,6 +844,7 @@ def test_var_operations(driver, var_operations: AppHarness):
         ("dict_in_foreach1", "a1b2"),
         ("dict_in_foreach1", "a1b2"),
         ("dict_in_foreach2", "12"),
         ("dict_in_foreach2", "12"),
         ("dict_in_foreach3", "1234"),
         ("dict_in_foreach3", "1234"),
+        ("typed_dict_in_foreach", "Hello Alice33Hello Bob28"),
     ]
     ]
 
 
     for tag, expected in tests:
     for tag, expected in tests: