Forráskód Böngészése

pyi_generator: Generate stubs for `SimpleNamespace` classes

If the namespace assigns `__call__` to an existing component `create` function,
generate args and docstring for IDE integration.
Masen Furer 1 éve
szülő
commit
763c1c1f07
1 módosított fájl, 104 hozzáadás és 10 törlés
  1. 104 10
      scripts/pyi_generator.py

+ 104 - 10
scripts/pyi_generator.py

@@ -13,7 +13,7 @@ import typing
 from inspect import getfullargspec
 from multiprocessing import Pool, cpu_count
 from pathlib import Path
-from types import ModuleType
+from types import ModuleType, SimpleNamespace
 from typing import Any, Callable, Iterable, Type, get_args
 
 import black
@@ -94,7 +94,9 @@ def _relative_to_pwd(path: Path) -> Path:
     Returns:
         The relative path.
     """
-    return path.relative_to(PWD)
+    if path.is_absolute():
+        return path.relative_to(PWD)
+    return path
 
 
 def _git_diff(args: list[str]) -> str:
@@ -403,7 +405,7 @@ def _get_parent_imports(func):
 
 def _generate_component_create_functiondef(
     node: ast.FunctionDef | None,
-    clz: type[Component],
+    clz: type[Component] | type[SimpleNamespace],
     type_hint_globals: dict[str, Any],
 ) -> ast.FunctionDef:
     """Generate the create function definition for a Component.
@@ -415,7 +417,13 @@ def _generate_component_create_functiondef(
 
     Returns:
         The create functiondef node for the ast.
+
+    Raises:
+        TypeError: If clz is not a subclass of Component.
     """
+    if not issubclass(clz, Component):
+        raise TypeError(f"clz must be a subclass of Component, not {clz!r}")
+
     # add the imports needed by get_type_hint later
     type_hint_globals.update(
         {name: getattr(typing, name) for name in DEFAULT_TYPING_IMPORTS}
@@ -484,10 +492,58 @@ def _generate_component_create_functiondef(
     return definition
 
 
+def _generate_namespace_call_functiondef(
+    clz_name: str,
+    classes: dict[str, type[Component] | type[SimpleNamespace]],
+    type_hint_globals: dict[str, Any],
+) -> ast.FunctionDef | None:
+    """Generate the __call__ function definition for a SimpleNamespace.
+
+    Args:
+        clz_name: The name of the SimpleNamespace class to generate the __call__ functiondef for.
+        classes: Map name to actual class definition.
+        type_hint_globals: The globals to use to resolving a type hint str.
+
+    Returns:
+        The create functiondef node for the ast.
+    """
+    # add the imports needed by get_type_hint later
+    type_hint_globals.update(
+        {name: getattr(typing, name) for name in DEFAULT_TYPING_IMPORTS}
+    )
+
+    clz = classes[clz_name]
+
+    # Determine which class is wrapped by the namespace __call__ method
+    component_class_name, dot, func_name = clz.__call__.__func__.__qualname__.partition(
+        "."
+    )
+    component_clz = classes[component_class_name]
+
+    # Only generate for create functions
+    if func_name != "create":
+        return None
+
+    definition = _generate_component_create_functiondef(
+        node=None,
+        clz=component_clz,
+        type_hint_globals=type_hint_globals,
+    )
+    definition.name = "__call__"
+
+    # Turn the definition into a staticmethod
+    del definition.args.args[0]  # remove `cls` arg
+    definition.decorator_list = [ast.Name(id="staticmethod")]
+
+    return definition
+
+
 class StubGenerator(ast.NodeTransformer):
     """A node transformer that will generate the stubs for a given module."""
 
-    def __init__(self, module: ModuleType, classes: dict[str, Type[Component]]):
+    def __init__(
+        self, module: ModuleType, classes: dict[str, Type[Component | SimpleNamespace]]
+    ):
         """Initialize the stub generator.
 
         Args:
@@ -528,6 +584,18 @@ class StubGenerator(ast.NodeTransformer):
             node.body.pop(0)
         return node
 
+    def _current_class_is_component(self) -> bool:
+        """Check if the current class is a Component.
+
+        Returns:
+            Whether the current class is a Component.
+        """
+        return (
+            self.current_class is not None
+            and self.current_class in self.classes
+            and issubclass(self.classes[self.current_class], Component)
+        )
+
     def visit_Module(self, node: ast.Module) -> ast.Module:
         """Visit a Module node and remove docstring from body.
 
@@ -591,6 +659,27 @@ class StubGenerator(ast.NodeTransformer):
         exec("\n".join(self.import_statements), self.type_hint_globals)
         self.current_class = node.name
         self._remove_docstring(node)
+
+        # Define `__call__` as a real function so the docstring appears in the stub.
+        call_definition = None
+        for child in node.body[:]:
+            found_call = False
+            if isinstance(child, ast.Assign):
+                for target in child.targets[:]:
+                    if isinstance(target, ast.Name) and target.id == "__call__":
+                        child.targets.remove(target)
+                        found_call = True
+                if not found_call:
+                    continue
+                if not child.targets[:]:
+                    node.body.remove(child)
+                call_definition = _generate_namespace_call_functiondef(
+                    self.current_class,
+                    self.classes,
+                    type_hint_globals=self.type_hint_globals,
+                )
+                break
+
         self.generic_visit(node)  # Visit child nodes.
 
         if (
@@ -598,7 +687,7 @@ class StubGenerator(ast.NodeTransformer):
                 isinstance(child, ast.FunctionDef) and child.name == "create"
                 for child in node.body
             )
-            and self.current_class in self.classes
+            and self._current_class_is_component()
         ):
             # Add a new .create FunctionDef since one does not exist.
             node.body.append(
@@ -608,6 +697,8 @@ class StubGenerator(ast.NodeTransformer):
                     type_hint_globals=self.type_hint_globals,
                 )
             )
+        if call_definition is not None:
+            node.body.append(call_definition)
         if not node.body:
             # We should never return an empty body.
             node.body.append(ast.Expr(value=ast.Ellipsis()))
@@ -634,11 +725,12 @@ class StubGenerator(ast.NodeTransformer):
                 node, self.classes[self.current_class], self.type_hint_globals
             )
         else:
-            if node.name.startswith("_"):
+            if node.name.startswith("_") and node.name != "__call__":
                 return None  # remove private methods
 
-            # Blank out the function body for public functions.
-            node.body = [ast.Expr(value=ast.Ellipsis())]
+            if node.body[-1] != ast.Expr(value=ast.Ellipsis()):
+                # Blank out the function body for public functions.
+                node.body = [ast.Expr(value=ast.Ellipsis())]
         return node
 
     def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:
@@ -657,9 +749,11 @@ class StubGenerator(ast.NodeTransformer):
             and node.value.id == "Any"
         ):
             return node
-        if self.current_class in self.classes:
+
+        if self._current_class_is_component():
             # Remove annotated assignments in Component classes (props)
             return None
+
         return node
 
     def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign | None:
@@ -738,7 +832,7 @@ class PyiGenerator:
             name: obj
             for name, obj in vars(module).items()
             if inspect.isclass(obj)
-            and issubclass(obj, Component)
+            and (issubclass(obj, Component) or issubclass(obj, SimpleNamespace))
             and obj != Component
             and inspect.getmodule(obj) == module
         }