|
@@ -13,7 +13,7 @@ import typing
|
|
|
from inspect import getfullargspec
|
|
|
from multiprocessing import Pool, cpu_count
|
|
|
from pathlib import Path
|
|
|
-from types import ModuleType
|
|
|
+from types import ModuleType, SimpleNamespace
|
|
|
from typing import Any, Callable, Iterable, Type, get_args
|
|
|
|
|
|
import black
|
|
@@ -94,7 +94,9 @@ def _relative_to_pwd(path: Path) -> Path:
|
|
|
Returns:
|
|
|
The relative path.
|
|
|
"""
|
|
|
- return path.relative_to(PWD)
|
|
|
+ if path.is_absolute():
|
|
|
+ return path.relative_to(PWD)
|
|
|
+ return path
|
|
|
|
|
|
|
|
|
def _git_diff(args: list[str]) -> str:
|
|
@@ -403,7 +405,7 @@ def _get_parent_imports(func):
|
|
|
|
|
|
def _generate_component_create_functiondef(
|
|
|
node: ast.FunctionDef | None,
|
|
|
- clz: type[Component],
|
|
|
+ clz: type[Component] | type[SimpleNamespace],
|
|
|
type_hint_globals: dict[str, Any],
|
|
|
) -> ast.FunctionDef:
|
|
|
"""Generate the create function definition for a Component.
|
|
@@ -415,7 +417,13 @@ def _generate_component_create_functiondef(
|
|
|
|
|
|
Returns:
|
|
|
The create functiondef node for the ast.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ TypeError: If clz is not a subclass of Component.
|
|
|
"""
|
|
|
+ if not issubclass(clz, Component):
|
|
|
+ raise TypeError(f"clz must be a subclass of Component, not {clz!r}")
|
|
|
+
|
|
|
# add the imports needed by get_type_hint later
|
|
|
type_hint_globals.update(
|
|
|
{name: getattr(typing, name) for name in DEFAULT_TYPING_IMPORTS}
|
|
@@ -484,10 +492,58 @@ def _generate_component_create_functiondef(
|
|
|
return definition
|
|
|
|
|
|
|
|
|
+def _generate_namespace_call_functiondef(
|
|
|
+ clz_name: str,
|
|
|
+ classes: dict[str, type[Component] | type[SimpleNamespace]],
|
|
|
+ type_hint_globals: dict[str, Any],
|
|
|
+) -> ast.FunctionDef | None:
|
|
|
+ """Generate the __call__ function definition for a SimpleNamespace.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ clz_name: The name of the SimpleNamespace class to generate the __call__ functiondef for.
|
|
|
+ classes: Map name to actual class definition.
|
|
|
+ type_hint_globals: The globals to use to resolving a type hint str.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The create functiondef node for the ast.
|
|
|
+ """
|
|
|
+ # add the imports needed by get_type_hint later
|
|
|
+ type_hint_globals.update(
|
|
|
+ {name: getattr(typing, name) for name in DEFAULT_TYPING_IMPORTS}
|
|
|
+ )
|
|
|
+
|
|
|
+ clz = classes[clz_name]
|
|
|
+
|
|
|
+ # Determine which class is wrapped by the namespace __call__ method
|
|
|
+ component_class_name, dot, func_name = clz.__call__.__func__.__qualname__.partition(
|
|
|
+ "."
|
|
|
+ )
|
|
|
+ component_clz = classes[component_class_name]
|
|
|
+
|
|
|
+ # Only generate for create functions
|
|
|
+ if func_name != "create":
|
|
|
+ return None
|
|
|
+
|
|
|
+ definition = _generate_component_create_functiondef(
|
|
|
+ node=None,
|
|
|
+ clz=component_clz,
|
|
|
+ type_hint_globals=type_hint_globals,
|
|
|
+ )
|
|
|
+ definition.name = "__call__"
|
|
|
+
|
|
|
+ # Turn the definition into a staticmethod
|
|
|
+ del definition.args.args[0] # remove `cls` arg
|
|
|
+ definition.decorator_list = [ast.Name(id="staticmethod")]
|
|
|
+
|
|
|
+ return definition
|
|
|
+
|
|
|
+
|
|
|
class StubGenerator(ast.NodeTransformer):
|
|
|
"""A node transformer that will generate the stubs for a given module."""
|
|
|
|
|
|
- def __init__(self, module: ModuleType, classes: dict[str, Type[Component]]):
|
|
|
+ def __init__(
|
|
|
+ self, module: ModuleType, classes: dict[str, Type[Component | SimpleNamespace]]
|
|
|
+ ):
|
|
|
"""Initialize the stub generator.
|
|
|
|
|
|
Args:
|
|
@@ -528,6 +584,18 @@ class StubGenerator(ast.NodeTransformer):
|
|
|
node.body.pop(0)
|
|
|
return node
|
|
|
|
|
|
+ def _current_class_is_component(self) -> bool:
|
|
|
+ """Check if the current class is a Component.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Whether the current class is a Component.
|
|
|
+ """
|
|
|
+ return (
|
|
|
+ self.current_class is not None
|
|
|
+ and self.current_class in self.classes
|
|
|
+ and issubclass(self.classes[self.current_class], Component)
|
|
|
+ )
|
|
|
+
|
|
|
def visit_Module(self, node: ast.Module) -> ast.Module:
|
|
|
"""Visit a Module node and remove docstring from body.
|
|
|
|
|
@@ -591,6 +659,27 @@ class StubGenerator(ast.NodeTransformer):
|
|
|
exec("\n".join(self.import_statements), self.type_hint_globals)
|
|
|
self.current_class = node.name
|
|
|
self._remove_docstring(node)
|
|
|
+
|
|
|
+ # Define `__call__` as a real function so the docstring appears in the stub.
|
|
|
+ call_definition = None
|
|
|
+ for child in node.body[:]:
|
|
|
+ found_call = False
|
|
|
+ if isinstance(child, ast.Assign):
|
|
|
+ for target in child.targets[:]:
|
|
|
+ if isinstance(target, ast.Name) and target.id == "__call__":
|
|
|
+ child.targets.remove(target)
|
|
|
+ found_call = True
|
|
|
+ if not found_call:
|
|
|
+ continue
|
|
|
+ if not child.targets[:]:
|
|
|
+ node.body.remove(child)
|
|
|
+ call_definition = _generate_namespace_call_functiondef(
|
|
|
+ self.current_class,
|
|
|
+ self.classes,
|
|
|
+ type_hint_globals=self.type_hint_globals,
|
|
|
+ )
|
|
|
+ break
|
|
|
+
|
|
|
self.generic_visit(node) # Visit child nodes.
|
|
|
|
|
|
if (
|
|
@@ -598,7 +687,7 @@ class StubGenerator(ast.NodeTransformer):
|
|
|
isinstance(child, ast.FunctionDef) and child.name == "create"
|
|
|
for child in node.body
|
|
|
)
|
|
|
- and self.current_class in self.classes
|
|
|
+ and self._current_class_is_component()
|
|
|
):
|
|
|
# Add a new .create FunctionDef since one does not exist.
|
|
|
node.body.append(
|
|
@@ -608,6 +697,8 @@ class StubGenerator(ast.NodeTransformer):
|
|
|
type_hint_globals=self.type_hint_globals,
|
|
|
)
|
|
|
)
|
|
|
+ if call_definition is not None:
|
|
|
+ node.body.append(call_definition)
|
|
|
if not node.body:
|
|
|
# We should never return an empty body.
|
|
|
node.body.append(ast.Expr(value=ast.Ellipsis()))
|
|
@@ -634,11 +725,12 @@ class StubGenerator(ast.NodeTransformer):
|
|
|
node, self.classes[self.current_class], self.type_hint_globals
|
|
|
)
|
|
|
else:
|
|
|
- if node.name.startswith("_"):
|
|
|
+ if node.name.startswith("_") and node.name != "__call__":
|
|
|
return None # remove private methods
|
|
|
|
|
|
- # Blank out the function body for public functions.
|
|
|
- node.body = [ast.Expr(value=ast.Ellipsis())]
|
|
|
+ if node.body[-1] != ast.Expr(value=ast.Ellipsis()):
|
|
|
+ # Blank out the function body for public functions.
|
|
|
+ node.body = [ast.Expr(value=ast.Ellipsis())]
|
|
|
return node
|
|
|
|
|
|
def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:
|
|
@@ -657,9 +749,11 @@ class StubGenerator(ast.NodeTransformer):
|
|
|
and node.value.id == "Any"
|
|
|
):
|
|
|
return node
|
|
|
- if self.current_class in self.classes:
|
|
|
+
|
|
|
+ if self._current_class_is_component():
|
|
|
# Remove annotated assignments in Component classes (props)
|
|
|
return None
|
|
|
+
|
|
|
return node
|
|
|
|
|
|
def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign | None:
|
|
@@ -738,7 +832,7 @@ class PyiGenerator:
|
|
|
name: obj
|
|
|
for name, obj in vars(module).items()
|
|
|
if inspect.isclass(obj)
|
|
|
- and issubclass(obj, Component)
|
|
|
+ and (issubclass(obj, Component) or issubclass(obj, SimpleNamespace))
|
|
|
and obj != Component
|
|
|
and inspect.getmodule(obj) == module
|
|
|
}
|