Browse Source

precommit

Khaleel Al-Adhami 4 tháng trước cách đây
mục cha
commit
749577f0bc

+ 7 - 6
reflex/components/component.py

@@ -929,15 +929,16 @@ class Component(BaseComponent, ABC):
 
             valid_children = self._valid_children + allowed_components
 
-            def child_is_in_valid(child):
-                if type(child).__name__ in valid_children:
+            def child_is_in_valid(child_component: Any):
+                if type(child_component).__name__ in valid_children:
                     return True
 
                 if (
-                    not isinstance(child, Bare)
-                    or child.contents is None
-                    or not isinstance(child.contents, Var)
-                    or (var_data := child.contents._get_all_var_data()) is None
+                    not isinstance(child_component, Bare)
+                    or child_component.contents is None
+                    or not isinstance(child_component.contents, Var)
+                    or (var_data := child_component.contents._get_all_var_data())
+                    is None
                 ):
                     return False
 

+ 1 - 1
reflex/components/core/banner.py

@@ -4,8 +4,8 @@ from __future__ import annotations
 
 from typing import Optional
 
-from reflex.components.base.fragment import Fragment
 from reflex import constants
+from reflex.components.base.fragment import Fragment
 from reflex.components.component import Component
 from reflex.components.core.cond import cond
 from reflex.components.datadisplay.logo import svg_logo

+ 0 - 2
reflex/components/core/foreach.py

@@ -2,10 +2,8 @@
 
 from __future__ import annotations
 
-import functools
 from typing import Callable, Iterable
 
-from reflex.utils.exceptions import UntypedVarError
 from reflex.vars.base import LiteralVar, Var
 from reflex.vars.object import ObjectVar
 from reflex.vars.sequence import ArrayVar

+ 1 - 1
reflex/components/radix/themes/color_mode.py

@@ -139,7 +139,7 @@ class ColorModeIconButton(IconButton):
 
         if allow_system:
 
-            def color_mode_item(_color_mode: str):
+            def color_mode_item(_color_mode: Literal["light", "dark", "system"]):
                 return dropdown_menu.item(
                     _color_mode.title(), on_click=set_color_mode(_color_mode)
                 )

+ 3 - 1
reflex/utils/types.py

@@ -880,7 +880,9 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo
     ):
         return all(
             typehint_issubclass(subclass, superclass)
-            for subclass, superclass in zip(possible_subclass, possible_superclass)
+            for subclass, superclass in zip(
+                possible_subclass, possible_superclass, strict=False
+            )
         )
     if possible_subclass is possible_superclass:
         return True

+ 18 - 24
reflex/vars/base.py

@@ -13,7 +13,7 @@ import random
 import re
 import string
 import warnings
-from types import CodeType, FunctionType
+from types import CodeType, EllipsisType, FunctionType
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -54,7 +54,6 @@ from reflex.constants.compiler import Hooks
 from reflex.utils import console, exceptions, imports, serializers, types
 from reflex.utils.exceptions import (
     UntypedComputedVarError,
-    VarAttributeError,
     VarDependencyError,
     VarTypeError,
     VarValueError,
@@ -108,12 +107,7 @@ class ReflexCallable(Protocol[P, R]):
     __call__: Callable[P, R]
 
 
-if sys.version_info >= (3, 10):
-    from types import EllipsisType
-
-    ReflexCallableParams = Union[EllipsisType, Tuple[GenericType, ...]]
-else:
-    ReflexCallableParams = Union[Any, Tuple[GenericType, ...]]
+ReflexCallableParams = Union[EllipsisType, Tuple[GenericType, ...]]
 
 
 def unwrap_reflex_callalbe(
@@ -1336,10 +1330,15 @@ class Var(Generic[VAR_TYPE]):
         """
         from .sequence import ArrayVar
 
+        if step is None:
+            return ArrayVar.range(first_endpoint, second_endpoint)
+
         return ArrayVar.range(first_endpoint, second_endpoint, step)
 
-    def __bool__(self) -> bool:
-        """Raise exception if using Var in a boolean context.
+    if not TYPE_CHECKING:
+
+        def __bool__(self) -> bool:
+            """Raise exception if using Var in a boolean context.
 
             Raises:
                 VarTypeError: when attempting to bool-ify the Var.
@@ -1924,12 +1923,12 @@ def var_operation(
         _raw_js_function=custom_operation_return._raw_js_function,
         _original_var_operation=simplified_operation,
         _var_type=ReflexCallable[
-            tuple(
+            tuple(  # pyright: ignore [reportInvalidTypeArguments]
                 arg_python_type
                 if isinstance(arg_default_values[i], inspect.Parameter)
                 else VarWithDefault[arg_python_type]
                 for i, (_, arg_python_type) in enumerate(args_with_type_hints)
-            ),  # type: ignore
+            ),
             custom_operation_return._var_type,
         ],
     )
@@ -2049,11 +2048,6 @@ class CachedVarOperation:
 
 RETURN_TYPE = TypeVar("RETURN_TYPE")
 
-DICT_KEY = TypeVar("DICT_KEY")
-DICT_VAL = TypeVar("DICT_VAL")
-
-LIST_INSIDE = TypeVar("LIST_INSIDE")
-
 
 class FakeComputedVarBaseClass(property):
     """A fake base class for ComputedVar to avoid inheriting from property."""
@@ -2273,17 +2267,17 @@ class ComputedVar(Var[RETURN_TYPE]):
 
     @overload
     def __get__(
-        self: ComputedVar[Mapping[DICT_KEY, DICT_VAL]],
+        self: ComputedVar[MAPPING_TYPE],
         instance: None,
         owner: Type,
-    ) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...
+    ) -> ObjectVar[MAPPING_TYPE]: ...
 
     @overload
     def __get__(
-        self: ComputedVar[Sequence[LIST_INSIDE]],
+        self: ComputedVar[SEQUENCE_TYPE],
         instance: None,
         owner: Type,
-    ) -> ArrayVar[Sequence[LIST_INSIDE]]: ...
+    ) -> ArrayVar[SEQUENCE_TYPE]: ...
 
     @overload
     def __get__(self, instance: None, owner: Type) -> ComputedVar[RETURN_TYPE]: ...
@@ -2588,7 +2582,7 @@ RETURN = TypeVar("RETURN")
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class CustomVarOperationReturn(Var[RETURN]):
     """Base class for custom var operations."""
@@ -3202,7 +3196,7 @@ class Field(Generic[FIELD_TYPE]):
     def __get__(self: Field[int], instance: None, owner: Any) -> NumberVar[int]: ...
 
     @overload
-    def __get__(self: Field[float], instance: None, owner) -> NumberVar[float]: ...
+    def __get__(self: Field[float], instance: None, owner: Any) -> NumberVar[float]: ...
 
     @overload
     def __get__(self: Field[str], instance: None, owner: Any) -> StringVar[str]: ...
@@ -3251,7 +3245,7 @@ def field(value: FIELD_TYPE) -> Field[FIELD_TYPE]:
     Returns:
         The Field.
     """
-    return value  # type: ignore
+    return value  # pyright: ignore [reportReturnType]
 
 
 def and_operation(a: Var | Any, b: Var | Any) -> Var:

+ 6 - 7
reflex/vars/function.py

@@ -1193,7 +1193,6 @@ class FunctionVar(
     @overload
     def call(self: FunctionVar[NoReturn], *args: Var | Any) -> Var: ...
 
-    def call(self, *args: Var | Any) -> Var:  # pyright: ignore [reportInconsistentOverload]
     def call(self, *args: Var | Any) -> Var:  # pyright: ignore [reportInconsistentOverload]
         """Call the function with the given arguments.
 
@@ -1299,7 +1298,7 @@ class FunctionVar(
         """
         args_types, return_type = unwrap_reflex_callalbe(self._var_type)
         if isinstance(args_types, tuple):
-            return ReflexCallable[[*args_types[len(args) :]], return_type], None  # type: ignore
+            return ReflexCallable[[*args_types[len(args) :]], return_type], None
         return ReflexCallable[..., return_type], None
 
     def _arg_len(self) -> int | None:
@@ -1637,7 +1636,7 @@ def pre_check_args(
     Raises:
         VarTypeError: If the arguments are invalid.
     """
-    for i, (validator, arg) in enumerate(zip(self._validators, args)):
+    for i, (validator, arg) in enumerate(zip(self._validators, args, strict=False)):
         if (validation_message := validator(arg)) is not None:
             arg_name = self._args.args[i] if i < len(self._args.args) else None
             if arg_name is not None:
@@ -1694,9 +1693,9 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar[CALLABLE_TYPE]):
 
     _cached_var_name = cached_property_no_lock(format_args_function_operation)
 
-    _pre_check = pre_check_args  # type: ignore
+    _pre_check = pre_check_args
 
-    _partial_type = figure_partial_type  # type: ignore
+    _partial_type = figure_partial_type
 
     @classmethod
     def create(
@@ -1776,9 +1775,9 @@ class ArgsFunctionOperationBuilder(
 
     _cached_var_name = cached_property_no_lock(format_args_function_operation)
 
-    _pre_check = pre_check_args  # type: ignore
+    _pre_check = pre_check_args
 
-    _partial_type = figure_partial_type  # type: ignore
+    _partial_type = figure_partial_type
 
     @classmethod
     def create(

+ 1 - 1
reflex/vars/number.py

@@ -1080,7 +1080,7 @@ TUPLE_ENDS_IN_VAR_RELAXED = tuple[
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
     """Base class for immutable match operations."""

+ 0 - 1
reflex/vars/object.py

@@ -142,7 +142,6 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
 
     # NoReturn is used here to catch when key value is Any
     @overload
-    def __getitem__(  # pyright: ignore [reportOverlappingOverload]
     def __getitem__(  # pyright: ignore [reportOverlappingOverload]
         self: ObjectVar[Mapping[Any, NoReturn]],
         key: Var | Any,

+ 4 - 4
reflex/vars/sequence.py

@@ -773,7 +773,7 @@ def map_array_operation(
         type_computer=nary_type_computer(
             ReflexCallable[[List[Any], ReflexCallable], List[Any]],
             ReflexCallable[[ReflexCallable], List[Any]],
-            computer=lambda args: List[unwrap_reflex_callalbe(args[1]._var_type)[1]],  # type: ignore
+            computer=lambda args: List[unwrap_reflex_callalbe(args[1]._var_type)[1]],
         ),
     )
 
@@ -846,7 +846,7 @@ class SliceVar(Var[slice], python_types=slice):
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class LiteralSliceVar(CachedVarOperation, LiteralVar, SliceVar):
     """Base class for immutable literal slice vars."""
@@ -1245,7 +1245,7 @@ _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class LiteralStringVar(LiteralVar, StringVar[str]):
     """Base class for immutable literal string vars."""
@@ -1367,7 +1367,7 @@ class LiteralStringVar(LiteralVar, StringVar[str]):
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
-    **{"slots": True} if sys.version_info >= (3, 10) else {},
+    slots=True,
 )
 class ConcatVarOperation(CachedVarOperation, StringVar[str]):
     """Representing a concatenation of literal string vars."""

+ 2 - 2
tests/integration/test_lifespan.py

@@ -136,7 +136,7 @@ async def test_lifespan(lifespan_app: AppHarness):
     task_global = driver.find_element(By.ID, "task_global")
 
     assert context_global.text == "2"
-    assert lifespan_app.app_module.lifespan_context_global_getter() == 2  # type: ignore
+    assert lifespan_app.app_module.lifespan_context_global_getter() == 2
 
     original_task_global_text = task_global.text
     original_task_global_value = int(original_task_global_text)
@@ -145,7 +145,7 @@ async def test_lifespan(lifespan_app: AppHarness):
     assert (
         lifespan_app.app_module.lifespan_task_global_getter()
         > original_task_global_value
-    )  # type: ignore
+    )
     assert int(task_global.text) > original_task_global_value
 
     # Kill the backend

+ 2 - 2
tests/units/test_var.py

@@ -1249,11 +1249,11 @@ def test_type_chains():
         List[int],
     )
     assert (
-        str(object_var.keys()[0].upper())
+        str(object_var.keys()[0].upper())  # pyright: ignore [reportAttributeAccessIssue]
         == '(((...args) => (((_string) => String.prototype.toUpperCase.apply(_string))((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))), ...args)))(0)), ...args)))())'
     )
     assert (
-        str(object_var.entries()[1][1] - 1)
+        str(object_var.entries()[1][1] - 1)  # pyright: ignore [reportCallIssue, reportOperatorIssue]
         == '((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))), ...args)))(1)), ...args)))(1)) - 1)'
     )
     assert (