"""The pyi generator module.""" import importlib import inspect import os import sys from inspect import getfullargspec from pathlib import Path from typing import Any, Dict, List, Optional, get_args import black from reflex.components.component import Component from reflex.vars import Var ruff_dont_remove = [Var, Optional, Dict, List] EXCLUDED_FILES = [ "__init__.py", "component.py", "bare.py", "foreach.py", "cond.py", "multiselect.py", ] DEFAULT_TYPING_IMPORTS = {"overload", "Optional", "Union"} def _get_type_hint(value, top_level=True, no_union=False): res = "" args = get_args(value) if args: res = f"{value.__name__}[{', '.join([_get_type_hint(arg, top_level=False) for arg in args if arg is not type(None)])}]" if value.__name__ == "Var": types = [res] + [ _get_type_hint(arg, top_level=False) for arg in args if arg is not type(None) ] if len(types) > 1 and not no_union: res = ", ".join(types) res = f"Union[{res}]" elif isinstance(value, str): ev = eval(value) res = _get_type_hint(ev, top_level=False) if ev.__name__ == "Var" else value else: res = value.__name__ if top_level and not res.startswith("Optional"): res = f"Optional[{res}]" return res def _get_typing_import(_module): src = [ line for line in inspect.getsource(_module).split("\n") if line.startswith("from typing") ] if len(src): return set(src[0].rpartition("from typing import ")[-1].split(", ")) return set() def _get_var_definition(_module, _var_name): return [ line.split(" = ")[0] for line in inspect.getsource(_module).splitlines() if line.startswith(_var_name) ] class PyiGenerator: """A .pyi file generator that will scan all defined Component in Reflex and generate the approriate stub. """ modules: list = [] root: str = "" current_module: Any = {} default_typing_imports: set = DEFAULT_TYPING_IMPORTS def _generate_imports(self, variables, classes): variables_imports = { type(_var) for _, _var in variables if isinstance(_var, Component) } bases = { base for _, _class in classes for base in _class.__bases__ if inspect.getmodule(base) != self.current_module } | variables_imports bases.add(Component) typing_imports = self.default_typing_imports | _get_typing_import( self.current_module ) return [ f"from typing import {','.join(typing_imports)}", *[f"from {base.__module__} import {base.__name__}" for base in bases], "from reflex.vars import Var, BaseVar, ComputedVar", "from reflex.event import EventChain", ] def _generate_pyi_class(self, _class): create_spec = getfullargspec(_class.create) lines = [ "", f"class {_class.__name__}({', '.join([base.__name__ for base in _class.__bases__])}):", ] definition = f" @overload\n @classmethod\n def create(cls, *children, " for kwarg in create_spec.kwonlyargs: if kwarg in create_spec.annotations: definition += f"{kwarg}: {_get_type_hint(create_spec.annotations[kwarg])} = None, " else: definition += f"{kwarg}, " for name, value in _class.__annotations__.items(): if name in create_spec.kwonlyargs: continue definition += f"{name}: {_get_type_hint(value)} = None, " definition = definition.rstrip(", ") definition += f", **props) -> '{_class.__name__}': ... # type: ignore" lines.append(definition) return lines def _generate_pyi_variable(self, _name, _var): return _get_var_definition(self.current_module, _name) def _generate_function(self, _name, _func): definition = "".join(inspect.getsource(_func).split(":\n")[0].split("\n")) return [f"{definition}:", " ..."] def _write_pyi_file(self, variables, functions, classes): pyi_content = [ f'"""Stub file for {self.current_module_path}.py"""', "# ------------------- DO NOT EDIT ----------------------", "# This file was generated by `scripts/pyi_generator.py`!", "# ------------------------------------------------------", "", ] pyi_content.extend(self._generate_imports(variables, classes)) for _name, _var in variables: pyi_content.extend(self._generate_pyi_variable(_name, _var)) for _fname, _func in functions: pyi_content.extend(self._generate_function(_fname, _func)) for _, _class in classes: pyi_content.extend(self._generate_pyi_class(_class)) pyi_filename = f"{self.current_module_path}.pyi" pyi_path = os.path.join(self.root, pyi_filename) with open(pyi_path, "w") as pyi_file: pyi_file.write("\n".join(pyi_content)) black.format_file_in_place( src=Path(pyi_path), fast=True, mode=black.FileMode(), write_back=black.WriteBack.YES, ) def _scan_file(self, file): self.current_module_path = os.path.splitext(file)[0] module_import = os.path.splitext(os.path.join(self.root, file))[0].replace( "/", "." ) self.current_module = importlib.import_module(module_import) local_variables = [ (name, obj) for name, obj in vars(self.current_module).items() if not name.startswith("__") # and ( # not inspect.getmodule(obj) # or inspect.getmodule(obj) == self.current_module # ) and not inspect.isclass(obj) and not inspect.isfunction(obj) ] functions = [ (name, obj) for name, obj in vars(self.current_module).items() if not name.startswith("__") and ( not inspect.getmodule(obj) or inspect.getmodule(obj) == self.current_module ) and inspect.isfunction(obj) ] class_names = [ (name, obj) for name, obj in vars(self.current_module).items() if inspect.isclass(obj) and issubclass(obj, Component) and obj != Component and inspect.getmodule(obj) == self.current_module ] if not class_names: return print(f"Parsed {file}: Found {[n for n,_ in class_names]}") self._write_pyi_file(local_variables, functions, class_names) def _scan_folder(self, folder): for root, _, files in os.walk(folder): self.root = root for file in files: if file in EXCLUDED_FILES: continue if file.endswith(".py"): self._scan_file(file) def scan_all(self, targets): """Scan all targets for class inheriting Component and generate the .pyi files. Args: targets: the list of file/folders to scan. """ for target in targets: if target.endswith(".py"): self.root, _, file = target.rpartition("/") self._scan_file(file) else: self._scan_folder(target) if __name__ == "__main__": targets = sys.argv[1:] if len(sys.argv) > 1 else ["reflex/components"] print(f"Running .pyi generator for {targets}") gen = PyiGenerator() gen.scan_all(targets)