Explorar o código

don't treat vars as their types for setting state fields (#4861)

Khaleel Al-Adhami hai 2 meses
pai
achega
62b3076dc1

+ 1 - 1
reflex/components/component.py

@@ -188,7 +188,7 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool:
     Returns:
         Whether the object satisfies the type hint.
     """
-    return types._isinstance(obj, type_hint, nested=1)
+    return types._isinstance(obj, type_hint, nested=1, treat_var_as_type=True)
 
 
 def _components_from(

+ 3 - 3
reflex/components/tags/tag.py

@@ -6,7 +6,7 @@ import dataclasses
 from typing import Any, List, Mapping, Sequence
 
 from reflex.event import EventChain
-from reflex.utils import format, types
+from reflex.utils import format
 from reflex.vars.base import LiteralVar, Var
 
 
@@ -103,9 +103,9 @@ class Tag:
             {
                 format.to_camel_case(name, treat_hyphens_as_underscores=False): (
                     prop
-                    if types._isinstance(prop, (EventChain, Mapping))
+                    if isinstance(prop, (EventChain, Mapping))
                     else LiteralVar.create(prop)
-                )  # rx.color is always a string
+                )
                 for name, prop in kwargs.items()
                 if self.is_valid_prop(prop)
             }

+ 2 - 2
reflex/state.py

@@ -692,7 +692,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         def computed_var_func(state: Self):
             result = f(state)
 
-            if not _isinstance(result, of_type):
+            if not _isinstance(result, of_type, nested=1, treat_var_as_type=False):
                 console.warn(
                     f"Inline ComputedVar {f} expected type {of_type}, got {type(result)}. "
                     "You can specify expected type with `of_type` argument."
@@ -1353,7 +1353,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             field_type = _unwrap_field_type(field.outer_type_)
             if field.allow_none and not is_optional(field_type):
                 field_type = field_type | None
-            if not _isinstance(value, field_type):
+            if not _isinstance(value, field_type, nested=1, treat_var_as_type=False):
                 console.error(
                     f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}',"
                     f" but got '{value}' of type '{type(value)}'."

+ 51 - 11
reflex/utils/types.py

@@ -509,13 +509,16 @@ def does_obj_satisfy_typed_dict(obj: Any, cls: GenericType) -> bool:
     return required_keys.issubset(required_keys)
 
 
-def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool:
+def _isinstance(
+    obj: Any, cls: GenericType, *, nested: int = 0, treat_var_as_type: bool = True
+) -> bool:
     """Check if an object is an instance of a class.
 
     Args:
         obj: The object to check.
         cls: The class to check against.
         nested: How many levels deep to check.
+        treat_var_as_type: Whether to treat Var as the type it represents, i.e. _var_type.
 
     Returns:
         Whether the object is an instance of the class.
@@ -528,15 +531,20 @@ def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool:
     if cls is Var:
         return isinstance(obj, Var)
     if isinstance(obj, LiteralVar):
-        return _isinstance(obj._var_value, cls, nested=nested)
+        return treat_var_as_type and _isinstance(
+            obj._var_value, cls, nested=nested, treat_var_as_type=True
+        )
     if isinstance(obj, Var):
-        return _issubclass(obj._var_type, cls)
+        return treat_var_as_type and _issubclass(obj._var_type, cls)
 
     if cls is None or cls is type(None):
         return obj is None
 
     if cls and is_union(cls):
-        return any(_isinstance(obj, arg, nested=nested) for arg in get_args(cls))
+        return any(
+            _isinstance(obj, arg, nested=nested, treat_var_as_type=treat_var_as_type)
+            for arg in get_args(cls)
+        )
 
     if is_literal(cls):
         return obj in get_args(cls)
@@ -566,37 +574,69 @@ def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool:
     if nested > 0 and args:
         if origin is list:
             return isinstance(obj, list) and all(
-                _isinstance(item, args[0], nested=nested - 1) for item in obj
+                _isinstance(
+                    item,
+                    args[0],
+                    nested=nested - 1,
+                    treat_var_as_type=treat_var_as_type,
+                )
+                for item in obj
             )
         if origin is tuple:
             if args[-1] is Ellipsis:
                 return isinstance(obj, tuple) and all(
-                    _isinstance(item, args[0], nested=nested - 1) for item in obj
+                    _isinstance(
+                        item,
+                        args[0],
+                        nested=nested - 1,
+                        treat_var_as_type=treat_var_as_type,
+                    )
+                    for item in obj
                 )
             return (
                 isinstance(obj, tuple)
                 and len(obj) == len(args)
                 and all(
-                    _isinstance(item, arg, nested=nested - 1)
+                    _isinstance(
+                        item,
+                        arg,
+                        nested=nested - 1,
+                        treat_var_as_type=treat_var_as_type,
+                    )
                     for item, arg in zip(obj, args, strict=True)
                 )
             )
         if origin in (dict, Mapping, Breakpoints):
             return isinstance(obj, Mapping) and all(
-                _isinstance(key, args[0], nested=nested - 1)
-                and _isinstance(value, args[1], nested=nested - 1)
+                _isinstance(
+                    key, args[0], nested=nested - 1, treat_var_as_type=treat_var_as_type
+                )
+                and _isinstance(
+                    value,
+                    args[1],
+                    nested=nested - 1,
+                    treat_var_as_type=treat_var_as_type,
+                )
                 for key, value in obj.items()
             )
         if origin is set:
             return isinstance(obj, set) and all(
-                _isinstance(item, args[0], nested=nested - 1) for item in obj
+                _isinstance(
+                    item,
+                    args[0],
+                    nested=nested - 1,
+                    treat_var_as_type=treat_var_as_type,
+                )
+                for item in obj
             )
 
     if args:
         from reflex.vars import Field
 
         if origin is Field:
-            return _isinstance(obj, args[0], nested=nested)
+            return _isinstance(
+                obj, args[0], nested=nested, treat_var_as_type=treat_var_as_type
+            )
 
     return isinstance(obj, get_base_class(cls))
 

+ 1 - 1
reflex/vars/base.py

@@ -2289,7 +2289,7 @@ class ComputedVar(Var[RETURN_TYPE]):
         return value
 
     def _check_deprecated_return_type(self, instance: BaseState, value: Any) -> None:
-        if not _isinstance(value, self._var_type):
+        if not _isinstance(value, self._var_type, nested=1, treat_var_as_type=False):
             console.error(
                 f"Computed var '{type(instance).__name__}.{self._js_expr}' must return"
                 f" type '{self._var_type}', got '{type(value)}'."