浏览代码

solve some but not all pyright issues

Khaleel Al-Adhami 4 月之前
父节点
当前提交
112b2ed948
共有 3 个文件被更改,包括 53 次插入24 次删除
  1. 30 12
      reflex/utils/pyi_generator.py
  2. 14 7
      reflex/utils/types.py
  3. 9 5
      reflex/vars/number.py

+ 30 - 12
reflex/utils/pyi_generator.py

@@ -16,7 +16,7 @@ from itertools import chain
 from multiprocessing import Pool, cpu_count
 from pathlib import Path
 from types import ModuleType, SimpleNamespace
-from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin
+from typing import Any, Callable, Iterable, Sequence, Type, cast, get_args, get_origin
 
 from reflex.components.component import Component
 from reflex.utils import types as rx_types
@@ -229,7 +229,9 @@ def _generate_imports(
     """
     return [
         *[
-            ast.ImportFrom(module=name, names=[ast.alias(name=val) for val in values])
+            ast.ImportFrom(
+                module=name, names=[ast.alias(name=val) for val in values], level=0
+            )
             for name, values in DEFAULT_IMPORTS.items()
         ],
         ast.Import([ast.alias("reflex")]),
@@ -428,16 +430,15 @@ def type_to_ast(typ, cls: type) -> ast.AST:
         return ast.Name(id=base_name)
 
     # Convert all type arguments recursively
-    arg_nodes = [type_to_ast(arg, cls) for arg in args]
+    arg_nodes = cast(list[ast.expr], [type_to_ast(arg, cls) for arg in args])
 
     # Special case for single-argument types (like List[T] or Optional[T])
     if len(arg_nodes) == 1:
         slice_value = arg_nodes[0]
     else:
         slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load())
-
     return ast.Subscript(
-        value=ast.Name(id=base_name), slice=ast.Index(value=slice_value), ctx=ast.Load()
+        value=ast.Name(id=base_name), slice=slice_value, ctx=ast.Load()
     )
 
 
@@ -630,7 +631,7 @@ def _generate_component_create_functiondef(
                 ),
             ),
             ast.Expr(
-                value=ast.Ellipsis(),
+                value=ast.Constant(...),
             ),
         ],
         decorator_list=[
@@ -641,8 +642,14 @@ def _generate_component_create_functiondef(
                 else [ast.Name(id="classmethod")]
             ),
         ],
-        lineno=node.lineno if node is not None else None,
         returns=ast.Constant(value=clz.__name__),
+        **(
+            {
+                "lineno": node.lineno,
+            }
+            if node is not None
+            else {}
+        ),
     )
     return definition
 
@@ -690,13 +697,19 @@ def _generate_staticmethod_call_functiondef(
             ),
         ],
         decorator_list=[ast.Name(id="staticmethod")],
-        lineno=node.lineno if node is not None else None,
         returns=ast.Constant(
             value=_get_type_hint(
                 typing.get_type_hints(clz.__call__).get("return", None),
                 type_hint_globals,
             )
         ),
+        **(
+            {
+                "lineno": node.lineno,
+            }
+            if node is not None
+            else {}
+        ),
     )
     return definition
 
@@ -731,7 +744,12 @@ def _generate_namespace_call_functiondef(
     # Determine which class is wrapped by the namespace __call__ method
     component_clz = clz.__call__.__self__
 
-    if clz.__call__.__func__.__name__ != "create":
+    func = getattr(clz.__call__, "__func__", None)
+
+    if func is None:
+        raise TypeError(f"__call__ method on {clz_name} does not have a __func__")
+
+    if func.__name__ != "create":
         return None
 
     definition = _generate_component_create_functiondef(
@@ -914,7 +932,7 @@ class StubGenerator(ast.NodeTransformer):
             node.body.append(call_definition)
         if not node.body:
             # We should never return an empty body.
-            node.body.append(ast.Expr(value=ast.Ellipsis()))
+            node.body.append(ast.Expr(value=ast.Constant(...)))
         self.current_class = None
         return node
 
@@ -941,9 +959,9 @@ class StubGenerator(ast.NodeTransformer):
             if node.name.startswith("_") and node.name != "__call__":
                 return None  # remove private methods
 
-            if node.body[-1] != ast.Expr(value=ast.Ellipsis()):
+            if node.body[-1] != ast.Expr(value=ast.Constant(...)):
                 # Blank out the function body for public functions.
-                node.body = [ast.Expr(value=ast.Ellipsis())]
+                node.body = [ast.Expr(value=ast.Constant(...))]
         return node
 
     def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:

+ 14 - 7
reflex/utils/types.py

@@ -69,21 +69,21 @@ else:
 
 
 # Potential GenericAlias types for isinstance checks.
-GenericAliasTypes = [_GenericAlias]
+_GenericAliasTypes: list[type] = [_GenericAlias]
 
 with contextlib.suppress(ImportError):
     # For newer versions of Python.
     from types import GenericAlias  # type: ignore
 
-    GenericAliasTypes.append(GenericAlias)
+    _GenericAliasTypes.append(GenericAlias)
 
 with contextlib.suppress(ImportError):
     # For older versions of Python.
     from typing import _SpecialGenericAlias  # type: ignore
 
-    GenericAliasTypes.append(_SpecialGenericAlias)
+    _GenericAliasTypes.append(_SpecialGenericAlias)
 
-GenericAliasTypes = tuple(GenericAliasTypes)
+GenericAliasTypes = tuple(_GenericAliasTypes)
 
 # Potential Union types for isinstance checks (UnionType added in py3.10).
 UnionTypes = (Union, types.UnionType) if hasattr(types, "UnionType") else (Union,)
@@ -181,7 +181,7 @@ def is_generic_alias(cls: GenericType) -> bool:
     return isinstance(cls, GenericAliasTypes)
 
 
-def unionize(*args: GenericType) -> Type:
+def unionize(*args: GenericType) -> GenericType:
     """Unionize the types.
 
     Args:
@@ -415,7 +415,7 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
 
 
 @lru_cache()
-def get_base_class(cls: GenericType) -> Type:
+def get_base_class(cls: GenericType) -> Type | tuple[Type, ...]:
     """Get the base class of a class.
 
     Args:
@@ -435,7 +435,14 @@ def get_base_class(cls: GenericType) -> Type:
         return type(get_args(cls)[0])
 
     if is_union(cls):
-        return tuple(get_base_class(arg) for arg in get_args(cls))
+        base_classes = []
+        for arg in get_args(cls):
+            sub_base_classes = get_base_class(arg)
+            if isinstance(sub_base_classes, tuple):
+                base_classes.extend(sub_base_classes)
+            else:
+                base_classes.append(sub_base_classes)
+        return tuple(base_classes)
 
     return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls
 

+ 9 - 5
reflex/vars/number.py

@@ -15,6 +15,7 @@ from typing import (
     Sequence,
     TypeVar,
     Union,
+    cast,
     overload,
 )
 
@@ -1102,7 +1103,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
     _cases: tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...] = dataclasses.field(
         default_factory=tuple
     )
-    _default: Var[VAR_TYPE] = dataclasses.field(
+    _default: Var[VAR_TYPE] = dataclasses.field(  # pyright: ignore[reportAssignmentType]
         default_factory=lambda: Var.create(None)
     )
 
@@ -1170,11 +1171,14 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
             The match operation.
         """
         cond = Var.create(cond)
-        cases = tuple(tuple(Var.create(c) for c in case) for case in cases)
-        default = Var.create(default)
+        cases = cast(
+            tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...],
+            tuple(tuple(Var.create(c) for c in case) for case in cases),
+        )
+        _default = cast(Var[VAR_TYPE], Var.create(default))
         var_type = _var_type or unionize(
             *(case[-1]._var_type for case in cases),
-            default._var_type,
+            _default._var_type,
         )
         return cls(
             _js_expr="",
@@ -1182,7 +1186,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
             _var_type=var_type,
             _cond=cond,
             _cases=cases,
-            _default=default,
+            _default=_default,
         )