123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292 |
- """The pyi generator module."""
- import importlib
- import inspect
- import os
- import re
- import sys
- from inspect import getfullargspec
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Union, get_args # NOQA
- import black
- from reflex.components.component import Component
- from reflex.vars import Var
- ruff_dont_remove = [Var, Optional, Dict, List]
- EXCLUDED_FILES = [
- "__init__.py",
- "component.py",
- "bare.py",
- "foreach.py",
- "cond.py",
- "multiselect.py",
- ]
- DEFAULT_TYPING_IMPORTS = {"overload", "Optional", "Union"}
- def _get_type_hint(value, top_level=True, no_union=False):
- res = ""
- args = get_args(value)
- if args:
- res = f"{value.__name__}[{', '.join([_get_type_hint(arg, top_level=False) for arg in args if arg is not type(None)])}]"
- if value.__name__ == "Var":
- types = [res] + [
- _get_type_hint(arg, top_level=False)
- for arg in args
- if arg is not type(None)
- ]
- if len(types) > 1 and not no_union:
- res = ", ".join(types)
- res = f"Union[{res}]"
- elif isinstance(value, str):
- ev = eval(value)
- res = _get_type_hint(ev, top_level=False) if ev.__name__ == "Var" else value
- else:
- res = value.__name__
- if top_level and not res.startswith("Optional"):
- res = f"Optional[{res}]"
- return res
- def _get_typing_import(_module):
- src = [
- line
- for line in inspect.getsource(_module).split("\n")
- if line.startswith("from typing")
- ]
- if len(src):
- return set(src[0].rpartition("from typing import ")[-1].split(", "))
- return set()
- def _get_var_definition(_module, _var_name):
- return [
- line.split(" = ")[0]
- for line in inspect.getsource(_module).splitlines()
- if line.startswith(_var_name)
- ]
- class PyiGenerator:
- """A .pyi file generator that will scan all defined Component in Reflex and
- generate the approriate stub.
- """
- modules: list = []
- root: str = ""
- current_module: Any = {}
- default_typing_imports: set = DEFAULT_TYPING_IMPORTS
- def _generate_imports(self, variables, classes):
- variables_imports = {
- type(_var) for _, _var in variables if isinstance(_var, Component)
- }
- bases = {
- base
- for _, _class in classes
- for base in _class.__bases__
- if inspect.getmodule(base) != self.current_module
- } | variables_imports
- bases.add(Component)
- typing_imports = self.default_typing_imports | _get_typing_import(
- self.current_module
- )
- bases = sorted(bases, key=lambda base: base.__name__)
- return [
- f"from typing import {','.join(sorted(typing_imports))}",
- *[f"from {base.__module__} import {base.__name__}" for base in bases],
- "from reflex.vars import Var, BaseVar, ComputedVar",
- "from reflex.event import EventHandler, EventChain, EventSpec",
- ]
- def _generate_pyi_class(self, _class: type[Component]):
- create_spec = getfullargspec(_class.create)
- lines = [
- "",
- f"class {_class.__name__}({', '.join([base.__name__ for base in _class.__bases__])}):",
- ]
- definition = f" @overload\n @classmethod\n def create(cls, *children, "
- for kwarg in create_spec.kwonlyargs:
- if kwarg in create_spec.annotations:
- definition += f"{kwarg}: {_get_type_hint(create_spec.annotations[kwarg])} = None, "
- else:
- definition += f"{kwarg}, "
- for name, value in _class.__annotations__.items():
- if name in create_spec.kwonlyargs:
- continue
- definition += f"{name}: {_get_type_hint(value)} = None, "
- for trigger in sorted(_class().get_event_triggers().keys()):
- definition += f"{trigger}: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, "
- definition = definition.rstrip(", ")
- definition += f", **props) -> '{_class.__name__}': # type: ignore\n"
- definition += self._generate_docstrings(_class, _class.__annotations__.keys())
- lines.append(definition)
- lines.append(" ...")
- return lines
- def _generate_docstrings(self, _class, _props):
- props_comments = {}
- comments = []
- for _i, line in enumerate(inspect.getsource(_class).splitlines()):
- reached_functions = re.search("def ", line)
- if reached_functions:
- # We've reached the functions, so stop.
- break
- # Get comments for prop
- if line.strip().startswith("#"):
- comments.append(line)
- continue
- # Check if this line has a prop.
- match = re.search("\\w+:", line)
- if match is None:
- # This line doesn't have a var, so continue.
- continue
- # Get the prop.
- prop = match.group(0).strip(":")
- if prop in _props:
- # This isn't a prop, so continue.
- props_comments[prop] = "\n".join(
- [comment.strip().strip("#") for comment in comments]
- )
- comments.clear()
- continue
- new_docstring = []
- for i, line in enumerate(_class.create.__doc__.splitlines()):
- if i == 0:
- new_docstring.append(" " * 8 + '"""' + line)
- else:
- new_docstring.append(line)
- if "*children" in line:
- for nline in [
- f"{line.split('*')[0]}{n}:{c}" for n, c in props_comments.items()
- ]:
- new_docstring.append(nline)
- new_docstring += ['"""']
- return "\n".join(new_docstring)
- def _generate_pyi_variable(self, _name, _var):
- return _get_var_definition(self.current_module, _name)
- def _generate_function(self, _name, _func):
- import textwrap
- # Don't generate indented functions.
- source = inspect.getsource(_func)
- if textwrap.dedent(source) != source:
- return []
- definition = "".join([line for line in source.split(":\n")[0].split("\n")])
- return [f"{definition}:", " ..."]
- def _write_pyi_file(self, variables, functions, classes):
- pyi_content = [
- f'"""Stub file for {self.current_module_path}.py"""',
- "# ------------------- DO NOT EDIT ----------------------",
- "# This file was generated by `scripts/pyi_generator.py`!",
- "# ------------------------------------------------------",
- "",
- ]
- pyi_content.extend(self._generate_imports(variables, classes))
- for _name, _var in variables:
- pyi_content.extend(self._generate_pyi_variable(_name, _var))
- for _fname, _func in functions:
- pyi_content.extend(self._generate_function(_fname, _func))
- for _, _class in classes:
- pyi_content.extend(self._generate_pyi_class(_class))
- pyi_filename = f"{self.current_module_path}.pyi"
- pyi_path = os.path.join(self.root, pyi_filename)
- with open(pyi_path, "w") as pyi_file:
- pyi_file.write("\n".join(pyi_content))
- black.format_file_in_place(
- src=Path(pyi_path),
- fast=True,
- mode=black.FileMode(),
- write_back=black.WriteBack.YES,
- )
- def _scan_file(self, file):
- self.current_module_path = os.path.splitext(file)[0]
- module_import = os.path.splitext(os.path.join(self.root, file))[0].replace(
- "/", "."
- )
- self.current_module = importlib.import_module(module_import)
- local_variables = [
- (name, obj)
- for name, obj in vars(self.current_module).items()
- if not name.startswith("_")
- and not inspect.isclass(obj)
- and not inspect.isfunction(obj)
- ]
- functions = [
- (name, obj)
- for name, obj in vars(self.current_module).items()
- if not name.startswith("__")
- and (
- not inspect.getmodule(obj)
- or inspect.getmodule(obj) == self.current_module
- )
- and inspect.isfunction(obj)
- ]
- class_names = [
- (name, obj)
- for name, obj in vars(self.current_module).items()
- if inspect.isclass(obj)
- and issubclass(obj, Component)
- and obj != Component
- and inspect.getmodule(obj) == self.current_module
- ]
- if not class_names:
- return
- print(f"Parsed {file}: Found {[n for n,_ in class_names]}")
- self._write_pyi_file(local_variables, functions, class_names)
- def _scan_folder(self, folder):
- for root, _, files in os.walk(folder):
- self.root = root
- for file in files:
- if file in EXCLUDED_FILES:
- continue
- if file.endswith(".py"):
- self._scan_file(file)
- def scan_all(self, targets):
- """Scan all targets for class inheriting Component and generate the .pyi files.
- Args:
- targets: the list of file/folders to scan.
- """
- for target in targets:
- if target.endswith(".py"):
- self.root, _, file = target.rpartition("/")
- self._scan_file(file)
- else:
- self._scan_folder(target)
- if __name__ == "__main__":
- targets = sys.argv[1:] if len(sys.argv) > 1 else ["reflex/components"]
- print(f"Running .pyi generator for {targets}")
- gen = PyiGenerator()
- gen.scan_all(targets)
|