瀏覽代碼

add type validation for state setattr (#4265)

* add type validation for state setattr

* add type to check to state setattr

* add type validation to computed vars
Khaleel Al-Adhami 6 月之前
父節點
當前提交
c8a7ee52bf
共有 3 個文件被更改,包括 104 次插入15 次删除
  1. 15 2
      reflex/state.py
  2. 51 1
      reflex/utils/types.py
  3. 38 12
      reflex/vars/base.py

+ 15 - 2
reflex/state.py

@@ -91,7 +91,7 @@ from reflex.utils.exceptions import (
 )
 from reflex.utils.exec import is_testing_env
 from reflex.utils.serializers import serializer
-from reflex.utils.types import get_origin, override
+from reflex.utils.types import _isinstance, get_origin, override
 from reflex.vars import VarData
 
 if TYPE_CHECKING:
@@ -636,7 +636,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         def computed_var_func(state: Self):
             result = f(state)
 
-            if not isinstance(result, of_type):
+            if not _isinstance(result, of_type):
                 console.warn(
                     f"Inline ComputedVar {f} expected type {of_type}, got {type(result)}. "
                     "You can specify expected type with `of_type` argument."
@@ -1274,6 +1274,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 f"All state variables must be declared before they can be set."
             )
 
+        fields = self.get_fields()
+
+        if name in fields and not _isinstance(
+            value, (field_type := fields[name].outer_type_)
+        ):
+            console.deprecate(
+                "mismatched-type-assignment",
+                f"Tried to assign value {value} of type {type(value)} to field {type(self).__name__}.{name} of type {field_type}."
+                " This might lead to unexpected behavior.",
+                "0.6.5",
+                "0.7.0",
+            )
+
         # Set the attribute.
         super().__setattr__(name, value)
 

+ 51 - 1
reflex/utils/types.py

@@ -510,16 +510,66 @@ def _issubclass(cls: GenericType, cls_check: GenericType, instance: Any = None)
         raise TypeError(f"Invalid type for issubclass: {cls_base}") from te
 
 
-def _isinstance(obj: Any, cls: GenericType) -> bool:
+def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
     """Check if an object is an instance of a class.
 
     Args:
         obj: The object to check.
         cls: The class to check against.
+        nested: Whether the check is nested.
 
     Returns:
         Whether the object is an instance of the class.
     """
+    if cls is Any:
+        return True
+
+    if cls is None or cls is type(None):
+        return obj is None
+
+    if is_literal(cls):
+        return obj in get_args(cls)
+
+    if is_union(cls):
+        return any(_isinstance(obj, arg) for arg in get_args(cls))
+
+    origin = get_origin(cls)
+
+    if origin is None:
+        # cls is a simple class
+        return isinstance(obj, cls)
+
+    args = get_args(cls)
+
+    if not args:
+        # cls is a simple generic class
+        return isinstance(obj, origin)
+
+    if nested and args:
+        if origin is list:
+            return isinstance(obj, list) and all(
+                _isinstance(item, args[0]) for item in obj
+            )
+        if origin is tuple:
+            if args[-1] is Ellipsis:
+                return isinstance(obj, tuple) and all(
+                    _isinstance(item, args[0]) for item in obj
+                )
+            return (
+                isinstance(obj, tuple)
+                and len(obj) == len(args)
+                and all(_isinstance(item, arg) for item, arg in zip(obj, args))
+            )
+        if origin is dict:
+            return isinstance(obj, dict) and all(
+                _isinstance(key, args[0]) and _isinstance(value, args[1])
+                for key, value in obj.items()
+            )
+        if origin is set:
+            return isinstance(obj, set) and all(
+                _isinstance(item, args[0]) for item in obj
+            )
+
     return isinstance(obj, get_base_class(cls))
 
 

+ 38 - 12
reflex/vars/base.py

@@ -63,7 +63,14 @@ from reflex.utils.imports import (
     ParsedImportDict,
     parse_imports,
 )
-from reflex.utils.types import GenericType, Self, get_origin, has_args, unionize
+from reflex.utils.types import (
+    GenericType,
+    Self,
+    _isinstance,
+    get_origin,
+    has_args,
+    unionize,
+)
 
 if TYPE_CHECKING:
     from reflex.state import BaseState
@@ -1833,6 +1840,14 @@ class ComputedVar(Var[RETURN_TYPE]):
             "return", Any
         )
 
+        if hint is Any:
+            console.deprecate(
+                "untyped-computed-var",
+                "ComputedVar should have a return type annotation.",
+                "0.6.5",
+                "0.7.0",
+            )
+
         kwargs.setdefault("_js_expr", fget.__name__)
         kwargs.setdefault("_var_type", hint)
 
@@ -2026,17 +2041,28 @@ class ComputedVar(Var[RETURN_TYPE]):
             )
 
         if not self._cache:
-            return self.fget(instance)
-
-        # handle caching
-        if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
-            # Set cache attr on state instance.
-            setattr(instance, self._cache_attr, self.fget(instance))
-            # Ensure the computed var gets serialized to redis.
-            instance._was_touched = True
-            # Set the last updated timestamp on the state instance.
-            setattr(instance, self._last_updated_attr, datetime.datetime.now())
-        return getattr(instance, self._cache_attr)
+            value = self.fget(instance)
+        else:
+            # handle caching
+            if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
+                # Set cache attr on state instance.
+                setattr(instance, self._cache_attr, self.fget(instance))
+                # Ensure the computed var gets serialized to redis.
+                instance._was_touched = True
+                # Set the last updated timestamp on the state instance.
+                setattr(instance, self._last_updated_attr, datetime.datetime.now())
+            value = getattr(instance, self._cache_attr)
+
+        if not _isinstance(value, self._var_type):
+            console.deprecate(
+                "mismatched-computed-var-return",
+                f"Computed var {type(instance).__name__}.{self._js_expr} returned value of type {type(value)}, "
+                f"expected {self._var_type}. This might cause unexpected behavior.",
+                "0.6.5",
+                "0.7.0",
+            )
+
+        return value
 
     def _deps(
         self,