Browse Source

down to only two pyright error

Khaleel Al-Adhami 4 months ago
parent
commit
57d8ea02e9

+ 4 - 4
poetry.lock

@@ -2813,13 +2813,13 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)",
 
 [[package]]
 name = "virtualenv"
-version = "20.28.1"
+version = "20.29.1"
 description = "Virtual Python Environment builder"
 optional = false
 python-versions = ">=3.8"
 files = [
-    {file = "virtualenv-20.28.1-py3-none-any.whl", hash = "sha256:412773c85d4dab0409b83ec36f7a6499e72eaf08c80e81e9576bca61831c71cb"},
-    {file = "virtualenv-20.28.1.tar.gz", hash = "sha256:5d34ab240fdb5d21549b76f9e8ff3af28252f5499fb6d6f031adac4e5a8c5329"},
+    {file = "virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779"},
+    {file = "virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35"},
 ]
 
 [package.dependencies]
@@ -3063,4 +3063,4 @@ type = ["pytest-mypy"]
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.9"
-content-hash = "ac7633388e00416af61c93e16c61c3b4a446e406afab2758585088e2b3416eee"
+content-hash = "ccd6d6b00fdcf40562854380fafdb18c990b7f6a4f2883b33aaeb0351fcdbc06"

+ 2 - 1
pyproject.toml

@@ -23,7 +23,7 @@ fastapi = ">=0.96.0,!=0.111.0,!=0.111.1"
 gunicorn = ">=20.1.0,<24.0"
 jinja2 = ">=3.1.2,<4.0"
 psutil = ">=5.9.4,<7.0"
-pydantic = ">=1.10.2,<3.0"
+pydantic = ">=1.10.15,<3.0"
 python-multipart = ">=0.0.5,<0.1"
 python-socketio = ">=5.7.0,<6.0"
 redis = ">=4.3.5,<6.0"
@@ -82,6 +82,7 @@ build-backend = "poetry.core.masonry.api"
 
 [tool.pyright]
 reportIncompatibleMethodOverride = false
+reportIncompatibleVariableOverride = false
 
 [tool.ruff]
 target-version = "py39"

+ 3 - 9
reflex/base.py

@@ -5,15 +5,9 @@ from __future__ import annotations
 import os
 from typing import TYPE_CHECKING, Any, List, Type
 
-try:
-    import pydantic.v1.main as pydantic_main
-    from pydantic.v1 import BaseModel
-    from pydantic.v1.fields import ModelField
-except ModuleNotFoundError:
-    if not TYPE_CHECKING:
-        import pydantic.main as pydantic_main
-        from pydantic import BaseModel
-        from pydantic.fields import ModelField  # type: ignore
+import pydantic.v1.main as pydantic_main
+from pydantic.v1 import BaseModel
+from pydantic.v1.fields import ModelField
 
 
 def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:

+ 1 - 5
reflex/components/core/cond.py

@@ -113,11 +113,7 @@ class Cond(MemoizationLeaf):
 
 
 @overload
-def cond(condition: Any, c1: Component, c2: Any) -> Component: ...
-
-
-@overload
-def cond(condition: Any, c1: Component) -> Component: ...
+def cond(condition: Any, c1: Component, c2: Any = None) -> Component: ...
 
 
 @overload

+ 40 - 42
reflex/components/core/match.py

@@ -1,15 +1,17 @@
 """rx.match."""
 
 import textwrap
-from typing import Any, List, Optional, Sequence, Tuple, Union
+from typing import Any, List, cast
 
 from reflex.components.base import Fragment
 from reflex.components.component import BaseComponent, Component, MemoizationLeaf
 from reflex.utils import types
 from reflex.utils.exceptions import MatchTypeError
-from reflex.vars.base import Var
+from reflex.vars.base import VAR_TYPE, Var
 from reflex.vars.number import MatchOperation
 
+CASE_TYPE = tuple[*tuple[Any, ...], Var[VAR_TYPE] | VAR_TYPE]
+
 
 class Match(MemoizationLeaf):
     """Match cases based on a condition."""
@@ -24,7 +26,11 @@ class Match(MemoizationLeaf):
     default: Any
 
     @classmethod
-    def create(cls, cond: Any, *cases) -> Union[Component, Var]:
+    def create(
+        cls,
+        cond: Any,
+        *cases: *tuple[*tuple[CASE_TYPE[VAR_TYPE], ...], Var[VAR_TYPE] | VAR_TYPE],
+    ) -> Var[VAR_TYPE]:
         """Create a Match Component.
 
         Args:
@@ -37,10 +43,24 @@ class Match(MemoizationLeaf):
         Raises:
             ValueError: When a default case is not provided for cases with Var return types.
         """
-        cases, default = cls._process_cases(cases)
-        cls._process_match_cases(cases)
+        default = None
 
-        cls._validate_return_types(cases)
+        if len([case for case in cases if not isinstance(case, tuple)]) > 1:
+            raise ValueError("rx.match can only have one default case.")
+
+        if not cases:
+            raise ValueError("rx.match should have at least one case.")
+
+        # Get the default case which should be the last non-tuple arg
+        if not isinstance(cases[-1], tuple):
+            default = cases[-1]
+            actual_cases = cases[:-1]
+        else:
+            actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases)
+
+        cls._process_match_cases(actual_cases)
+
+        cls._validate_return_types(actual_cases)
 
         if default is None and any(
             not (
@@ -50,7 +70,7 @@ class Match(MemoizationLeaf):
                     and types.typehint_issubclass(return_type._var_type, Component)
                 )
             )
-            for case in cases
+            for case in actual_cases
         ):
             raise ValueError(
                 "For cases with return types as Vars, a default case must be provided"
@@ -58,40 +78,16 @@ class Match(MemoizationLeaf):
         elif default is None:
             default = Fragment.create()
 
-        return cls._create_match_cond_var_or_component(cond, cases, default)
-
-    @classmethod
-    def _process_cases(
-        cls, cases: Sequence
-    ) -> Tuple[List, Optional[Union[Var, BaseComponent]]]:
-        """Process the list of match cases and the catchall default case.
-
-        Args:
-            cases: The list of match cases.
-
-        Returns:
-            The default case and the list of match case tuples.
-
-        Raises:
-            ValueError: If there are multiple default cases.
-        """
-        default = None
-
-        if len([case for case in cases if not isinstance(case, tuple)]) > 1:
-            raise ValueError("rx.match can only have one default case.")
-
-        if not cases:
-            raise ValueError("rx.match should have at least one case.")
-
-        # Get the default case which should be the last non-tuple arg
-        if not isinstance(cases[-1], tuple):
-            default = cases[-1]
-            cases = cases[:-1]
+        default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
 
-        return list(cases), default
+        return cls._create_match_cond_var_or_component(
+            cond,
+            actual_cases,
+            default,
+        )
 
     @classmethod
-    def _process_match_cases(cls, cases: Sequence):
+    def _process_match_cases(cls, cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
         """Process the individual match cases.
 
         Args:
@@ -116,7 +112,9 @@ class Match(MemoizationLeaf):
                 )
 
     @classmethod
-    def _validate_return_types(cls, match_cases: List[List[Var]]) -> None:
+    def _validate_return_types(
+        cls, match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]
+    ) -> None:
         """Validate that match cases have the same return types.
 
         Args:
@@ -151,9 +149,9 @@ class Match(MemoizationLeaf):
     def _create_match_cond_var_or_component(
         cls,
         match_cond_var: Var,
-        match_cases: List[List[Var]],
-        default: Union[Var, BaseComponent],
-    ) -> Union[Component, Var]:
+        match_cases: tuple[CASE_TYPE[VAR_TYPE], ...],
+        default: VAR_TYPE | Var[VAR_TYPE],
+    ) -> Var[VAR_TYPE]:
         """Create and return the match condition var or component.
 
         Args:

+ 1 - 1
reflex/components/datadisplay/dataeditor.py

@@ -303,7 +303,7 @@ class DataEditor(NoSSRComponent):
 
     # Fired when editing is finished.
     on_finished_editing: EventHandler[
-        passthrough_event_spec(Union[GridCell, None], tuple[int, int])
+        passthrough_event_spec(Union[GridCell, None], tuple[int, int])  # pyright: ignore[reportArgumentType]
     ]
 
     # Fired when a row is appended.

+ 1 - 1
reflex/components/radix/primitives/accordion.py

@@ -197,7 +197,7 @@ class AccordionItem(AccordionComponent):
     # The header of the accordion item.
     header: Var[Union[Component, str]]
     # The content of the accordion item.
-    content: Var[Union[Component, str]] = Var.create(None)
+    content: Var[Union[Component, str]] = Var.create("")
 
     _valid_children: List[str] = [
         "AccordionHeader",

+ 5 - 4
reflex/components/tags/iter_tag.py

@@ -4,9 +4,10 @@ from __future__ import annotations
 
 import dataclasses
 import inspect
-from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, Union, get_args
+from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Union, get_args
 
 from reflex.components.tags.tag import Tag
+from reflex.utils import types
 from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
 
 if TYPE_CHECKING:
@@ -31,7 +32,7 @@ class IterTag(Tag):
     # The name of the index var.
     index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
 
-    def get_iterable_var_type(self) -> Type:
+    def get_iterable_var_type(self) -> types.GenericType:
         """Get the type of the iterable var.
 
         Returns:
@@ -41,10 +42,10 @@ class IterTag(Tag):
         try:
             if iterable._var_type.mro()[0] is dict:
                 # Arg is a tuple of (key, value).
-                return Tuple[get_args(iterable._var_type)]  # type: ignore
+                return Tuple[get_args(iterable._var_type)]
             elif iterable._var_type.mro()[0] is tuple:
                 # Arg is a union of any possible values in the tuple.
-                return Union[get_args(iterable._var_type)]  # type: ignore
+                return Union[get_args(iterable._var_type)]
             else:
                 return get_args(iterable._var_type)[0]
         except Exception:

+ 52 - 14
reflex/event.py

@@ -25,7 +25,6 @@ from typing import (
     overload,
 )
 
-import typing_extensions
 from typing_extensions import (
     Concatenate,
     ParamSpec,
@@ -33,6 +32,8 @@ from typing_extensions import (
     TypeAliasType,
     TypedDict,
     TypeVar,
+    TypeVarTuple,
+    deprecated,
     get_args,
     get_origin,
 )
@@ -620,14 +621,16 @@ stop_propagation = EventChain(events=[], args_spec=no_args_event_spec).stop_prop
 prevent_default = EventChain(events=[], args_spec=no_args_event_spec).prevent_default
 
 
-T = TypeVar("T")
-U = TypeVar("U")
+EVENT_T = TypeVar("EVENT_T")
+EVENT_U = TypeVar("EVENT_U")
+
+Ts = TypeVarTuple("Ts")
 
 
-class IdentityEventReturn(Generic[T], Protocol):
+class IdentityEventReturn(Generic[*Ts], Protocol):
     """Protocol for an identity event return."""
 
-    def __call__(self, *values: Var[T]) -> Tuple[Var[T], ...]:
+    def __call__(self, *values: *Ts) -> tuple[*Ts]:
         """Return the input values.
 
         Args:
@@ -641,21 +644,25 @@ class IdentityEventReturn(Generic[T], Protocol):
 
 @overload
 def passthrough_event_spec(
-    event_type: Type[T], /
-) -> Callable[[Var[T]], Tuple[Var[T]]]: ...  # type: ignore
+    event_type: Type[EVENT_T], /
+) -> IdentityEventReturn[Var[EVENT_T]]: ...
 
 
 @overload
 def passthrough_event_spec(
-    event_type_1: Type[T], event_type2: Type[U], /
-) -> Callable[[Var[T], Var[U]], Tuple[Var[T], Var[U]]]: ...
+    event_type_1: Type[EVENT_T], event_type2: Type[EVENT_U], /
+) -> IdentityEventReturn[Var[EVENT_T], Var[EVENT_U]]: ...
 
 
 @overload
-def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: ...
+def passthrough_event_spec(
+    *event_types: *tuple[Type[EVENT_T]],
+) -> IdentityEventReturn[*tuple[Var[EVENT_T], ...]]: ...
 
 
-def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]:  # type: ignore
+def passthrough_event_spec(  # pyright: ignore[reportInconsistentOverload]
+    *event_types: Type[EVENT_T],
+) -> IdentityEventReturn[*tuple[Var[EVENT_T], ...]]:
     """A helper function that returns the input event as output.
 
     Args:
@@ -665,7 +672,7 @@ def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]:  #
         A function that returns the input event as output.
     """
 
-    def inner(*values: Var[T]) -> Tuple[Var[T], ...]:
+    def inner(*values: Var[EVENT_T]) -> Tuple[Var[EVENT_T], ...]:
         return values
 
     inner_type = tuple(Var[event_type] for event_type in event_types)
@@ -800,7 +807,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec:
         return None
 
     fn.__qualname__ = name
-    fn.__signature__ = sig
+    fn.__signature__ = sig  # pyright: ignore[reportFunctionMemberAccess]
     return EventSpec(
         handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
         args=tuple(
@@ -822,7 +829,7 @@ def redirect(
 
 
 @overload
-@typing_extensions.deprecated("`external` is deprecated use `is_external` instead")
+@deprecated("`external` is deprecated use `is_external` instead")
 def redirect(
     path: str | Var[str],
     is_external: Optional[bool] = None,
@@ -1826,6 +1833,37 @@ class EventCallback(Generic[P, T]):
         """
         self.func = func
 
+    def throttle(self, limit_ms: int):
+        """Throttle the event handler.
+
+        Args:
+            limit_ms: The time in milliseconds to throttle the event handler.
+
+        Returns:
+            New EventHandler-like with throttle set to limit_ms.
+        """
+        return self
+
+    def debounce(self, delay_ms: int):
+        """Debounce the event handler.
+
+        Args:
+            delay_ms: The time in milliseconds to debounce the event handler.
+
+        Returns:
+            New EventHandler-like with debounce set to delay_ms.
+        """
+        return self
+
+    @property
+    def temporal(self):
+        """Do not queue the event if the backend is down.
+
+        Returns:
+            New EventHandler-like with temporal set to True.
+        """
+        return self
+
     @property
     def prevent_default(self):
         """Prevent default behavior.

+ 9 - 11
reflex/state.py

@@ -587,8 +587,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             if cls._item_is_event_handler(name, fn)
         }
 
-        for mixin in cls._mixins():
-            for name, value in mixin.__dict__.items():
+        for mixin_class in cls._mixins():
+            for name, value in mixin_class.__dict__.items():
                 if name in cls.inherited_vars:
                     continue
                 if is_computed_var(value):
@@ -599,7 +599,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                     cls.computed_vars[newcv._js_expr] = newcv
                     cls.vars[newcv._js_expr] = newcv
                     continue
-                if types.is_backend_base_variable(name, mixin):
+                if types.is_backend_base_variable(name, mixin_class):
                     cls.backend_vars[name] = copy.deepcopy(value)
                     continue
                 if events.get(name) is not None:
@@ -3710,6 +3710,9 @@ def get_state_manager() -> StateManager:
     return app.state_manager
 
 
+DATACLASS_FIELDS = getattr(dataclasses, "_FIELDS", "__dataclass_fields__")
+
+
 class MutableProxy(wrapt.ObjectProxy):
     """A proxy for a mutable object that tracks changes."""
 
@@ -3781,12 +3784,7 @@ class MutableProxy(wrapt.ObjectProxy):
                 cls.__dataclass_proxies__[wrapper_cls_name] = type(
                     wrapper_cls_name,
                     (cls,),
-                    {
-                        dataclasses._FIELDS: getattr(  # pyright: ignore [reportGeneralTypeIssues]
-                            wrapped_cls,
-                            dataclasses._FIELDS,  # pyright: ignore [reportGeneralTypeIssues]
-                        ),
-                    },
+                    {DATACLASS_FIELDS: getattr(wrapped_cls, DATACLASS_FIELDS)},
                 )
             cls = cls.__dataclass_proxies__[wrapper_cls_name]
         return super().__new__(cls)
@@ -3933,11 +3931,11 @@ class MutableProxy(wrapt.ObjectProxy):
             if (
                 isinstance(self.__wrapped__, Base)
                 and __name not in self.__never_wrap_base_attrs__
-                and hasattr(value, "__func__")
+                and (value_func := getattr(value, "__func__", None))
             ):
                 # Wrap methods called on Base subclasses, which might do _anything_
                 return wrapt.FunctionWrapper(
-                    functools.partial(value.__func__, self),
+                    functools.partial(value_func, self),
                     self._wrap_recursive_decorator,
                 )
 

+ 6 - 4
reflex/testing.py

@@ -67,10 +67,8 @@ try:
         from selenium.webdriver.remote.webelement import (  # pyright: ignore [reportMissingImports]
             WebElement,
         )
-
-    has_selenium = True
 except ImportError:
-    has_selenium = False
+    webdriver = None
 
 # The timeout (minutes) to check for the port.
 DEFAULT_TIMEOUT = 15
@@ -293,8 +291,12 @@ class AppHarness:
                 if p not in before_decorated_pages
             ]
         self.app_instance = self.app_module.app
+        if self.app_instance is None:
+            raise RuntimeError("App was not initialized.")
         if isinstance(self.app_instance._state_manager, StateManagerRedis):
             # Create our own redis connection for testing.
+            if self.app_instance.state is None:
+                raise RuntimeError("App state is not initialized.")
             self.state_manager = StateManagerRedis.create(self.app_instance.state)
         else:
             self.state_manager = self.app_instance._state_manager
@@ -608,7 +610,7 @@ class AppHarness:
         Raises:
             RuntimeError: when selenium is not importable or frontend is not running
         """
-        if not has_selenium:
+        if webdriver is None:
             raise RuntimeError(
                 "Frontend functionality requires `selenium` to be installed, "
                 "and it could not be imported."

+ 6 - 3
reflex/utils/console.py

@@ -203,10 +203,13 @@ def _get_first_non_framework_frame() -> FrameType | None:
     # Exclude utility modules that should never be the source of deprecated reflex usage.
     exclude_modules = [click, rx, typer, typing_extensions]
     exclude_roots = [
-        p.parent.resolve()
-        if (p := Path(m.__file__)).name == "__init__.py"
-        else p.resolve()
+        (
+            p.parent.resolve()
+            if (p := Path(m.__file__)).name == "__init__.py"
+            else p.resolve()
+        )
         for m in exclude_modules
+        if m.__file__
     ]
     # Specifically exclude the reflex cli module.
     if reflex_bin := shutil.which(b"reflex"):

+ 6 - 3
reflex/vars/base.py

@@ -3197,17 +3197,20 @@ class Field(Generic[T]):
     def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
 
     @overload
-    def __get__(self: Field[int], instance: None, owner) -> NumberVar: ...
+    def __get__(self: Field[int], instance: None, owner) -> NumberVar[int]: ...
 
     @overload
-    def __get__(self: Field[str], instance: None, owner) -> StringVar: ...
+    def __get__(self: Field[float], instance: None, owner) -> NumberVar[float]: ...
+
+    @overload
+    def __get__(self: Field[str], instance: None, owner) -> StringVar[str]: ...
 
     @overload
     def __get__(self: Field[None], instance: None, owner) -> NoneVar: ...
 
     @overload
     def __get__(
-        self: Field[Sequence[V]] | Field[Set[V]],
+        self: Field[Sequence[V]] | Field[Set[V]] | Field[List[V]],
         instance: None,
         owner,
     ) -> ArrayVar[Sequence[V]]: ...

+ 7 - 19
reflex/vars/number.py

@@ -1069,24 +1069,11 @@ def ternary_operation(
     return value
 
 
-TUPLE_ENDS_IN_VAR = (
-    tuple[Var[VAR_TYPE]]
-    | tuple[Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
-    | tuple[
-        Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]
-    ]
-)
+X = tuple[*tuple[Var, ...], str]
+
+TUPLE_ENDS_IN_VAR = tuple[*tuple[Var[Any], ...], Var[VAR_TYPE]]
+
+TUPLE_ENDS_IN_VAR_RELAXED = tuple[*tuple[Var[Any] | Any, ...], Var[VAR_TYPE] | VAR_TYPE]
 
 
 @dataclasses.dataclass(
@@ -1153,7 +1140,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
     def create(
         cls,
         cond: Any,
-        cases: Sequence[Sequence[Any | Var[VAR_TYPE]]],
+        cases: Sequence[TUPLE_ENDS_IN_VAR_RELAXED[VAR_TYPE]],
         default: Var[VAR_TYPE] | VAR_TYPE,
         _var_data: VarData | None = None,
         _var_type: type[VAR_TYPE] | None = None,
@@ -1175,6 +1162,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
             tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...],
             tuple(tuple(Var.create(c) for c in case) for case in cases),
         )
+
         _default = cast(Var[VAR_TYPE], Var.create(default))
         var_type = _var_type or unionize(
             *(case[-1]._var_type for case in cases),

+ 3 - 2
reflex/vars/object.py

@@ -45,7 +45,7 @@ from .base import (
 from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
 from .sequence import ArrayVar, StringVar
 
-OBJECT_TYPE = TypeVar("OBJECT_TYPE")
+OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True)
 
 KEY_TYPE = TypeVar("KEY_TYPE")
 VALUE_TYPE = TypeVar("VALUE_TYPE")
@@ -164,7 +164,8 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
 
     @overload
     def __getitem__(
-        self: ObjectVar[Dict[Any, Sequence[ARRAY_INNER_TYPE]]],
+        self: ObjectVar[Dict[Any, Sequence[ARRAY_INNER_TYPE]]]
+        | ObjectVar[Dict[Any, List[ARRAY_INNER_TYPE]]],
         key: Var | Any,
     ) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ...
 

+ 2 - 0
tests/integration/test_event_actions.py

@@ -28,9 +28,11 @@ def TestEventAction():
         def on_click2(self):
             self.order.append("on_click2")
 
+        @rx.event
         def on_click_throttle(self):
             self.order.append("on_click_throttle")
 
+        @rx.event
         def on_click_debounce(self):
             self.order.append("on_click_debounce")
 

+ 4 - 4
tests/integration/test_lifespan.py

@@ -22,9 +22,9 @@ def LifespanApp():
 
     @asynccontextmanager
     async def lifespan_context(app, inc: int = 1):
-        global lifespan_context_global
+        nonlocal lifespan_context_global
         print(f"Lifespan context entered: {app}.")
-        lifespan_context_global += inc  # pyright: ignore[reportUnboundVariable]
+        lifespan_context_global += inc
         try:
             yield
         finally:
@@ -32,11 +32,11 @@ def LifespanApp():
             lifespan_context_global += inc
 
     async def lifespan_task(inc: int = 1):
-        global lifespan_task_global
+        nonlocal lifespan_task_global
         print("Lifespan global started.")
         try:
             while True:
-                lifespan_task_global += inc  # pyright: ignore[reportUnboundVariable]
+                lifespan_task_global += inc
                 await asyncio.sleep(0.1)
         except asyncio.CancelledError as ce:
             print(f"Lifespan global cancelled: {ce}.")

+ 19 - 17
tests/integration/test_var_operations.py

@@ -19,25 +19,27 @@ def VarOperations():
     from reflex.vars.sequence import ArrayVar
 
     class Object(rx.Base):
-        str: str = "hello"
+        name: str = "hello"
 
     class VarOperationState(rx.State):
-        int_var1: int = 10
-        int_var2: int = 5
-        int_var3: int = 7
-        float_var1: float = 10.5
-        float_var2: float = 5.5
-        list1: List = [1, 2]
-        list2: List = [3, 4]
-        list3: List = ["first", "second", "third"]
-        list4: List = [Object(name="obj_1"), Object(name="obj_2")]
-        str_var1: str = "first"
-        str_var2: str = "second"
-        str_var3: str = "ThIrD"
-        str_var4: str = "a long string"
-        dict1: Dict[int, int] = {1: 2}
-        dict2: Dict[int, int] = {3: 4}
-        html_str: str = "<div>hello</div>"
+        int_var1: rx.Field[int] = rx.field(10)
+        int_var2: rx.Field[int] = rx.field(5)
+        int_var3: rx.Field[int] = rx.field(7)
+        float_var1: rx.Field[float] = rx.field(10.5)
+        float_var2: rx.Field[float] = rx.field(5.5)
+        list1: rx.Field[List[int]] = rx.field([1, 2])
+        list2: rx.Field[List[int]] = rx.field([3, 4])
+        list3: rx.Field[List[str]] = rx.field(["first", "second", "third"])
+        list4: rx.Field[List[Object]] = rx.field(
+            [Object(name="obj_1"), Object(name="obj_2")]
+        )
+        str_var1: rx.Field[str] = rx.field("first")
+        str_var2: rx.Field[str] = rx.field("second")
+        str_var3: rx.Field[str] = rx.field("ThIrD")
+        str_var4: rx.Field[str] = rx.field("a long string")
+        dict1: rx.Field[Dict[int, int]] = rx.field({1: 2})
+        dict2: rx.Field[Dict[int, int]] = rx.field({3: 4})
+        html_str: rx.Field[str] = rx.field("<div>hello</div>")
 
     app = rx.App(state=rx.State)
 

+ 1 - 1
tests/units/components/core/test_cond.py

@@ -13,7 +13,7 @@ from reflex.vars.base import LiteralVar, Var, computed_var
 @pytest.fixture
 def cond_state(request):
     class CondState(BaseState):
-        value: request.param["value_type"] = request.param["value"]  # noqa
+        value: request.param["value_type"] = request.param["value"]  # pyright: ignore[reportInvalidTypeForm, reportUndefinedVariable]  # noqa: F821
 
     return CondState
 

+ 6 - 6
tests/units/components/core/test_match.py

@@ -67,7 +67,7 @@ def test_match_vars(cases, expected):
         cases: The match cases.
         expected: The expected var full name.
     """
-    match_comp = Match.create(MatchState.value, *cases)
+    match_comp = Match.create(MatchState.value, *cases)  # pyright: ignore[reportCallIssue]
     assert isinstance(match_comp, Var)
     assert str(match_comp) == expected
 
@@ -131,7 +131,7 @@ def test_match_default_not_last_arg(match_case):
         ValueError,
         match="rx.match should have tuples of cases and a default case as the last argument.",
     ):
-        Match.create(MatchState.value, *match_case)
+        Match.create(MatchState.value, *match_case)  # pyright: ignore[reportCallIssue]
 
 
 @pytest.mark.parametrize(
@@ -161,7 +161,7 @@ def test_match_case_tuple_elements(match_case):
         ValueError,
         match="A case tuple should have at least a match case element and a return value.",
     ):
-        Match.create(MatchState.value, *match_case)
+        Match.create(MatchState.value, *match_case)  # pyright: ignore[reportCallIssue]
 
 
 @pytest.mark.parametrize(
@@ -203,7 +203,7 @@ def test_match_different_return_types(cases: Tuple, error_msg: str):
         error_msg: Expected error message.
     """
     with pytest.raises(MatchTypeError, match=error_msg):
-        Match.create(MatchState.value, *cases)
+        Match.create(MatchState.value, *cases)  # pyright: ignore[reportCallIssue]
 
 
 @pytest.mark.parametrize(
@@ -235,9 +235,9 @@ def test_match_multiple_default_cases(match_case):
         match_case: the cases to match.
     """
     with pytest.raises(ValueError, match="rx.match can only have one default case."):
-        Match.create(MatchState.value, *match_case)
+        Match.create(MatchState.value, *match_case)  # pyright: ignore[reportCallIssue]
 
 
 def test_match_no_cond():
     with pytest.raises(ValueError):
-        _ = Match.create(None)
+        _ = Match.create(None)  # pyright: ignore[reportCallIssue]

+ 4 - 2
tests/units/components/datadisplay/test_datatable.py

@@ -13,7 +13,8 @@ from reflex.utils.serializers import serialize, serialize_dataframe
         pytest.param(
             {
                 "data": pd.DataFrame(
-                    [["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"]
+                    [["foo", "bar"], ["foo1", "bar1"]],
+                    columns=["column1", "column2"],  # pyright: ignore [reportArgumentType]
                 )
             },
             "data",
@@ -113,7 +114,8 @@ def test_computed_var_without_annotation(fixture, request, err_msg, is_data_fram
 def test_serialize_dataframe():
     """Test if dataframe is serialized correctly."""
     df = pd.DataFrame(
-        [["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"]
+        [["foo", "bar"], ["foo1", "bar1"]],
+        columns=["column1", "column2"],  # pyright: ignore [reportArgumentType]
     )
     value = serialize(df)
     assert value == serialize_dataframe(df)

+ 5 - 5
tests/units/test_app.py

@@ -9,7 +9,7 @@ import unittest.mock
 import uuid
 from contextlib import nullcontext as does_not_raise
 from pathlib import Path
-from typing import Generator, List, Tuple, Type
+from typing import Generator, List, Tuple, Type, cast
 from unittest.mock import AsyncMock
 
 import pytest
@@ -33,7 +33,7 @@ from reflex.components import Component
 from reflex.components.base.fragment import Fragment
 from reflex.components.core.cond import Cond
 from reflex.components.radix.themes.typography.text import Text
-from reflex.event import Event
+from reflex.event import Event, EventHandler
 from reflex.middleware import HydrateMiddleware
 from reflex.model import Model
 from reflex.state import (
@@ -917,7 +917,7 @@ class DynamicState(BaseState):
         """
         return self.dynamic
 
-    on_load_internal = OnLoadInternalState.on_load_internal.fn
+    on_load_internal = cast(EventHandler, OnLoadInternalState.on_load_internal).fn
 
 
 def test_dynamic_arg_shadow(
@@ -1190,7 +1190,7 @@ async def test_process_events(mocker, token: str):
         pass
 
     assert (await app.state_manager.get_state(event.substate_token)).value == 5
-    assert app._postprocess.call_count == 6
+    assert getattr(app._postprocess, "call_count", None) == 6
 
     if isinstance(app.state_manager, StateManagerRedis):
         await app.state_manager.close()
@@ -1247,7 +1247,7 @@ def test_overlay_component(
 
     if exp_page_child is not None:
         assert len(page.children) == 3
-        children_types = (type(child) for child in page.children)
+        children_types = [type(child) for child in page.children]
         assert exp_page_child in children_types
     else:
         assert len(page.children) == 2

+ 3 - 0
tests/units/test_event.py

@@ -5,6 +5,7 @@ import pytest
 import reflex as rx
 from reflex.event import (
     Event,
+    EventActionsMixin,
     EventChain,
     EventHandler,
     EventSpec,
@@ -410,6 +411,7 @@ def test_event_actions():
 
 def test_event_actions_on_state():
     class EventActionState(BaseState):
+        @rx.event
         def handler(self):
             pass
 
@@ -418,6 +420,7 @@ def test_event_actions_on_state():
     assert not handler.event_actions
 
     sp_handler = EventActionState.handler.stop_propagation
+    assert isinstance(sp_handler, EventActionsMixin)
     assert sp_handler.event_actions == {"stopPropagation": True}
     # should NOT affect other references to the handler
     assert not handler.event_actions

+ 5 - 2
tests/units/test_health_endpoint.py

@@ -122,9 +122,12 @@ async def test_health(
     # Call the async health function
     response = await health()
 
-    print(json.loads(response.body))
+    body = response.body
+    assert isinstance(body, bytes)
+
+    print(json.loads(body))
     print(expected_status)
 
     # Verify the response content and status code
     assert response.status_code == expected_code
-    assert json.loads(response.body) == expected_status
+    assert json.loads(body) == expected_status

+ 4 - 4
tests/units/test_sqlalchemy.py

@@ -59,7 +59,7 @@ def test_automigration(
         id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)
 
     # initial table
-    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+    class AlembicThing(ModelBase):  # pyright: ignore[reportRedeclaration]
         t1: Mapped[str] = mapped_column(default="")
 
     with Model.get_db_engine().connect() as connection:
@@ -78,7 +78,7 @@ def test_automigration(
     model_registry.get_metadata().clear()
 
     # Create column t2, mark t1 as optional with default
-    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+    class AlembicThing(ModelBase):  # pyright: ignore[reportRedeclaration]
         t1: Mapped[Optional[str]] = mapped_column(default="default")
         t2: Mapped[str] = mapped_column(default="bar")
 
@@ -98,7 +98,7 @@ def test_automigration(
     model_registry.get_metadata().clear()
 
     # Drop column t1
-    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+    class AlembicThing(ModelBase):  # pyright: ignore[reportRedeclaration]
         t2: Mapped[str] = mapped_column(default="bar")
 
     assert Model.migrate(autogenerate=True)
@@ -133,7 +133,7 @@ def test_automigration(
     # drop table (AlembicSecond)
     model_registry.get_metadata().clear()
 
-    class AlembicThing(ModelBase):  # pyright: ignore[reportGeneralTypeIssues]
+    class AlembicThing(ModelBase):  # pyright: ignore[reportRedeclaration]
         t2: Mapped[str] = mapped_column(default="bar")
 
     assert Model.migrate(autogenerate=True)

+ 20 - 11
tests/units/test_state.py

@@ -17,6 +17,7 @@ from typing import (
     Dict,
     List,
     Optional,
+    Sequence,
     Set,
     Tuple,
     Union,
@@ -120,8 +121,8 @@ class TestState(BaseState):
     num2: float = 3.14
     key: str
     map_key: str = "a"
-    array: List[float] = [1, 2, 3.14]
-    mapping: Dict[str, List[int]] = {"a": [1, 2, 3], "b": [4, 5, 6]}
+    array: rx.Field[List[float]] = rx.field([1, 2, 3.14])
+    mapping: rx.Field[Dict[str, List[int]]] = rx.field({"a": [1, 2, 3], "b": [4, 5, 6]})
     obj: Object = Object()
     complex: Dict[int, Object] = {1: Object(), 2: Object()}
     fig: Figure = Figure()
@@ -1357,6 +1358,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
     class HandlerState(BaseState):
         x: int = 42
 
+        @rx.event
         def handler(self):
             self.x = self.x + 1
 
@@ -1367,11 +1369,11 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
             counter += 1
             return counter
 
+    assert isinstance(HandlerState.handler, EventHandler)
     if use_partial:
-        HandlerState.handler = functools.partial(HandlerState.handler.fn)
+        partial_guy = functools.partial(HandlerState.handler.fn)
+        HandlerState.handler = partial_guy  # pyright: ignore[reportAttributeAccessIssue]
         assert isinstance(HandlerState.handler, functools.partial)
-    else:
-        assert isinstance(HandlerState.handler, EventHandler)
 
     s = HandlerState()
     assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
@@ -2025,8 +2027,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
 
     # ensure state update was emitted
     assert mock_app.event_namespace is not None
-    mock_app.event_namespace.emit.assert_called_once()
-    mcall = mock_app.event_namespace.emit.mock_calls[0]
+    mock_app.event_namespace.emit.assert_called_once()  # pyright: ignore[reportFunctionMemberAccess]
+    mock_calls = getattr(mock_app.event_namespace.emit, "mock_calls", None)
+    assert mock_calls is not None
+    assert isinstance(mock_calls, Sequence)
+    mcall = mock_calls[0]
     assert mcall.args[0] == str(SocketEvent.EVENT)
     assert mcall.args[1] == StateUpdate(
         delta={
@@ -2231,7 +2236,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
     assert mock_app.event_namespace is not None
     emit_mock = mock_app.event_namespace.emit
 
-    first_ws_message = emit_mock.mock_calls[0].args[1]
+    mock_calls = getattr(emit_mock, "mock_calls", None)
+    assert mock_calls is not None
+    assert isinstance(mock_calls, Sequence)
+
+    first_ws_message = mock_calls[0].args[1]
     assert (
         first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
         is not None
@@ -2246,7 +2255,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
         events=[],
         final=True,
     )
-    for call in emit_mock.mock_calls[1:5]:
+    for call in mock_calls[1:5]:
         assert call.args[1] == StateUpdate(
             delta={
                 BackgroundTaskState.get_full_name(): {
@@ -2256,7 +2265,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
             events=[],
             final=True,
         )
-    assert emit_mock.mock_calls[-2].args[1] == StateUpdate(
+    assert mock_calls[-2].args[1] == StateUpdate(
         delta={
             BackgroundTaskState.get_full_name(): {
                 "order": exp_order,
@@ -2267,7 +2276,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
         events=[],
         final=True,
     )
-    assert emit_mock.mock_calls[-1].args[1] == StateUpdate(
+    assert mock_calls[-1].args[1] == StateUpdate(
         delta={
             BackgroundTaskState.get_full_name(): {
                 "computed_order": exp_order,