Selaa lähdekoodia

improve var base typing (#4718)

* improve var base typing

* fix pyi

* dang it darglint

* drain _process in tests

* fixes #4576

* dang it darglint
Khaleel Al-Adhami 3 kuukautta sitten
vanhempi
säilyke
8663dbcb97

+ 2 - 1
reflex/components/base/error_boundary.py

@@ -11,10 +11,11 @@ from reflex.event import EventHandler, set_clipboard
 from reflex.state import FrontendEventExceptionState
 from reflex.vars.base import Var
 from reflex.vars.function import ArgsFunctionOperation
+from reflex.vars.object import ObjectVar
 
 
 def on_error_spec(
-    error: Var[Dict[str, str]], info: Var[Dict[str, str]]
+    error: ObjectVar[Dict[str, str]], info: ObjectVar[Dict[str, str]]
 ) -> Tuple[Var[str], Var[str]]:
     """The spec for the on_error event handler.
 

+ 2 - 1
reflex/components/base/error_boundary.pyi

@@ -9,9 +9,10 @@ from reflex.components.component import Component
 from reflex.event import BASE_STATE, EventType
 from reflex.style import Style
 from reflex.vars.base import Var
+from reflex.vars.object import ObjectVar
 
 def on_error_spec(
-    error: Var[Dict[str, str]], info: Var[Dict[str, str]]
+    error: ObjectVar[Dict[str, str]], info: ObjectVar[Dict[str, str]]
 ) -> Tuple[Var[str], Var[str]]: ...
 
 class ErrorBoundary(Component):

+ 1 - 0
reflex/components/component.py

@@ -2457,6 +2457,7 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) ->
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
+    slots=True,
 )
 class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
     """A Var that represents a Component."""

+ 10 - 2
reflex/components/core/foreach.py

@@ -11,6 +11,7 @@ from reflex.components.component import Component
 from reflex.components.tags import IterTag
 from reflex.constants import MemoizationMode
 from reflex.state import ComponentState
+from reflex.utils.exceptions import UntypedVarError
 from reflex.vars.base import LiteralVar, Var
 
 
@@ -51,6 +52,7 @@ class Foreach(Component):
         Raises:
             ForeachVarError: If the iterable is of type Any.
             TypeError: If the render function is a ComponentState.
+            UntypedVarError: If the iterable is of type Any without a type annotation.
         """
         iterable = LiteralVar.create(iterable)
         if iterable._var_type == Any:
@@ -72,8 +74,14 @@ class Foreach(Component):
             iterable=iterable,
             render_fn=render_fn,
         )
-        # Keep a ref to a rendered component to determine correct imports/hooks/styles.
-        component.children = [component._render().render_component()]
+        try:
+            # Keep a ref to a rendered component to determine correct imports/hooks/styles.
+            component.children = [component._render().render_component()]
+        except UntypedVarError as e:
+            raise UntypedVarError(
+                f"Could not foreach over var `{iterable!s}` without a type annotation. "
+                "See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
+            ) from e
         return component
 
     def _render(self) -> IterTag:

+ 2 - 1
reflex/components/datadisplay/dataeditor.py

@@ -387,7 +387,8 @@ class DataEditor(NoSSRComponent):
                 raise ValueError(
                     "DataEditor data must be an ArrayVar if rows is not provided."
                 )
-            props["rows"] = data.length() if isinstance(data, Var) else len(data)
+
+            props["rows"] = data.length() if isinstance(data, ArrayVar) else len(data)
 
         if not isinstance(columns, Var) and len(columns):
             if types.is_dataframe(type(data)) or (

+ 8 - 4
reflex/components/datadisplay/shiki_code_block.py

@@ -621,18 +621,22 @@ class ShikiCodeBlock(Component, MarkdownComponentMap):
 
         Returns:
             Imports for the component.
+
+        Raises:
+            ValueError: If the transformers are not of type LiteralVar.
         """
         imports = defaultdict(list)
+        if not isinstance(self.transformers, LiteralVar):
+            raise ValueError(
+                f"transformers should be a LiteralVar type. Got {type(self.transformers)} instead."
+            )
         for transformer in self.transformers._var_value:
             if isinstance(transformer, ShikiBaseTransformers):
                 imports[transformer.library].extend(
                     [ImportVar(tag=str(fn)) for fn in transformer.fns]
                 )
-                (
+                if transformer.library not in self.lib_dependencies:
                     self.lib_dependencies.append(transformer.library)
-                    if transformer.library not in self.lib_dependencies
-                    else None
-                )
         return imports
 
     @classmethod

+ 17 - 8
reflex/event.py

@@ -4,7 +4,6 @@ from __future__ import annotations
 
 import dataclasses
 import inspect
-import sys
 import types
 import urllib.parse
 from base64 import b64encode
@@ -541,7 +540,7 @@ class JavasciptKeyboardEvent:
     shiftKey: bool = False  # noqa: N815
 
 
-def input_event(e: Var[JavascriptInputEvent]) -> Tuple[Var[str]]:
+def input_event(e: ObjectVar[JavascriptInputEvent]) -> Tuple[Var[str]]:
     """Get the value from an input event.
 
     Args:
@@ -562,7 +561,9 @@ class KeyInputInfo(TypedDict):
     shift_key: bool
 
 
-def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInfo]]:
+def key_event(
+    e: ObjectVar[JavasciptKeyboardEvent],
+) -> Tuple[Var[str], Var[KeyInputInfo]]:
     """Get the key from a keyboard event.
 
     Args:
@@ -572,7 +573,7 @@ def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInf
         The key from the keyboard event.
     """
     return (
-        e.key,
+        e.key.to(str),
         Var.create(
             {
                 "alt_key": e.altKey,
@@ -580,7 +581,7 @@ def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInf
                 "meta_key": e.metaKey,
                 "shift_key": e.shiftKey,
             },
-        ),
+        ).to(KeyInputInfo),
     )
 
 
@@ -1354,7 +1355,7 @@ def unwrap_var_annotation(annotation: GenericType):
     Returns:
         The unwrapped annotation.
     """
-    if get_origin(annotation) is Var and (args := get_args(annotation)):
+    if get_origin(annotation) in (Var, ObjectVar) and (args := get_args(annotation)):
         return args[0]
     return annotation
 
@@ -1620,7 +1621,7 @@ class EventVar(ObjectVar, python_types=EventSpec):
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class LiteralEventVar(VarOperationCall, LiteralVar, EventVar):
     """A literal event var."""
@@ -1681,7 +1682,7 @@ class EventChainVar(BuilderFunctionVar, python_types=EventChain):
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 # Note: LiteralVar is second in the inheritance list allowing it act like a
 # CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the
@@ -1713,6 +1714,9 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV
 
         Returns:
             The created LiteralEventChainVar instance.
+
+        Raises:
+            ValueError: If the invocation is not a FunctionVar.
         """
         arg_spec = (
             value.args_spec[0]
@@ -1740,6 +1744,11 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV
         else:
             invocation = value.invocation
 
+        if invocation is not None and not isinstance(invocation, FunctionVar):
+            raise ValueError(
+                f"EventChain invocation must be a FunctionVar, got {invocation!s} of type {invocation._var_type!s}."
+            )
+
         return cls(
             _js_expr="",
             _var_type=EventChain,

+ 1 - 2
reflex/experimental/client_state.py

@@ -4,7 +4,6 @@ from __future__ import annotations
 
 import dataclasses
 import re
-import sys
 from typing import Any, Callable, Union
 
 from reflex import constants
@@ -49,7 +48,7 @@ def _client_state_ref_dict(var_name: str) -> str:
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class ClientStateVar(Var):
     """A Var that exists on the client via useState."""

+ 4 - 2
reflex/state.py

@@ -1637,9 +1637,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         if not isinstance(var, Var):
             return var
 
+        unset = object()
+
         # Fast case: this is a literal var and the value is known.
-        if hasattr(var, "_var_value"):
-            return var._var_value
+        if (var_value := getattr(var, "_var_value", unset)) is not unset:
+            return var_value  # pyright: ignore [reportReturnType]
 
         var_data = var._get_all_var_data()
         if var_data is None or not var_data.state:

+ 4 - 0
reflex/utils/exceptions.py

@@ -75,6 +75,10 @@ class VarAttributeError(ReflexError, AttributeError):
     """Custom AttributeError for var related errors."""
 
 
+class UntypedVarError(ReflexError, TypeError):
+    """Custom TypeError for untyped var errors."""
+
+
 class UntypedComputedVarError(ReflexError, TypeError):
     """Custom TypeError for untyped computed var errors."""
 

+ 107 - 96
reflex/vars/base.py

@@ -12,7 +12,6 @@ import json
 import random
 import re
 import string
-import sys
 import warnings
 from types import CodeType, FunctionType
 from typing import (
@@ -82,6 +81,7 @@ if TYPE_CHECKING:
 VAR_TYPE = TypeVar("VAR_TYPE", covariant=True)
 OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE")
 STRING_T = TypeVar("STRING_T", bound=str)
+SEQUENCE_TYPE = TypeVar("SEQUENCE_TYPE", bound=Sequence)
 
 warnings.filterwarnings("ignore", message="fields may not start with an underscore")
 
@@ -449,7 +449,7 @@ class Var(Generic[VAR_TYPE]):
             @dataclasses.dataclass(
                 eq=False,
                 frozen=True,
-                **{"slots": True} if sys.version_info >= (3, 10) else {},
+                slots=True,
             )
             class ToVarOperation(ToOperation, cls):
                 """Base class of converting a var to another var type."""
@@ -597,7 +597,7 @@ class Var(Generic[VAR_TYPE]):
 
     @overload
     @classmethod
-    def create(
+    def create(  # pyright: ignore [reportOverlappingOverload]
         cls,
         value: STRING_T,
         _var_data: VarData | None = None,
@@ -611,6 +611,22 @@ class Var(Generic[VAR_TYPE]):
         _var_data: VarData | None = None,
     ) -> NoneVar: ...
 
+    @overload
+    @classmethod
+    def create(
+        cls,
+        value: MAPPING_TYPE,
+        _var_data: VarData | None = None,
+    ) -> ObjectVar[MAPPING_TYPE]: ...
+
+    @overload
+    @classmethod
+    def create(
+        cls,
+        value: SEQUENCE_TYPE,
+        _var_data: VarData | None = None,
+    ) -> ArrayVar[SEQUENCE_TYPE]: ...
+
     @overload
     @classmethod
     def create(
@@ -692,8 +708,8 @@ class Var(Generic[VAR_TYPE]):
     @overload
     def to(
         self,
-        output: type[Mapping],
-    ) -> ObjectVar[Mapping]: ...
+        output: type[MAPPING_TYPE],
+    ) -> ObjectVar[MAPPING_TYPE]: ...
 
     @overload
     def to(
@@ -744,7 +760,7 @@ class Var(Generic[VAR_TYPE]):
             return get_to_operation(NoneVar).create(self)  # pyright: ignore [reportReturnType]
 
         # Handle fixed_output_type being Base or a dataclass.
-        if can_use_in_object_var(fixed_output_type):
+        if can_use_in_object_var(output):
             return self.to(ObjectVar, output)
 
         if inspect.isclass(output):
@@ -776,6 +792,9 @@ class Var(Generic[VAR_TYPE]):
 
         return self
 
+    @overload
+    def guess_type(self: Var[NoReturn]) -> Var[Any]: ...  # pyright: ignore [reportOverlappingOverload]
+
     @overload
     def guess_type(self: Var[str]) -> StringVar: ...
 
@@ -785,6 +804,9 @@ class Var(Generic[VAR_TYPE]):
     @overload
     def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ...
 
+    @overload
+    def guess_type(self: Var[BASE_TYPE]) -> ObjectVar[BASE_TYPE]: ...
+
     @overload
     def guess_type(self) -> Self: ...
 
@@ -933,7 +955,7 @@ class Var(Generic[VAR_TYPE]):
 
         return setter
 
-    def _var_set_state(self, state: type[BaseState] | str):
+    def _var_set_state(self, state: type[BaseState] | str) -> Self:
         """Set the state of the var.
 
         Args:
@@ -948,7 +970,7 @@ class Var(Generic[VAR_TYPE]):
             else format_state_name(state.get_full_name())
         )
 
-        return StateOperation.create(
+        return StateOperation.create(  # pyright: ignore [reportReturnType]
             formatted_state_name,
             self,
             _var_data=VarData.merge(
@@ -1127,43 +1149,6 @@ class Var(Generic[VAR_TYPE]):
         """
         return self
 
-    def __getattr__(self, name: str):
-        """Get an attribute of the var.
-
-        Args:
-            name: The name of the attribute.
-
-        Returns:
-            The attribute.
-
-        Raises:
-            VarAttributeError: If the attribute does not exist.
-            TypeError: If the var type is Any.
-        """
-        if name.startswith("_"):
-            return self.__getattribute__(name)
-
-        if name == "contains":
-            raise TypeError(
-                f"Var of type {self._var_type} does not support contains check."
-            )
-        if name == "reverse":
-            raise TypeError("Cannot reverse non-list var.")
-
-        if self._var_type is Any:
-            raise TypeError(
-                f"You must provide an annotation for the state var `{self!s}`. Annotation cannot be `{self._var_type}`."
-            )
-
-        if name in REPLACED_NAMES:
-            raise VarAttributeError(
-                f"Field {name!r} was renamed to {REPLACED_NAMES[name]!r}"
-            )
-
-        raise VarAttributeError(
-            f"The State var has no attribute '{name}' or may have been annotated wrongly.",
-        )
-
     def _decode(self) -> Any:
         """Decode Var as a python value.
 
@@ -1225,36 +1210,76 @@ class Var(Generic[VAR_TYPE]):
 
         return ArrayVar.range(first_endpoint, second_endpoint, step)
 
-    def __bool__(self) -> bool:
-        """Raise exception if using Var in a boolean context.
+    if not TYPE_CHECKING:
 
-        Raises:
-            VarTypeError: when attempting to bool-ify the Var.
-        """
-        raise VarTypeError(
-            f"Cannot convert Var {str(self)!r} to bool for use with `if`, `and`, `or`, and `not`. "
-            "Instead use `rx.cond` and bitwise operators `&` (and), `|` (or), `~` (invert)."
-        )
+        def __getattr__(self, name: str):
+            """Get an attribute of the var.
 
-    def __iter__(self) -> Any:
-        """Raise exception if using Var in an iterable context.
+            Args:
+                name: The name of the attribute.
 
-        Raises:
-            VarTypeError: when attempting to iterate over the Var.
-        """
-        raise VarTypeError(
-            f"Cannot iterate over Var {str(self)!r}. Instead use `rx.foreach`."
-        )
+            Raises:
+                VarAttributeError: If the attribute does not exist.
+                UntypedVarError: If the var type is Any.
+                TypeError: If the var type is Any.
 
-    def __contains__(self, _: Any) -> Var:
-        """Override the 'in' operator to alert the user that it is not supported.
+            # noqa: DAR101 self
+            """
+            if name.startswith("_"):
+                raise VarAttributeError(f"Attribute {name} not found.")
 
-        Raises:
-            VarTypeError: the operation is not supported
-        """
-        raise VarTypeError(
-            "'in' operator not supported for Var types, use Var.contains() instead."
-        )
+            if name == "contains":
+                raise TypeError(
+                    f"Var of type {self._var_type} does not support contains check."
+                )
+            if name == "reverse":
+                raise TypeError("Cannot reverse non-list var.")
+
+            if self._var_type is Any:
+                raise exceptions.UntypedVarError(
+                    f"You must provide an annotation for the state var `{self!s}`. Annotation cannot be `{self._var_type}`."
+                )
+
+            raise VarAttributeError(
+                f"The State var has no attribute '{name}' or may have been annotated wrongly.",
+            )
+
+        def __bool__(self) -> bool:
+            """Raise exception if using Var in a boolean context.
+
+            Raises:
+                VarTypeError: when attempting to bool-ify the Var.
+
+            # noqa: DAR101 self
+            """
+            raise VarTypeError(
+                f"Cannot convert Var {str(self)!r} to bool for use with `if`, `and`, `or`, and `not`. "
+                "Instead use `rx.cond` and bitwise operators `&` (and), `|` (or), `~` (invert)."
+            )
+
+        def __iter__(self) -> Any:
+            """Raise exception if using Var in an iterable context.
+
+            Raises:
+                VarTypeError: when attempting to iterate over the Var.
+
+            # noqa: DAR101 self
+            """
+            raise VarTypeError(
+                f"Cannot iterate over Var {str(self)!r}. Instead use `rx.foreach`."
+            )
+
+        def __contains__(self, _: Any) -> Var:
+            """Override the 'in' operator to alert the user that it is not supported.
+
+            Raises:
+                VarTypeError: the operation is not supported
+
+            # noqa: DAR101 self
+            """
+            raise VarTypeError(
+                "'in' operator not supported for Var types, use Var.contains() instead."
+            )
 
 
 OUTPUT = TypeVar("OUTPUT", bound=Var)
@@ -1471,6 +1496,12 @@ class LiteralVar(Var):
     def __post_init__(self):
         """Post-initialize the var."""
 
+    @property
+    def _var_value(self) -> Any:
+        raise NotImplementedError(
+            "LiteralVar subclasses must implement the _var_value property."
+        )
+
     def json(self) -> str:
         """Serialize the var to a JSON string.
 
@@ -1543,7 +1574,7 @@ def var_operation(
 ) -> Callable[P, StringVar]: ...
 
 
-LIST_T = TypeVar("LIST_T", bound=Union[List[Any], Tuple, Set])
+LIST_T = TypeVar("LIST_T", bound=Sequence)
 
 
 @overload
@@ -1780,7 +1811,7 @@ def _or_operation(a: Var, b: Var):
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class CallableVar(Var):
     """Decorate a Var-returning function to act as both a Var and a function.
@@ -1861,7 +1892,7 @@ def is_computed_var(obj: Any) -> TypeGuard[ComputedVar]:
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class ComputedVar(Var[RETURN_TYPE]):
     """A field with computed getters."""
@@ -2070,13 +2101,6 @@ class ComputedVar(Var[RETURN_TYPE]):
         owner: Type,
     ) -> ArrayVar[list[LIST_INSIDE]]: ...
 
-    @overload
-    def __get__(
-        self: ComputedVar[set[LIST_INSIDE]],
-        instance: None,
-        owner: Type,
-    ) -> ArrayVar[set[LIST_INSIDE]]: ...
-
     @overload
     def __get__(
         self: ComputedVar[tuple[LIST_INSIDE, ...]],
@@ -2436,7 +2460,7 @@ def var_operation_return(
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class CustomVarOperation(CachedVarOperation, Var[T]):
     """Base class for custom var operations."""
@@ -2507,7 +2531,7 @@ class NoneVar(Var[None], python_types=type(None)):
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class LiteralNoneVar(LiteralVar, NoneVar):
     """A var representing None."""
@@ -2569,7 +2593,7 @@ def get_to_operation(var_subclass: Type[Var]) -> Type[ToOperation]:
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class StateOperation(CachedVarOperation, Var):
     """A var operation that accesses a field on an object."""
@@ -2716,19 +2740,6 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]:
     return var_datas
 
 
-# These names were changed in reflex 0.3.0
-REPLACED_NAMES = {
-    "full_name": "_var_full_name",
-    "name": "_js_expr",
-    "state": "_var_data.state",
-    "type_": "_var_type",
-    "is_local": "_var_is_local",
-    "is_string": "_var_is_string",
-    "set_state": "_var_set_state",
-    "deps": "_deps",
-}
-
-
 dispatchers: Dict[GenericType, Callable[[Var], Var]] = {}
 
 

+ 1 - 2
reflex/vars/datetime.py

@@ -3,7 +3,6 @@
 from __future__ import annotations
 
 import dataclasses
-import sys
 from datetime import date, datetime
 from typing import Any, NoReturn, TypeVar, Union, overload
 
@@ -193,7 +192,7 @@ def date_compare_operation(
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class LiteralDatetimeVar(LiteralVar, DateTimeVar):
     """Base class for immutable datetime and date vars."""

+ 3 - 3
reflex/vars/function.py

@@ -226,7 +226,7 @@ class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
     """Base class for immutable vars that are the result of a function call."""
@@ -350,7 +350,7 @@ def format_args_function_operation(
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
     """Base class for immutable function defined via arguments and return expression."""
@@ -407,7 +407,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
     """Base class for immutable function defined via arguments and return expression with the builder pattern."""

+ 4 - 5
reflex/vars/number.py

@@ -5,7 +5,6 @@ from __future__ import annotations
 import dataclasses
 import json
 import math
-import sys
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -160,7 +159,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
         """
         from .sequence import ArrayVar, LiteralArrayVar
 
-        if isinstance(other, (list, tuple, set, ArrayVar)):
+        if isinstance(other, (list, tuple, ArrayVar)):
             if isinstance(other, ArrayVar):
                 return other * self
             return LiteralArrayVar.create(other) * self
@@ -187,7 +186,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
         """
         from .sequence import ArrayVar, LiteralArrayVar
 
-        if isinstance(other, (list, tuple, set, ArrayVar)):
+        if isinstance(other, (list, tuple, ArrayVar)):
             if isinstance(other, ArrayVar):
                 return other * self
             return LiteralArrayVar.create(other) * self
@@ -973,7 +972,7 @@ def boolean_not_operation(value: BooleanVar):
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class LiteralNumberVar(LiteralVar, NumberVar):
     """Base class for immutable literal number vars."""
@@ -1032,7 +1031,7 @@ class LiteralNumberVar(LiteralVar, NumberVar):
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class LiteralBooleanVar(LiteralVar, BooleanVar):
     """Base class for immutable literal boolean vars."""

+ 15 - 25
reflex/vars/object.py

@@ -3,7 +3,6 @@
 from __future__ import annotations
 
 import dataclasses
-import sys
 import typing
 from inspect import isclass
 from typing import (
@@ -167,12 +166,6 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
         key: Var | Any,
     ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
 
-    @overload
-    def __getitem__(
-        self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
-        key: Var | Any,
-    ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
-
     @overload
     def __getitem__(
         self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
@@ -229,12 +222,6 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
         name: str,
     ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
 
-    @overload
-    def __getattr__(
-        self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
-        name: str,
-    ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
-
     @overload
     def __getattr__(
         self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
@@ -305,7 +292,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
     """Base class for immutable literal object vars."""
@@ -355,17 +342,20 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
 
         Returns:
             The JSON representation of the object.
+
+        Raises:
+            TypeError: The keys and values of the object must be literal vars to get the JSON representation
         """
-        return (
-            "{"
-            + ", ".join(
-                [
-                    f"{LiteralVar.create(key).json()}:{LiteralVar.create(value).json()}"
-                    for key, value in self._var_value.items()
-                ]
-            )
-            + "}"
-        )
+        keys_and_values = []
+        for key, value in self._var_value.items():
+            key = LiteralVar.create(key)
+            value = LiteralVar.create(value)
+            if not isinstance(key, LiteralVar) or not isinstance(value, LiteralVar):
+                raise TypeError(
+                    "The keys and values of the object must be literal vars to get the JSON representation."
+                )
+            keys_and_values.append(f"{key.json()}:{value.json()}")
+        return "{" + ", ".join(keys_and_values) + "}"
 
     def __hash__(self) -> int:
         """Get the hash of the var.
@@ -487,7 +477,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class ObjectItemOperation(CachedVarOperation, Var):
     """Operation to get an item from an object."""

+ 33 - 41
reflex/vars/sequence.py

@@ -6,7 +6,6 @@ import dataclasses
 import inspect
 import json
 import re
-import sys
 import typing
 from typing import (
     TYPE_CHECKING,
@@ -15,7 +14,7 @@ from typing import (
     List,
     Literal,
     NoReturn,
-    Set,
+    Sequence,
     Tuple,
     Type,
     Union,
@@ -596,7 +595,7 @@ _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class LiteralStringVar(LiteralVar, StringVar[str]):
     """Base class for immutable literal string vars."""
@@ -718,7 +717,7 @@ class LiteralStringVar(LiteralVar, StringVar[str]):
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class ConcatVarOperation(CachedVarOperation, StringVar[str]):
     """Representing a concatenation of literal string vars."""
@@ -794,7 +793,8 @@ class ConcatVarOperation(CachedVarOperation, StringVar[str]):
         )
 
 
-ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Union[List, Tuple, Set])
+ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Sequence, covariant=True)
+OTHER_ARRAY_VAR_TYPE = TypeVar("OTHER_ARRAY_VAR_TYPE", bound=Sequence)
 
 OTHER_TUPLE = TypeVar("OTHER_TUPLE")
 
@@ -887,6 +887,11 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
         i: Literal[0, -2],
     ) -> NumberVar: ...
 
+    @overload
+    def __getitem__(
+        self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1]
+    ) -> BooleanVar: ...
+
     @overload
     def __getitem__(
         self: (
@@ -914,7 +919,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
 
     @overload
     def __getitem__(
-        self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1]
+        self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar
     ) -> BooleanVar: ...
 
     @overload
@@ -932,23 +937,12 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
         self: ARRAY_VAR_OF_LIST_ELEMENT[str], i: int | NumberVar
     ) -> StringVar: ...
 
-    @overload
-    def __getitem__(
-        self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar
-    ) -> BooleanVar: ...
-
     @overload
     def __getitem__(
         self: ARRAY_VAR_OF_LIST_ELEMENT[List[INNER_ARRAY_VAR]],
         i: int | NumberVar,
     ) -> ArrayVar[List[INNER_ARRAY_VAR]]: ...
 
-    @overload
-    def __getitem__(
-        self: ARRAY_VAR_OF_LIST_ELEMENT[Set[INNER_ARRAY_VAR]],
-        i: int | NumberVar,
-    ) -> ArrayVar[Set[INNER_ARRAY_VAR]]: ...
-
     @overload
     def __getitem__(
         self: ARRAY_VAR_OF_LIST_ELEMENT[Tuple[KEY_TYPE, VALUE_TYPE]],
@@ -1239,26 +1233,18 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
 
 LIST_ELEMENT = TypeVar("LIST_ELEMENT")
 
-ARRAY_VAR_OF_LIST_ELEMENT = Union[
-    ArrayVar[List[LIST_ELEMENT]],
-    ArrayVar[Set[LIST_ELEMENT]],
-    ArrayVar[Tuple[LIST_ELEMENT, ...]],
-]
+ARRAY_VAR_OF_LIST_ELEMENT = ArrayVar[Sequence[LIST_ELEMENT]]
 
 
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
     """Base class for immutable literal array vars."""
 
-    _var_value: Union[
-        List[Union[Var, Any]],
-        Set[Union[Var, Any]],
-        Tuple[Union[Var, Any], ...],
-    ] = dataclasses.field(default_factory=list)
+    _var_value: Sequence[Union[Var, Any]] = dataclasses.field(default=())
 
     @cached_property_no_lock
     def _cached_var_name(self) -> str:
@@ -1303,22 +1289,28 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
 
         Returns:
             The JSON representation of the var.
+
+        Raises:
+            TypeError: If the array elements are not of type LiteralVar.
         """
-        return (
-            "["
-            + ", ".join(
-                [LiteralVar.create(element).json() for element in self._var_value]
-            )
-            + "]"
-        )
+        elements = []
+        for element in self._var_value:
+            element_var = LiteralVar.create(element)
+            if not isinstance(element_var, LiteralVar):
+                raise TypeError(
+                    f"Array elements must be of type LiteralVar, not {type(element_var)}"
+                )
+            elements.append(element_var.json())
+
+        return "[" + ", ".join(elements) + "]"
 
     @classmethod
     def create(
         cls,
-        value: ARRAY_VAR_TYPE,
-        _var_type: Type[ARRAY_VAR_TYPE] | None = None,
+        value: OTHER_ARRAY_VAR_TYPE,
+        _var_type: Type[OTHER_ARRAY_VAR_TYPE] | None = None,
         _var_data: VarData | None = None,
-    ) -> LiteralArrayVar[ARRAY_VAR_TYPE]:
+    ) -> LiteralArrayVar[OTHER_ARRAY_VAR_TYPE]:
         """Create a var from a string value.
 
         Args:
@@ -1329,7 +1321,7 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
         Returns:
             The var.
         """
-        return cls(
+        return LiteralArrayVar(
             _js_expr="",
             _var_type=figure_out_type(value) if _var_type is None else _var_type,
             _var_data=_var_data,
@@ -1356,7 +1348,7 @@ def string_split_operation(string: StringVar[Any], sep: StringVar | str = ""):
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class ArraySliceOperation(CachedVarOperation, ArrayVar):
     """Base class for immutable string vars that are the result of a string slice operation."""
@@ -1705,7 +1697,7 @@ class ColorVar(StringVar[Color], python_types=Color):
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class LiteralColorVar(CachedVarOperation, LiteralVar, ColorVar):
     """Base class for immutable literal color vars."""

+ 4 - 0
tests/units/components/core/test_match.py

@@ -3,6 +3,7 @@ from typing import List, Mapping, Tuple
 import pytest
 
 import reflex as rx
+from reflex.components.component import Component
 from reflex.components.core.match import Match
 from reflex.state import BaseState
 from reflex.utils.exceptions import MatchTypeError
@@ -29,6 +30,8 @@ def test_match_components():
         rx.text("default value"),
     )
     match_comp = Match.create(MatchState.value, *match_case_tuples)
+
+    assert isinstance(match_comp, Component)
     match_dict = match_comp.render()
     assert match_dict["name"] == "Fragment"
 
@@ -151,6 +154,7 @@ def test_match_on_component_without_default():
     )
 
     match_comp = Match.create(MatchState.value, *match_case_tuples)
+    assert isinstance(match_comp, Component)
     default = match_comp.render()["children"][0]["default"]
 
     assert isinstance(default, dict) and default["name"] == Fragment.__name__

+ 5 - 4
tests/units/components/test_component.py

@@ -36,6 +36,7 @@ from reflex.utils.exceptions import (
 from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
 from reflex.vars import VarData
 from reflex.vars.base import LiteralVar, Var
+from reflex.vars.object import ObjectVar
 
 
 @pytest.fixture
@@ -842,12 +843,12 @@ def test_component_event_trigger_arbitrary_args():
     """Test that we can define arbitrary types for the args of an event trigger."""
 
     def on_foo_spec(
-        _e: Var[JavascriptInputEvent],
+        _e: ObjectVar[JavascriptInputEvent],
         alpha: Var[str],
         bravo: dict[str, Any],
-        charlie: Var[_Obj],
+        charlie: ObjectVar[_Obj],
     ):
-        return [_e.target.value, bravo["nested"], charlie.custom + 42]
+        return [_e.target.value, bravo["nested"], charlie.custom.to(int) + 42]
 
     class C1(Component):
         library = "/local"
@@ -1328,7 +1329,7 @@ class EventState(rx.State):
         ),
         pytest.param(
             rx.fragment(class_name=[TEST_VAR, "other-class"]),
-            [LiteralVar.create([TEST_VAR, "other-class"]).join(" ")],
+            [Var.create([TEST_VAR, "other-class"]).join(" ")],
             id="fstring-dual-class_name",
         ),
         pytest.param(

+ 13 - 15
tests/units/test_app.py

@@ -471,15 +471,15 @@ async def test_dynamic_var_event(test_state: Type[ATestState], token: str):
     """
     state = test_state()  # pyright: ignore [reportCallIssue]
     state.add_var("int_val", int, 0)
-    result = await state._process(
+    async for result in state._process(
         Event(
             token=token,
             name=f"{test_state.get_name()}.set_int_val",
             router_data={"pathname": "/", "query": {}},
             payload={"value": 50},
         )
-    ).__anext__()
-    assert result.delta == {test_state.get_name(): {"int_val": 50}}
+    ):
+        assert result.delta == {test_state.get_name(): {"int_val": 50}}
 
 
 @pytest.mark.asyncio
@@ -583,18 +583,17 @@ async def test_list_mutation_detection__plain_list(
         token: a Token.
     """
     for event_name, expected_delta in event_tuples:
-        result = await list_mutation_state._process(
+        async for result in list_mutation_state._process(
             Event(
                 token=token,
                 name=f"{list_mutation_state.get_name()}.{event_name}",
                 router_data={"pathname": "/", "query": {}},
                 payload={},
             )
-        ).__anext__()
-
-        # prefix keys in expected_delta with the state name
-        expected_delta = {list_mutation_state.get_name(): expected_delta}
-        assert result.delta == expected_delta
+        ):
+            # prefix keys in expected_delta with the state name
+            expected_delta = {list_mutation_state.get_name(): expected_delta}
+            assert result.delta == expected_delta
 
 
 @pytest.mark.asyncio
@@ -709,19 +708,18 @@ async def test_dict_mutation_detection__plain_list(
         token: a Token.
     """
     for event_name, expected_delta in event_tuples:
-        result = await dict_mutation_state._process(
+        async for result in dict_mutation_state._process(
             Event(
                 token=token,
                 name=f"{dict_mutation_state.get_name()}.{event_name}",
                 router_data={"pathname": "/", "query": {}},
                 payload={},
             )
-        ).__anext__()
-
-        # prefix keys in expected_delta with the state name
-        expected_delta = {dict_mutation_state.get_name(): expected_delta}
+        ):
+            # prefix keys in expected_delta with the state name
+            expected_delta = {dict_mutation_state.get_name(): expected_delta}
 
-        assert result.delta == expected_delta
+            assert result.delta == expected_delta
 
 
 @pytest.mark.asyncio

+ 36 - 38
tests/units/test_state.py

@@ -789,17 +789,16 @@ async def test_process_event_simple(test_state):
     assert test_state.num1 == 0
 
     event = Event(token="t", name="set_num1", payload={"value": 69})
-    update = await test_state._process(event).__anext__()
-
-    # The event should update the value.
-    assert test_state.num1 == 69
-
-    # The delta should contain the changes, including computed vars.
-    assert update.delta == {
-        TestState.get_full_name(): {"num1": 69, "sum": 72.14},
-        GrandchildState3.get_full_name(): {"computed": ""},
-    }
-    assert update.events == []
+    async for update in test_state._process(event):
+        # The event should update the value.
+        assert test_state.num1 == 69
+
+        # The delta should contain the changes, including computed vars.
+        assert update.delta == {
+            TestState.get_full_name(): {"num1": 69, "sum": 72.14},
+            GrandchildState3.get_full_name(): {"computed": ""},
+        }
+        assert update.events == []
 
 
 @pytest.mark.asyncio
@@ -819,15 +818,15 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
         name=f"{ChildState.get_name()}.change_both",
         payload={"value": "hi", "count": 12},
     )
-    update = await test_state._process(event).__anext__()
-    assert child_state.value == "HI"
-    assert child_state.count == 24
-    assert update.delta == {
-        # TestState.get_full_name(): {"sum": 3.14, "upper": ""},
-        ChildState.get_full_name(): {"value": "HI", "count": 24},
-        GrandchildState3.get_full_name(): {"computed": ""},
-    }
-    test_state._clean()
+    async for update in test_state._process(event):
+        assert child_state.value == "HI"
+        assert child_state.count == 24
+        assert update.delta == {
+            # TestState.get_full_name(): {"sum": 3.14, "upper": ""},
+            ChildState.get_full_name(): {"value": "HI", "count": 24},
+            GrandchildState3.get_full_name(): {"computed": ""},
+        }
+        test_state._clean()
 
     # Test with the granchild state.
     assert grandchild_state.value2 == ""
@@ -836,13 +835,13 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
         name=f"{GrandchildState.get_full_name()}.set_value2",
         payload={"value": "new"},
     )
-    update = await test_state._process(event).__anext__()
-    assert grandchild_state.value2 == "new"
-    assert update.delta == {
-        # TestState.get_full_name(): {"sum": 3.14, "upper": ""},
-        GrandchildState.get_full_name(): {"value2": "new"},
-        GrandchildState3.get_full_name(): {"computed": ""},
-    }
+    async for update in test_state._process(event):
+        assert grandchild_state.value2 == "new"
+        assert update.delta == {
+            # TestState.get_full_name(): {"sum": 3.14, "upper": ""},
+            GrandchildState.get_full_name(): {"value2": "new"},
+            GrandchildState3.get_full_name(): {"computed": ""},
+        }
 
 
 @pytest.mark.asyncio
@@ -2909,10 +2908,10 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
 
     events = updates[0].events
     assert len(events) == 2
-    assert (await state._process(events[0]).__anext__()).delta == {
-        test_state.get_full_name(): {"num": 1}
-    }
-    assert (await state._process(events[1]).__anext__()).delta == exp_is_hydrated(state)
+    async for update in state._process(events[0]):
+        assert update.delta == {test_state.get_full_name(): {"num": 1}}
+    async for update in state._process(events[1]):
+        assert update.delta == exp_is_hydrated(state)
 
     if isinstance(app.state_manager, StateManagerRedis):
         await app.state_manager.close()
@@ -2957,13 +2956,12 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
 
     events = updates[0].events
     assert len(events) == 3
-    assert (await state._process(events[0]).__anext__()).delta == {
-        OnLoadState.get_full_name(): {"num": 1}
-    }
-    assert (await state._process(events[1]).__anext__()).delta == {
-        OnLoadState.get_full_name(): {"num": 2}
-    }
-    assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state)
+    async for update in state._process(events[0]):
+        assert update.delta == {OnLoadState.get_full_name(): {"num": 1}}
+    async for update in state._process(events[1]):
+        assert update.delta == {OnLoadState.get_full_name(): {"num": 2}}
+    async for update in state._process(events[2]):
+        assert update.delta == exp_is_hydrated(state)
 
     if isinstance(app.state_manager, StateManagerRedis):
         await app.state_manager.close()

+ 8 - 15
tests/units/test_var.py

@@ -1,6 +1,5 @@
 import json
 import math
-import sys
 import typing
 from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast
 
@@ -422,19 +421,13 @@ class Bar(rx.Base):
 
 @pytest.mark.parametrize(
     ("var", "var_type"),
-    (
-        [
-            (Var(_js_expr="", _var_type=Foo | Bar).guess_type(), Foo | Bar),
-            (Var(_js_expr="", _var_type=Foo | Bar).guess_type().bar, Union[int, str]),
-        ]
-        if sys.version_info >= (3, 10)
-        else []
-    )
-    + [
-        (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type(), Union[Foo, Bar]),
-        (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().baz, str),
+    [
+        (Var(_js_expr="").to(Foo | Bar), Foo | Bar),
+        (Var(_js_expr="").to(Foo | Bar).bar, Union[int, str]),
+        (Var(_js_expr="").to(Union[Foo, Bar]), Union[Foo, Bar]),
+        (Var(_js_expr="").to(Union[Foo, Bar]).baz, str),
         (
-            Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().foo,
+            Var(_js_expr="").to(Union[Foo, Bar]).foo,
             Union[int, None],
         ),
     ],
@@ -1358,7 +1351,7 @@ def test_unsupported_types_for_contains(var: Var):
         var: The base var.
     """
     with pytest.raises(TypeError) as err:
-        assert var.contains(1)
+        assert var.contains(1)  # pyright: ignore [reportAttributeAccessIssue]
     assert (
         err.value.args[0]
         == f"Var of type {var._var_type} does not support contains check."
@@ -1388,7 +1381,7 @@ def test_unsupported_types_for_string_contains(other):
 
 def test_unsupported_default_contains():
     with pytest.raises(TypeError) as err:
-        assert 1 in Var(_js_expr="var", _var_type=str).guess_type()
+        assert 1 in Var(_js_expr="var", _var_type=str).guess_type()  # pyright: ignore [reportOperatorIssue]
     assert (
         err.value.args[0]
         == "'in' operator not supported for Var types, use Var.contains() instead."