Browse Source

down to only two pyright error

Khaleel Al-Adhami 4 months ago
parent
commit
57d8ea02e9

+ 4 - 4
poetry.lock

@@ -2813,13 +2813,13 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)",
 
 
 [[package]]
 [[package]]
 name = "virtualenv"
 name = "virtualenv"
-version = "20.28.1"
+version = "20.29.1"
 description = "Virtual Python Environment builder"
 description = "Virtual Python Environment builder"
 optional = false
 optional = false
 python-versions = ">=3.8"
 python-versions = ">=3.8"
 files = [
 files = [
-    {file = "virtualenv-20.28.1-py3-none-any.whl", hash = "sha256:412773c85d4dab0409b83ec36f7a6499e72eaf08c80e81e9576bca61831c71cb"},
-    {file = "virtualenv-20.28.1.tar.gz", hash = "sha256:5d34ab240fdb5d21549b76f9e8ff3af28252f5499fb6d6f031adac4e5a8c5329"},
+    {file = "virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779"},
+    {file = "virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35"},
 ]
 ]
 
 
 [package.dependencies]
 [package.dependencies]
@@ -3063,4 +3063,4 @@ type = ["pytest-mypy"]
 [metadata]
 [metadata]
 lock-version = "2.0"
 lock-version = "2.0"
 python-versions = "^3.9"
 python-versions = "^3.9"
-content-hash = "ac7633388e00416af61c93e16c61c3b4a446e406afab2758585088e2b3416eee"
+content-hash = "ccd6d6b00fdcf40562854380fafdb18c990b7f6a4f2883b33aaeb0351fcdbc06"

+ 2 - 1
pyproject.toml

@@ -23,7 +23,7 @@ fastapi = ">=0.96.0,!=0.111.0,!=0.111.1"
 gunicorn = ">=20.1.0,<24.0"
 gunicorn = ">=20.1.0,<24.0"
 jinja2 = ">=3.1.2,<4.0"
 jinja2 = ">=3.1.2,<4.0"
 psutil = ">=5.9.4,<7.0"
 psutil = ">=5.9.4,<7.0"
-pydantic = ">=1.10.2,<3.0"
+pydantic = ">=1.10.15,<3.0"
 python-multipart = ">=0.0.5,<0.1"
 python-multipart = ">=0.0.5,<0.1"
 python-socketio = ">=5.7.0,<6.0"
 python-socketio = ">=5.7.0,<6.0"
 redis = ">=4.3.5,<6.0"
 redis = ">=4.3.5,<6.0"
@@ -82,6 +82,7 @@ build-backend = "poetry.core.masonry.api"
 
 
 [tool.pyright]
 [tool.pyright]
 reportIncompatibleMethodOverride = false
 reportIncompatibleMethodOverride = false
+reportIncompatibleVariableOverride = false
 
 
 [tool.ruff]
 [tool.ruff]
 target-version = "py39"
 target-version = "py39"

+ 3 - 9
reflex/base.py

@@ -5,15 +5,9 @@ from __future__ import annotations
 import os
 import os
 from typing import TYPE_CHECKING, Any, List, Type
 from typing import TYPE_CHECKING, Any, List, Type
 
 
-try:
-    import pydantic.v1.main as pydantic_main
-    from pydantic.v1 import BaseModel
-    from pydantic.v1.fields import ModelField
-except ModuleNotFoundError:
-    if not TYPE_CHECKING:
-        import pydantic.main as pydantic_main
-        from pydantic import BaseModel
-        from pydantic.fields import ModelField  # type: ignore
+import pydantic.v1.main as pydantic_main
+from pydantic.v1 import BaseModel
+from pydantic.v1.fields import ModelField
 
 
 
 
 def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:
 def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:

+ 1 - 5
reflex/components/core/cond.py

@@ -113,11 +113,7 @@ class Cond(MemoizationLeaf):
 
 
 
 
 @overload
 @overload
-def cond(condition: Any, c1: Component, c2: Any) -> Component: ...
-
-
-@overload
-def cond(condition: Any, c1: Component) -> Component: ...
+def cond(condition: Any, c1: Component, c2: Any = None) -> Component: ...
 
 
 
 
 @overload
 @overload

+ 40 - 42
reflex/components/core/match.py

@@ -1,15 +1,17 @@
 """rx.match."""
 """rx.match."""
 
 
 import textwrap
 import textwrap
-from typing import Any, List, Optional, Sequence, Tuple, Union
+from typing import Any, List, cast
 
 
 from reflex.components.base import Fragment
 from reflex.components.base import Fragment
 from reflex.components.component import BaseComponent, Component, MemoizationLeaf
 from reflex.components.component import BaseComponent, Component, MemoizationLeaf
 from reflex.utils import types
 from reflex.utils import types
 from reflex.utils.exceptions import MatchTypeError
 from reflex.utils.exceptions import MatchTypeError
-from reflex.vars.base import Var
+from reflex.vars.base import VAR_TYPE, Var
 from reflex.vars.number import MatchOperation
 from reflex.vars.number import MatchOperation
 
 
+CASE_TYPE = tuple[*tuple[Any, ...], Var[VAR_TYPE] | VAR_TYPE]
+
 
 
 class Match(MemoizationLeaf):
 class Match(MemoizationLeaf):
     """Match cases based on a condition."""
     """Match cases based on a condition."""
@@ -24,7 +26,11 @@ class Match(MemoizationLeaf):
     default: Any
     default: Any
 
 
     @classmethod
     @classmethod
-    def create(cls, cond: Any, *cases) -> Union[Component, Var]:
+    def create(
+        cls,
+        cond: Any,
+        *cases: *tuple[*tuple[CASE_TYPE[VAR_TYPE], ...], Var[VAR_TYPE] | VAR_TYPE],
+    ) -> Var[VAR_TYPE]:
         """Create a Match Component.
         """Create a Match Component.
 
 
         Args:
         Args:
@@ -37,10 +43,24 @@ class Match(MemoizationLeaf):
         Raises:
         Raises:
             ValueError: When a default case is not provided for cases with Var return types.
             ValueError: When a default case is not provided for cases with Var return types.
         """
         """
-        cases, default = cls._process_cases(cases)
-        cls._process_match_cases(cases)
+        default = None
 
 
-        cls._validate_return_types(cases)
+        if len([case for case in cases if not isinstance(case, tuple)]) > 1:
+            raise ValueError("rx.match can only have one default case.")
+
+        if not cases:
+            raise ValueError("rx.match should have at least one case.")
+
+        # Get the default case which should be the last non-tuple arg
+        if not isinstance(cases[-1], tuple):
+            default = cases[-1]
+            actual_cases = cases[:-1]
+        else:
+            actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases)
+
+        cls._process_match_cases(actual_cases)
+
+        cls._validate_return_types(actual_cases)
 
 
         if default is None and any(
         if default is None and any(
             not (
             not (
@@ -50,7 +70,7 @@ class Match(MemoizationLeaf):
                     and types.typehint_issubclass(return_type._var_type, Component)
                     and types.typehint_issubclass(return_type._var_type, Component)
                 )
                 )
             )
             )
-            for case in cases
+            for case in actual_cases
         ):
         ):
             raise ValueError(
             raise ValueError(
                 "For cases with return types as Vars, a default case must be provided"
                 "For cases with return types as Vars, a default case must be provided"
@@ -58,40 +78,16 @@ class Match(MemoizationLeaf):
         elif default is None:
         elif default is None:
             default = Fragment.create()
             default = Fragment.create()
 
 
-        return cls._create_match_cond_var_or_component(cond, cases, default)
-
-    @classmethod
-    def _process_cases(
-        cls, cases: Sequence
-    ) -> Tuple[List, Optional[Union[Var, BaseComponent]]]:
-        """Process the list of match cases and the catchall default case.
-
-        Args:
-            cases: The list of match cases.
-
-        Returns:
-            The default case and the list of match case tuples.
-
-        Raises:
-            ValueError: If there are multiple default cases.
-        """
-        default = None
-
-        if len([case for case in cases if not isinstance(case, tuple)]) > 1:
-            raise ValueError("rx.match can only have one default case.")
-
-        if not cases:
-            raise ValueError("rx.match should have at least one case.")
-
-        # Get the default case which should be the last non-tuple arg
-        if not isinstance(cases[-1], tuple):
-            default = cases[-1]
-            cases = cases[:-1]
+        default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
 
 
-        return list(cases), default
+        return cls._create_match_cond_var_or_component(
+            cond,
+            actual_cases,
+            default,
+        )
 
 
     @classmethod
     @classmethod
-    def _process_match_cases(cls, cases: Sequence):
+    def _process_match_cases(cls, cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
         """Process the individual match cases.
         """Process the individual match cases.
 
 
         Args:
         Args:
@@ -116,7 +112,9 @@ class Match(MemoizationLeaf):
                 )
                 )
 
 
     @classmethod
     @classmethod
-    def _validate_return_types(cls, match_cases: List[List[Var]]) -> None:
+    def _validate_return_types(
+        cls, match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]
+    ) -> None:
         """Validate that match cases have the same return types.
         """Validate that match cases have the same return types.
 
 
         Args:
         Args:
@@ -151,9 +149,9 @@ class Match(MemoizationLeaf):
     def _create_match_cond_var_or_component(
     def _create_match_cond_var_or_component(
         cls,
         cls,
         match_cond_var: Var,
         match_cond_var: Var,
-        match_cases: List[List[Var]],
-        default: Union[Var, BaseComponent],
-    ) -> Union[Component, Var]:
+        match_cases: tuple[CASE_TYPE[VAR_TYPE], ...],
+        default: VAR_TYPE | Var[VAR_TYPE],
+    ) -> Var[VAR_TYPE]:
         """Create and return the match condition var or component.
         """Create and return the match condition var or component.
 
 
         Args:
         Args:

+ 1 - 1
reflex/components/datadisplay/dataeditor.py

@@ -303,7 +303,7 @@ class DataEditor(NoSSRComponent):
 
 
     # Fired when editing is finished.
     # Fired when editing is finished.
     on_finished_editing: EventHandler[
     on_finished_editing: EventHandler[
-        passthrough_event_spec(Union[GridCell, None], tuple[int, int])
+        passthrough_event_spec(Union[GridCell, None], tuple[int, int])  # pyright: ignore[reportArgumentType]
     ]
     ]
 
 
     # Fired when a row is appended.
     # Fired when a row is appended.

+ 1 - 1
reflex/components/radix/primitives/accordion.py

@@ -197,7 +197,7 @@ class AccordionItem(AccordionComponent):
     # The header of the accordion item.
     # The header of the accordion item.
     header: Var[Union[Component, str]]
     header: Var[Union[Component, str]]
     # The content of the accordion item.
     # The content of the accordion item.
-    content: Var[Union[Component, str]] = Var.create(None)
+    content: Var[Union[Component, str]] = Var.create("")
 
 
     _valid_children: List[str] = [
     _valid_children: List[str] = [
         "AccordionHeader",
         "AccordionHeader",

+ 5 - 4
reflex/components/tags/iter_tag.py

@@ -4,9 +4,10 @@ from __future__ import annotations
 
 
 import dataclasses
 import dataclasses
 import inspect
 import inspect
-from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, Union, get_args
+from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Union, get_args
 
 
 from reflex.components.tags.tag import Tag
 from reflex.components.tags.tag import Tag
+from reflex.utils import types
 from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
 from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -31,7 +32,7 @@ class IterTag(Tag):
     # The name of the index var.
     # The name of the index var.
     index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
     index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
 
 
-    def get_iterable_var_type(self) -> Type:
+    def get_iterable_var_type(self) -> types.GenericType:
         """Get the type of the iterable var.
         """Get the type of the iterable var.
 
 
         Returns:
         Returns:
@@ -41,10 +42,10 @@ class IterTag(Tag):
         try:
         try:
             if iterable._var_type.mro()[0] is dict:
             if iterable._var_type.mro()[0] is dict:
                 # Arg is a tuple of (key, value).
                 # Arg is a tuple of (key, value).
-                return Tuple[get_args(iterable._var_type)]  # type: ignore
+                return Tuple[get_args(iterable._var_type)]
             elif iterable._var_type.mro()[0] is tuple:
             elif iterable._var_type.mro()[0] is tuple:
                 # Arg is a union of any possible values in the tuple.
                 # Arg is a union of any possible values in the tuple.
-                return Union[get_args(iterable._var_type)]  # type: ignore
+                return Union[get_args(iterable._var_type)]
             else:
             else:
                 return get_args(iterable._var_type)[0]
                 return get_args(iterable._var_type)[0]
         except Exception:
         except Exception:

+ 52 - 14
reflex/event.py

@@ -25,7 +25,6 @@ from typing import (
     overload,
     overload,
 )
 )
 
 
-import typing_extensions
 from typing_extensions import (
 from typing_extensions import (
     Concatenate,
     Concatenate,
     ParamSpec,
     ParamSpec,
@@ -33,6 +32,8 @@ from typing_extensions import (
     TypeAliasType,
     TypeAliasType,
     TypedDict,
     TypedDict,
     TypeVar,
     TypeVar,
+    TypeVarTuple,
+    deprecated,
     get_args,
     get_args,
     get_origin,
     get_origin,
 )
 )
@@ -620,14 +621,16 @@ stop_propagation = EventChain(events=[], args_spec=no_args_event_spec).stop_prop
 prevent_default = EventChain(events=[], args_spec=no_args_event_spec).prevent_default
 prevent_default = EventChain(events=[], args_spec=no_args_event_spec).prevent_default
 
 
 
 
-T = TypeVar("T")
-U = TypeVar("U")
+EVENT_T = TypeVar("EVENT_T")
+EVENT_U = TypeVar("EVENT_U")
+
+Ts = TypeVarTuple("Ts")
 
 
 
 
-class IdentityEventReturn(Generic[T], Protocol):
+class IdentityEventReturn(Generic[*Ts], Protocol):
     """Protocol for an identity event return."""
     """Protocol for an identity event return."""
 
 
-    def __call__(self, *values: Var[T]) -> Tuple[Var[T], ...]:
+    def __call__(self, *values: *Ts) -> tuple[*Ts]:
         """Return the input values.
         """Return the input values.
 
 
         Args:
         Args:
@@ -641,21 +644,25 @@ class IdentityEventReturn(Generic[T], Protocol):
 
 
 @overload
 @overload
 def passthrough_event_spec(
 def passthrough_event_spec(
-    event_type: Type[T], /
-) -> Callable[[Var[T]], Tuple[Var[T]]]: ...  # type: ignore
+    event_type: Type[EVENT_T], /
+) -> IdentityEventReturn[Var[EVENT_T]]: ...
 
 
 
 
 @overload
 @overload
 def passthrough_event_spec(
 def passthrough_event_spec(
-    event_type_1: Type[T], event_type2: Type[U], /
-) -> Callable[[Var[T], Var[U]], Tuple[Var[T], Var[U]]]: ...
+    event_type_1: Type[EVENT_T], event_type2: Type[EVENT_U], /
+) -> IdentityEventReturn[Var[EVENT_T], Var[EVENT_U]]: ...
 
 
 
 
 @overload
 @overload
-def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: ...
+def passthrough_event_spec(
+    *event_types: *tuple[Type[EVENT_T]],
+) -> IdentityEventReturn[*tuple[Var[EVENT_T], ...]]: ...
 
 
 
 
-def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]:  # type: ignore
+def passthrough_event_spec(  # pyright: ignore[reportInconsistentOverload]
+    *event_types: Type[EVENT_T],
+) -> IdentityEventReturn[*tuple[Var[EVENT_T], ...]]:
     """A helper function that returns the input event as output.
     """A helper function that returns the input event as output.
 
 
     Args:
     Args:
@@ -665,7 +672,7 @@ def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]:  #
         A function that returns the input event as output.
         A function that returns the input event as output.
     """
     """
 
 
-    def inner(*values: Var[T]) -> Tuple[Var[T], ...]:
+    def inner(*values: Var[EVENT_T]) -> Tuple[Var[EVENT_T], ...]:
         return values
         return values
 
 
     inner_type = tuple(Var[event_type] for event_type in event_types)
     inner_type = tuple(Var[event_type] for event_type in event_types)
@@ -800,7 +807,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec:
         return None
         return None
 
 
     fn.__qualname__ = name
     fn.__qualname__ = name
-    fn.__signature__ = sig
+    fn.__signature__ = sig  # pyright: ignore[reportFunctionMemberAccess]
     return EventSpec(
     return EventSpec(
         handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
         handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
         args=tuple(
         args=tuple(
@@ -822,7 +829,7 @@ def redirect(
 
 
 
 
 @overload
 @overload
-@typing_extensions.deprecated("`external` is deprecated use `is_external` instead")
+@deprecated("`external` is deprecated use `is_external` instead")
 def redirect(
 def redirect(
     path: str | Var[str],
     path: str | Var[str],
     is_external: Optional[bool] = None,
     is_external: Optional[bool] = None,
@@ -1826,6 +1833,37 @@ class EventCallback(Generic[P, T]):
         """
         """
         self.func = func
         self.func = func
 
 
+    def throttle(self, limit_ms: int):
+        """Throttle the event handler.
+
+        Args:
+            limit_ms: The time in milliseconds to throttle the event handler.
+
+        Returns:
+            New EventHandler-like with throttle set to limit_ms.
+        """
+        return self
+
+    def debounce(self, delay_ms: int):
+        """Debounce the event handler.
+
+        Args:
+            delay_ms: The time in milliseconds to debounce the event handler.
+
+        Returns:
+            New EventHandler-like with debounce set to delay_ms.
+        """
+        return self
+
+    @property
+    def temporal(self):
+        """Do not queue the event if the backend is down.
+
+        Returns:
+            New EventHandler-like with temporal set to True.
+        """
+        return self
+
     @property
     @property
     def prevent_default(self):
     def prevent_default(self):
         """Prevent default behavior.
         """Prevent default behavior.

+ 9 - 11
reflex/state.py

@@ -587,8 +587,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             if cls._item_is_event_handler(name, fn)
             if cls._item_is_event_handler(name, fn)
         }
         }
 
 
-        for mixin in cls._mixins():
-            for name, value in mixin.__dict__.items():
+        for mixin_class in cls._mixins():
+            for name, value in mixin_class.__dict__.items():
                 if name in cls.inherited_vars:
                 if name in cls.inherited_vars:
                     continue
                     continue
                 if is_computed_var(value):
                 if is_computed_var(value):
@@ -599,7 +599,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                     cls.computed_vars[newcv._js_expr] = newcv
                     cls.computed_vars[newcv._js_expr] = newcv
                     cls.vars[newcv._js_expr] = newcv
                     cls.vars[newcv._js_expr] = newcv
                     continue
                     continue
-                if types.is_backend_base_variable(name, mixin):
+                if types.is_backend_base_variable(name, mixin_class):
                     cls.backend_vars[name] = copy.deepcopy(value)
                     cls.backend_vars[name] = copy.deepcopy(value)
                     continue
                     continue
                 if events.get(name) is not None:
                 if events.get(name) is not None:
@@ -3710,6 +3710,9 @@ def get_state_manager() -> StateManager:
     return app.state_manager
     return app.state_manager
 
 
 
 
+DATACLASS_FIELDS = getattr(dataclasses, "_FIELDS", "__dataclass_fields__")
+
+
 class MutableProxy(wrapt.ObjectProxy):
 class MutableProxy(wrapt.ObjectProxy):
     """A proxy for a mutable object that tracks changes."""
     """A proxy for a mutable object that tracks changes."""
 
 
@@ -3781,12 +3784,7 @@ class MutableProxy(wrapt.ObjectProxy):
                 cls.__dataclass_proxies__[wrapper_cls_name] = type(
                 cls.__dataclass_proxies__[wrapper_cls_name] = type(
                     wrapper_cls_name,
                     wrapper_cls_name,
                     (cls,),
                     (cls,),
-                    {
-                        dataclasses._FIELDS: getattr(  # pyright: ignore [reportGeneralTypeIssues]
-                            wrapped_cls,
-                            dataclasses._FIELDS,  # pyright: ignore [reportGeneralTypeIssues]
-                        ),
-                    },
+                    {DATACLASS_FIELDS: getattr(wrapped_cls, DATACLASS_FIELDS)},
                 )
                 )
             cls = cls.__dataclass_proxies__[wrapper_cls_name]
             cls = cls.__dataclass_proxies__[wrapper_cls_name]
         return super().__new__(cls)
         return super().__new__(cls)
@@ -3933,11 +3931,11 @@ class MutableProxy(wrapt.ObjectProxy):
             if (
             if (
                 isinstance(self.__wrapped__, Base)
                 isinstance(self.__wrapped__, Base)
                 and __name not in self.__never_wrap_base_attrs__
                 and __name not in self.__never_wrap_base_attrs__
-                and hasattr(value, "__func__")
+                and (value_func := getattr(value, "__func__", None))
             ):
             ):
                 # Wrap methods called on Base subclasses, which might do _anything_
                 # Wrap methods called on Base subclasses, which might do _anything_
                 return wrapt.FunctionWrapper(
                 return wrapt.FunctionWrapper(
-                    functools.partial(value.__func__, self),
+                    functools.partial(value_func, self),
                     self._wrap_recursive_decorator,
                     self._wrap_recursive_decorator,
                 )
                 )
 
 

+ 6 - 4
reflex/testing.py

@@ -67,10 +67,8 @@ try:
         from selenium.webdriver.remote.webelement import (  # pyright: ignore [reportMissingImports]
         from selenium.webdriver.remote.webelement import (  # pyright: ignore [reportMissingImports]
             WebElement,
             WebElement,
         )
         )
-
-    has_selenium = True
 except ImportError:
 except ImportError:
-    has_selenium = False
+    webdriver = None
 
 
 # The timeout (minutes) to check for the port.
 # The timeout (minutes) to check for the port.
 DEFAULT_TIMEOUT = 15
 DEFAULT_TIMEOUT = 15
@@ -293,8 +291,12 @@ class AppHarness:
                 if p not in before_decorated_pages
                 if p not in before_decorated_pages
             ]
             ]
         self.app_instance = self.app_module.app
         self.app_instance = self.app_module.app
+        if self.app_instance is None:
+            raise RuntimeError("App was not initialized.")
         if isinstance(self.app_instance._state_manager, StateManagerRedis):
         if isinstance(self.app_instance._state_manager, StateManagerRedis):
             # Create our own redis connection for testing.
             # Create our own redis connection for testing.
+            if self.app_instance.state is None:
+                raise RuntimeError("App state is not initialized.")
             self.state_manager = StateManagerRedis.create(self.app_instance.state)
             self.state_manager = StateManagerRedis.create(self.app_instance.state)
         else:
         else:
             self.state_manager = self.app_instance._state_manager
             self.state_manager = self.app_instance._state_manager
@@ -608,7 +610,7 @@ class AppHarness:
         Raises:
         Raises:
             RuntimeError: when selenium is not importable or frontend is not running
             RuntimeError: when selenium is not importable or frontend is not running
         """
         """
-        if not has_selenium:
+        if webdriver is None:
             raise RuntimeError(
             raise RuntimeError(
                 "Frontend functionality requires `selenium` to be installed, "
                 "Frontend functionality requires `selenium` to be installed, "
                 "and it could not be imported."
                 "and it could not be imported."

+ 6 - 3
reflex/utils/console.py

@@ -203,10 +203,13 @@ def _get_first_non_framework_frame() -> FrameType | None:
     # Exclude utility modules that should never be the source of deprecated reflex usage.
     # Exclude utility modules that should never be the source of deprecated reflex usage.
     exclude_modules = [click, rx, typer, typing_extensions]
     exclude_modules = [click, rx, typer, typing_extensions]
     exclude_roots = [
     exclude_roots = [
-        p.parent.resolve()
-        if (p := Path(m.__file__)).name == "__init__.py"
-        else p.resolve()
+        (
+            p.parent.resolve()
+            if (p := Path(m.__file__)).name == "__init__.py"
+            else p.resolve()
+        )
         for m in exclude_modules
         for m in exclude_modules
+        if m.__file__
     ]
     ]
     # Specifically exclude the reflex cli module.
     # Specifically exclude the reflex cli module.
     if reflex_bin := shutil.which(b"reflex"):
     if reflex_bin := shutil.which(b"reflex"):

+ 6 - 3
reflex/vars/base.py

@@ -3197,17 +3197,20 @@ class Field(Generic[T]):
     def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
     def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
 
 
     @overload
     @overload
-    def __get__(self: Field[int], instance: None, owner) -> NumberVar: ...
+    def __get__(self: Field[int], instance: None, owner) -> NumberVar[int]: ...
 
 
     @overload
     @overload
-    def __get__(self: Field[str], instance: None, owner) -> StringVar: ...
+    def __get__(self: Field[float], instance: None, owner) -> NumberVar[float]: ...
+
+    @overload
+    def __get__(self: Field[str], instance: None, owner) -> StringVar[str]: ...
 
 
     @overload
     @overload
     def __get__(self: Field[None], instance: None, owner) -> NoneVar: ...
     def __get__(self: Field[None], instance: None, owner) -> NoneVar: ...
 
 
     @overload
     @overload
     def __get__(
     def __get__(
-        self: Field[Sequence[V]] | Field[Set[V]],
+        self: Field[Sequence[V]] | Field[Set[V]] | Field[List[V]],
         instance: None,
         instance: None,
         owner,
         owner,
     ) -> ArrayVar[Sequence[V]]: ...
     ) -> ArrayVar[Sequence[V]]: ...

+ 7 - 19
reflex/vars/number.py

@@ -1069,24 +1069,11 @@ def ternary_operation(
     return value
     return value
 
 
 
 
-TUPLE_ENDS_IN_VAR = (
-    tuple[Var[VAR_TYPE]]
-    | tuple[Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[
-        Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]
-    ]
-)
+X = tuple[*tuple[Var, ...], str]
+
+TUPLE_ENDS_IN_VAR = tuple[*tuple[Var[Any], ...], Var[VAR_TYPE]]
+
+TUPLE_ENDS_IN_VAR_RELAXED = tuple[*tuple[Var[Any] | Any, ...], Var[VAR_TYPE] | VAR_TYPE]
 
 
 
 
 @dataclasses.dataclass(
 @dataclasses.dataclass(
@@ -1153,7 +1140,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
     def create(
     def create(
         cls,
         cls,
         cond: Any,
         cond: Any,
-        cases: Sequence[Sequence[Any | Var[VAR_TYPE]]],
+        cases: Sequence[TUPLE_ENDS_IN_VAR_RELAXED[VAR_TYPE]],
         default: Var[VAR_TYPE] | VAR_TYPE,
         default: Var[VAR_TYPE] | VAR_TYPE,
         _var_data: VarData | None = None,
         _var_data: VarData | None = None,
         _var_type: type[VAR_TYPE] | None = None,
         _var_type: type[VAR_TYPE] | None = None,
@@ -1175,6 +1162,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
             tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...],
             tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...],
             tuple(tuple(Var.create(c) for c in case) for case in cases),
             tuple(tuple(Var.create(c) for c in case) for case in cases),
         )
         )
+
         _default = cast(Var[VAR_TYPE], Var.create(default))
         _default = cast(Var[VAR_TYPE], Var.create(default))
         var_type = _var_type or unionize(
         var_type = _var_type or unionize(
             *(case[-1]._var_type for case in cases),
             *(case[-1]._var_type for case in cases),

+ 3 - 2
reflex/vars/object.py

@@ -45,7 +45,7 @@ from .base import (
 from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
 from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
 from .sequence import ArrayVar, StringVar
 from .sequence import ArrayVar, StringVar
 
 
-OBJECT_TYPE = TypeVar("OBJECT_TYPE")
+OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True)
 
 
 KEY_TYPE = TypeVar("KEY_TYPE")
 KEY_TYPE = TypeVar("KEY_TYPE")
 VALUE_TYPE = TypeVar("VALUE_TYPE")
 VALUE_TYPE = TypeVar("VALUE_TYPE")
@@ -164,7 +164,8 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
 
 
     @overload
     @overload
     def __getitem__(
     def __getitem__(
-        self: ObjectVar[Dict[Any, Sequence[ARRAY_INNER_TYPE]]],
+        self: ObjectVar[Dict[Any, Sequence[ARRAY_INNER_TYPE]]]
+        | ObjectVar[Dict[Any, List[ARRAY_INNER_TYPE]]],
         key: Var | Any,
         key: Var | Any,
     ) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ...
     ) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ...
 
 

+ 2 - 0
tests/integration/test_event_actions.py

@@ -28,9 +28,11 @@ def TestEventAction():
         def on_click2(self):
         def on_click2(self):
             self.order.append("on_click2")
             self.order.append("on_click2")
 
 
+        @rx.event
         def on_click_throttle(self):
         def on_click_throttle(self):
             self.order.append("on_click_throttle")
             self.order.append("on_click_throttle")
 
 
+        @rx.event
         def on_click_debounce(self):
         def on_click_debounce(self):
             self.order.append("on_click_debounce")
             self.order.append("on_click_debounce")
 
 

+ 4 - 4
tests/integration/test_lifespan.py

@@ -22,9 +22,9 @@ def LifespanApp():
 
 
     @asynccontextmanager
     @asynccontextmanager
     async def lifespan_context(app, inc: int = 1):
     async def lifespan_context(app, inc: int = 1):
-        global lifespan_context_global
+        nonlocal lifespan_context_global
         print(f"Lifespan context entered: {app}.")
         print(f"Lifespan context entered: {app}.")
-        lifespan_context_global += inc  # pyright: ignore[reportUnboundVariable]
+        lifespan_context_global += inc
         try:
         try:
             yield
             yield
         finally:
         finally:
@@ -32,11 +32,11 @@ def LifespanApp():
             lifespan_context_global += inc
             lifespan_context_global += inc
 
 
     async def lifespan_task(inc: int = 1):
     async def lifespan_task(inc: int = 1):
-        global lifespan_task_global
+        nonlocal lifespan_task_global
         print("Lifespan global started.")
         print("Lifespan global started.")
         try:
         try:
             while True:
             while True:
-                lifespan_task_global += inc  # pyright: ignore[reportUnboundVariable]
+                lifespan_task_global += inc
                 await asyncio.sleep(0.1)
                 await asyncio.sleep(0.1)
         except asyncio.CancelledError as ce:
         except asyncio.CancelledError as ce:
             print(f"Lifespan global cancelled: {ce}.")
             print(f"Lifespan global cancelled: {ce}.")

+ 19 - 17
tests/integration/test_var_operations.py

@@ -19,25 +19,27 @@ def VarOperations():
     from reflex.vars.sequence import ArrayVar
     from reflex.vars.sequence import ArrayVar
 
 
     class Object(rx.Base):
     class Object(rx.Base):
-        str: str = "hello"
+        name: str = "hello"
 
 
     class VarOperationState(rx.State):
     class VarOperationState(rx.State):
-        int_var1: int = 10
-        int_var2: int = 5
-        int_var3: int = 7
-        float_var1: float = 10.5
-        float_var2: float = 5.5
-        list1: List = [1, 2]
-        list2: List = [3, 4]
-        list3: List = ["first", "second", "third"]
-        list4: List = [Object(name="obj_1"), Object(name="obj_2")]
-        str_var1: str = "first"
-        str_var2: str = "second"
-        str_var3: str = "ThIrD"
-        str_var4: str = "a long string"
-        dict1: Dict[int, int] = {1: 2}
-        dict2: Dict[int, int] = {3: 4}
-        html_str: str = "<div>hello</div>"
+        int_var1: rx.Field[int] = rx.field(10)
+        int_var2: rx.Field[int] = rx.field(5)
+        int_var3: rx.Field[int] = rx.field(7)
+        float_var1: rx.Field[float] = rx.field(10.5)
+        float_var2: rx.Field[float] = rx.field(5.5)
+        list1: rx.Field[List[int]] = rx.field([1, 2])
+        list2: rx.Field[List[int]] = rx.field([3, 4])
+        list3: rx.Field[List[str]] = rx.field(["first", "second", "third"])
+        list4: rx.Field[List[Object]] = rx.field(
+            [Object(name="obj_1"), Object(name="obj_2")]
+        )
+        str_var1: rx.Field[str] = rx.field("first")
+        str_var2: rx.Field[str] = rx.field("second")
+        str_var3: rx.Field[str] = rx.field("ThIrD")
+        str_var4: rx.Field[str] = rx.field("a long string")
+        dict1: rx.Field[Dict[int, int]] = rx.field({1: 2})
+        dict2: rx.Field[Dict[int, int]] = rx.field({3: 4})
+        html_str: rx.Field[str] = rx.field("<div>hello</div>")
 
 
     app = rx.App(state=rx.State)
     app = rx.App(state=rx.State)
 
 

+ 1 - 1
tests/units/components/core/test_cond.py

@@ -13,7 +13,7 @@ from reflex.vars.base import LiteralVar, Var, computed_var
 @pytest.fixture
 @pytest.fixture
 def cond_state(request):
 def cond_state(request):
     class CondState(BaseState):
     class CondState(BaseState):
-        value: request.param["value_type"] = request.param["value"]  # noqa
+        value: request.param["value_type"] = request.param["value"]  # pyright: ignore[reportInvalidTypeForm, reportUndefinedVariable]  # noqa: F821
 
 
     return CondState
     return CondState
 
 

+ 6 - 6
tests/units/components/core/test_match.py

@@ -67,7 +67,7 @@ def test_match_vars(cases, expected):
         cases: The match cases.
         cases: The match cases.
         expected: The expected var full name.
         expected: The expected var full name.
     """
     """
-    match_comp = Match.create(MatchState.value, *cases)
+    match_comp = Match.create(MatchState.value, *cases)  # pyright: ignore[reportCallIssue]
     assert isinstance(match_comp, Var)
     assert isinstance(match_comp, Var)
     assert str(match_comp) == expected
     assert str(match_comp) == expected
 
 
@@ -131,7 +131,7 @@ def test_match_default_not_last_arg(match_case):
         ValueError,
         ValueError,
         match="rx.match should have tuples of cases and a default case as the last argument.",
         match="rx.match should have tuples of cases and a default case as the last argument.",
     ):
     ):
-        Match.create(MatchState.value, *match_case)
+        Match.create(MatchState.value, *match_case)  # pyright: ignore[reportCallIssue]
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
@@ -161,7 +161,7 @@ def test_match_case_tuple_elements(match_case):
         ValueError,
         ValueError,
         match="A case tuple should have at least a match case element and a return value.",
         match="A case tuple should have at least a match case element and a return value.",
     ):
     ):
-        Match.create(MatchState.value, *match_case)
+        Match.create(MatchState.value, *match_case)  # pyright: ignore[reportCallIssue]
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
@@ -203,7 +203,7 @@ def test_match_different_return_types(cases: Tuple, error_msg: str):
         error_msg: Expected error message.
         error_msg: Expected error message.
     """
     """
     with pytest.raises(MatchTypeError, match=error_msg):
     with pytest.raises(MatchTypeError, match=error_msg):
-        Match.create(MatchState.value, *cases)
+        Match.create(MatchState.value, *cases)  # pyright: ignore[reportCallIssue]
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
@@ -235,9 +235,9 @@ def test_match_multiple_default_cases(match_case):
         match_case: the cases to match.
         match_case: the cases to match.
     """
     """
     with pytest.raises(ValueError, match="rx.match can only have one default case."):
     with pytest.raises(ValueError, match="rx.match can only have one default case."):
-        Match.create(MatchState.value, *match_case)
+        Match.create(MatchState.value, *match_case)  # pyright: ignore[reportCallIssue]
 
 
 
 
 def test_match_no_cond():
 def test_match_no_cond():
     with pytest.raises(ValueError):
     with pytest.raises(ValueError):
-        _ = Match.create(None)
+        _ = Match.create(None)  # pyright: ignore[reportCallIssue]

+ 4 - 2
tests/units/components/datadisplay/test_datatable.py

@@ -13,7 +13,8 @@ from reflex.utils.serializers import serialize, serialize_dataframe
         pytest.param(
         pytest.param(
             {
             {
                 "data": pd.DataFrame(
                 "data": pd.DataFrame(
-                    [["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"]
+                    [["foo", "bar"], ["foo1", "bar1"]],
+                    columns=["column1", "column2"],  # pyright: ignore [reportArgumentType]
                 )
                 )
             },
             },
             "data",
             "data",
@@ -113,7 +114,8 @@ def test_computed_var_without_annotation(fixture, request, err_msg, is_data_fram
 def test_serialize_dataframe():
 def test_serialize_dataframe():
     """Test if dataframe is serialized correctly."""
     """Test if dataframe is serialized correctly."""
     df = pd.DataFrame(
     df = pd.DataFrame(
-        [["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"]
+        [["foo", "bar"], ["foo1", "bar1"]],
+        columns=["column1", "column2"],  # pyright: ignore [reportArgumentType]
     )
     )
     value = serialize(df)
     value = serialize(df)
     assert value == serialize_dataframe(df)
     assert value == serialize_dataframe(df)

+ 5 - 5
tests/units/test_app.py

@@ -9,7 +9,7 @@ import unittest.mock
 import uuid
 import uuid
 from contextlib import nullcontext as does_not_raise
 from contextlib import nullcontext as does_not_raise
 from pathlib import Path
 from pathlib import Path
-from typing import Generator, List, Tuple, Type
+from typing import Generator, List, Tuple, Type, cast
 from unittest.mock import AsyncMock
 from unittest.mock import AsyncMock
 
 
 import pytest
 import pytest
@@ -33,7 +33,7 @@ from reflex.components import Component
 from reflex.components.base.fragment import Fragment
 from reflex.components.base.fragment import Fragment
 from reflex.components.core.cond import Cond
 from reflex.components.core.cond import Cond
 from reflex.components.radix.themes.typography.text import Text
 from reflex.components.radix.themes.typography.text import Text
-from reflex.event import Event
+from reflex.event import Event, EventHandler
 from reflex.middleware import HydrateMiddleware
 from reflex.middleware import HydrateMiddleware
 from reflex.model import Model
 from reflex.model import Model
 from reflex.state import (
 from reflex.state import (
@@ -917,7 +917,7 @@ class DynamicState(BaseState):
         """
         """
         return self.dynamic
         return self.dynamic
 
 
-    on_load_internal = OnLoadInternalState.on_load_internal.fn
+    on_load_internal = cast(EventHandler, OnLoadInternalState.on_load_internal).fn
 
 
 
 
 def test_dynamic_arg_shadow(
 def test_dynamic_arg_shadow(
@@ -1190,7 +1190,7 @@ async def test_process_events(mocker, token: str):
         pass
         pass
 
 
     assert (await app.state_manager.get_state(event.substate_token)).value == 5
     assert (await app.state_manager.get_state(event.substate_token)).value == 5
-    assert app._postprocess.call_count == 6
+    assert getattr(app._postprocess, "call_count", None) == 6
 
 
     if isinstance(app.state_manager, StateManagerRedis):
     if isinstance(app.state_manager, StateManagerRedis):
         await app.state_manager.close()
         await app.state_manager.close()
@@ -1247,7 +1247,7 @@ def test_overlay_component(
 
 
     if exp_page_child is not None:
     if exp_page_child is not None:
         assert len(page.children) == 3
         assert len(page.children) == 3
-        children_types = (type(child) for child in page.children)
+        children_types = [type(child) for child in page.children]
         assert exp_page_child in children_types
         assert exp_page_child in children_types
     else:
     else:
         assert len(page.children) == 2
         assert len(page.children) == 2

+ 3 - 0
tests/units/test_event.py

@@ -5,6 +5,7 @@ import pytest
 import reflex as rx
 import reflex as rx
 from reflex.event import (
 from reflex.event import (
     Event,
     Event,
+    EventActionsMixin,
     EventChain,
     EventChain,
     EventHandler,
     EventHandler,
     EventSpec,
     EventSpec,
@@ -410,6 +411,7 @@ def test_event_actions():
 
 
 def test_event_actions_on_state():
 def test_event_actions_on_state():
     class EventActionState(BaseState):
     class EventActionState(BaseState):
+        @rx.event
         def handler(self):
         def handler(self):
             pass
             pass
 
 
@@ -418,6 +420,7 @@ def test_event_actions_on_state():
     assert not handler.event_actions
     assert not handler.event_actions
 
 
     sp_handler = EventActionState.handler.stop_propagation
     sp_handler = EventActionState.handler.stop_propagation
+    assert isinstance(sp_handler, EventActionsMixin)
     assert sp_handler.event_actions == {"stopPropagation": True}
     assert sp_handler.event_actions == {"stopPropagation": True}
     # should NOT affect other references to the handler
     # should NOT affect other references to the handler
     assert not handler.event_actions
     assert not handler.event_actions

+ 5 - 2
tests/units/test_health_endpoint.py

@@ -122,9 +122,12 @@ async def test_health(
     # Call the async health function
     # Call the async health function
     response = await health()
     response = await health()
 
 
-    print(json.loads(response.body))
+    body = response.body
+    assert isinstance(body, bytes)
+
+    print(json.loads(body))
     print(expected_status)
     print(expected_status)
 
 
     # Verify the response content and status code
     # Verify the response content and status code
     assert response.status_code == expected_code
     assert response.status_code == expected_code
-    assert json.loads(response.body) == expected_status
+    assert json.loads(body) == expected_status

+ 4 - 4
tests/units/test_sqlalchemy.py

@@ -59,7 +59,7 @@ def test_automigration(
         id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)
         id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)
 
 
     # initial table
     # initial table
-    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+    class AlembicThing(ModelBase):  # pyright: ignore[reportRedeclaration]
         t1: Mapped[str] = mapped_column(default="")
         t1: Mapped[str] = mapped_column(default="")
 
 
     with Model.get_db_engine().connect() as connection:
     with Model.get_db_engine().connect() as connection:
@@ -78,7 +78,7 @@ def test_automigration(
     model_registry.get_metadata().clear()
     model_registry.get_metadata().clear()
 
 
     # Create column t2, mark t1 as optional with default
     # Create column t2, mark t1 as optional with default
-    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+    class AlembicThing(ModelBase):  # pyright: ignore[reportRedeclaration]
         t1: Mapped[Optional[str]] = mapped_column(default="default")
         t1: Mapped[Optional[str]] = mapped_column(default="default")
         t2: Mapped[str] = mapped_column(default="bar")
         t2: Mapped[str] = mapped_column(default="bar")
 
 
@@ -98,7 +98,7 @@ def test_automigration(
     model_registry.get_metadata().clear()
     model_registry.get_metadata().clear()
 
 
     # Drop column t1
     # Drop column t1
-    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+    class AlembicThing(ModelBase):  # pyright: ignore[reportRedeclaration]
         t2: Mapped[str] = mapped_column(default="bar")
         t2: Mapped[str] = mapped_column(default="bar")
 
 
     assert Model.migrate(autogenerate=True)
     assert Model.migrate(autogenerate=True)
@@ -133,7 +133,7 @@ def test_automigration(
     # drop table (AlembicSecond)
     # drop table (AlembicSecond)
     model_registry.get_metadata().clear()
     model_registry.get_metadata().clear()
 
 
-    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+    class AlembicThing(ModelBase):  # pyright: ignore[reportRedeclaration]
         t2: Mapped[str] = mapped_column(default="bar")
         t2: Mapped[str] = mapped_column(default="bar")
 
 
     assert Model.migrate(autogenerate=True)
     assert Model.migrate(autogenerate=True)

+ 20 - 11
tests/units/test_state.py

@@ -17,6 +17,7 @@ from typing import (
     Dict,
     Dict,
     List,
     List,
     Optional,
     Optional,
+    Sequence,
     Set,
     Set,
     Tuple,
     Tuple,
     Union,
     Union,
@@ -120,8 +121,8 @@ class TestState(BaseState):
     num2: float = 3.14
     num2: float = 3.14
     key: str
     key: str
     map_key: str = "a"
     map_key: str = "a"
-    array: List[float] = [1, 2, 3.14]
-    mapping: Dict[str, List[int]] = {"a": [1, 2, 3], "b": [4, 5, 6]}
+    array: rx.Field[List[float]] = rx.field([1, 2, 3.14])
+    mapping: rx.Field[Dict[str, List[int]]] = rx.field({"a": [1, 2, 3], "b": [4, 5, 6]})
     obj: Object = Object()
     obj: Object = Object()
     complex: Dict[int, Object] = {1: Object(), 2: Object()}
     complex: Dict[int, Object] = {1: Object(), 2: Object()}
     fig: Figure = Figure()
     fig: Figure = Figure()
@@ -1357,6 +1358,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
     class HandlerState(BaseState):
     class HandlerState(BaseState):
         x: int = 42
         x: int = 42
 
 
+        @rx.event
         def handler(self):
         def handler(self):
             self.x = self.x + 1
             self.x = self.x + 1
 
 
@@ -1367,11 +1369,11 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
             counter += 1
             counter += 1
             return counter
             return counter
 
 
+    assert isinstance(HandlerState.handler, EventHandler)
     if use_partial:
     if use_partial:
-        HandlerState.handler = functools.partial(HandlerState.handler.fn)
+        partial_guy = functools.partial(HandlerState.handler.fn)
+        HandlerState.handler = partial_guy  # pyright: ignore[reportAttributeAccessIssue]
         assert isinstance(HandlerState.handler, functools.partial)
         assert isinstance(HandlerState.handler, functools.partial)
-    else:
-        assert isinstance(HandlerState.handler, EventHandler)
 
 
     s = HandlerState()
     s = HandlerState()
     assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
     assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
@@ -2025,8 +2027,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
 
 
     # ensure state update was emitted
     # ensure state update was emitted
     assert mock_app.event_namespace is not None
     assert mock_app.event_namespace is not None
-    mock_app.event_namespace.emit.assert_called_once()
-    mcall = mock_app.event_namespace.emit.mock_calls[0]
+    mock_app.event_namespace.emit.assert_called_once()  # pyright: ignore[reportFunctionMemberAccess]
+    mock_calls = getattr(mock_app.event_namespace.emit, "mock_calls", None)
+    assert mock_calls is not None
+    assert isinstance(mock_calls, Sequence)
+    mcall = mock_calls[0]
     assert mcall.args[0] == str(SocketEvent.EVENT)
     assert mcall.args[0] == str(SocketEvent.EVENT)
     assert mcall.args[1] == StateUpdate(
     assert mcall.args[1] == StateUpdate(
         delta={
         delta={
@@ -2231,7 +2236,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
     assert mock_app.event_namespace is not None
     assert mock_app.event_namespace is not None
     emit_mock = mock_app.event_namespace.emit
     emit_mock = mock_app.event_namespace.emit
 
 
-    first_ws_message = emit_mock.mock_calls[0].args[1]
+    mock_calls = getattr(emit_mock, "mock_calls", None)
+    assert mock_calls is not None
+    assert isinstance(mock_calls, Sequence)
+
+    first_ws_message = mock_calls[0].args[1]
     assert (
     assert (
         first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
         first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
         is not None
         is not None
@@ -2246,7 +2255,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
         events=[],
         events=[],
         final=True,
         final=True,
     )
     )
-    for call in emit_mock.mock_calls[1:5]:
+    for call in mock_calls[1:5]:
         assert call.args[1] == StateUpdate(
         assert call.args[1] == StateUpdate(
             delta={
             delta={
                 BackgroundTaskState.get_full_name(): {
                 BackgroundTaskState.get_full_name(): {
@@ -2256,7 +2265,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
             events=[],
             events=[],
             final=True,
             final=True,
         )
         )
-    assert emit_mock.mock_calls[-2].args[1] == StateUpdate(
+    assert mock_calls[-2].args[1] == StateUpdate(
         delta={
         delta={
             BackgroundTaskState.get_full_name(): {
             BackgroundTaskState.get_full_name(): {
                 "order": exp_order,
                 "order": exp_order,
@@ -2267,7 +2276,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
         events=[],
         events=[],
         final=True,
         final=True,
     )
     )
-    assert emit_mock.mock_calls[-1].args[1] == StateUpdate(
+    assert mock_calls[-1].args[1] == StateUpdate(
         delta={
         delta={
             BackgroundTaskState.get_full_name(): {
             BackgroundTaskState.get_full_name(): {
                 "computed_order": exp_order,
                 "computed_order": exp_order,