pyi_generator.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. """The pyi generator module."""
  2. import ast
  3. import contextlib
  4. import importlib
  5. import inspect
  6. import logging
  7. import os
  8. import re
  9. import sys
  10. import textwrap
  11. from inspect import getfullargspec
  12. from pathlib import Path
  13. from types import ModuleType
  14. from typing import Any, Callable, Iterable, Type, get_args
  15. import black
  16. import black.mode
  17. from reflex.components.component import Component
  18. from reflex.utils import types as rx_types
  19. from reflex.vars import Var
  20. logger = logging.getLogger("pyi_generator")
  21. EXCLUDED_FILES = [
  22. "__init__.py",
  23. "component.py",
  24. "bare.py",
  25. "foreach.py",
  26. "cond.py",
  27. "multiselect.py",
  28. "literals.py",
  29. ]
  30. # These props exist on the base component, but should not be exposed in create methods.
  31. EXCLUDED_PROPS = [
  32. "alias",
  33. "children",
  34. "event_triggers",
  35. "library",
  36. "lib_dependencies",
  37. "tag",
  38. "is_default",
  39. "special_props",
  40. "_invalid_children",
  41. "_memoization_mode",
  42. "_valid_children",
  43. ]
  44. DEFAULT_TYPING_IMPORTS = {
  45. "overload",
  46. "Any",
  47. "Dict",
  48. # "List",
  49. "Literal",
  50. "Optional",
  51. "Union",
  52. }
  53. def _get_type_hint(value, type_hint_globals, is_optional=True) -> str:
  54. """Resolve the type hint for value.
  55. Args:
  56. value: The type annotation as a str or actual types/aliases.
  57. type_hint_globals: The globals to use to resolving a type hint str.
  58. is_optional: Whether the type hint should be wrapped in Optional.
  59. Returns:
  60. The resolved type hint as a str.
  61. """
  62. res = ""
  63. args = get_args(value)
  64. if args:
  65. inner_container_type_args = (
  66. [repr(arg) for arg in args]
  67. if rx_types.is_literal(value)
  68. else [
  69. _get_type_hint(arg, type_hint_globals, is_optional=False)
  70. for arg in args
  71. if arg is not type(None)
  72. ]
  73. )
  74. res = f"{value.__name__}[{', '.join(inner_container_type_args)}]"
  75. if value.__name__ == "Var":
  76. # For Var types, Union with the inner args so they can be passed directly.
  77. types = [res] + [
  78. _get_type_hint(arg, type_hint_globals, is_optional=False)
  79. for arg in args
  80. if arg is not type(None)
  81. ]
  82. if len(types) > 1:
  83. res = ", ".join(types)
  84. res = f"Union[{res}]"
  85. elif isinstance(value, str):
  86. ev = eval(value, type_hint_globals)
  87. res = (
  88. _get_type_hint(ev, type_hint_globals, is_optional=False)
  89. if ev.__name__ == "Var"
  90. else value
  91. )
  92. else:
  93. res = value.__name__
  94. if is_optional and not res.startswith("Optional"):
  95. res = f"Optional[{res}]"
  96. return res
  97. def _generate_imports(typing_imports: Iterable[str]) -> list[ast.ImportFrom]:
  98. """Generate the import statements for the stub file.
  99. Args:
  100. typing_imports: The typing imports to include.
  101. Returns:
  102. The list of import statements.
  103. """
  104. return [
  105. ast.ImportFrom(
  106. module="typing",
  107. names=[ast.alias(name=imp) for imp in sorted(typing_imports)],
  108. ),
  109. *ast.parse( # type: ignore
  110. textwrap.dedent(
  111. """
  112. from reflex.vars import Var, BaseVar, ComputedVar
  113. from reflex.event import EventChain, EventHandler, EventSpec
  114. from reflex.style import Style"""
  115. )
  116. ).body,
  117. ]
  118. def _generate_docstrings(clzs: list[Type[Component]], props: list[str]) -> str:
  119. """Generate the docstrings for the create method.
  120. Args:
  121. clzs: The classes to generate docstrings for.
  122. props: The props to generate docstrings for.
  123. Returns:
  124. The docstring for the create method.
  125. """
  126. props_comments = {}
  127. comments = []
  128. for clz in clzs:
  129. for line in inspect.getsource(clz).splitlines():
  130. reached_functions = re.search("def ", line)
  131. if reached_functions:
  132. # We've reached the functions, so stop.
  133. break
  134. # Get comments for prop
  135. if line.strip().startswith("#"):
  136. comments.append(line)
  137. continue
  138. # Check if this line has a prop.
  139. match = re.search("\\w+:", line)
  140. if match is None:
  141. # This line doesn't have a var, so continue.
  142. continue
  143. # Get the prop.
  144. prop = match.group(0).strip(":")
  145. if prop in props:
  146. if not comments: # do not include undocumented props
  147. continue
  148. props_comments[prop] = [
  149. comment.strip().strip("#") for comment in comments
  150. ]
  151. comments.clear()
  152. clz = clzs[0]
  153. new_docstring = []
  154. for line in (clz.create.__doc__ or "").splitlines():
  155. if "**" in line:
  156. indent = line.split("**")[0]
  157. for nline in [
  158. f"{indent}{n}:{' '.join(c)}" for n, c in props_comments.items()
  159. ]:
  160. new_docstring.append(nline)
  161. new_docstring.append(line)
  162. return "\n".join(new_docstring)
  163. def _extract_func_kwargs_as_ast_nodes(
  164. func: Callable,
  165. type_hint_globals: dict[str, Any],
  166. ) -> list[tuple[ast.arg, ast.Constant | None]]:
  167. """Get the kwargs already defined on the function.
  168. Args:
  169. func: The function to extract kwargs from.
  170. type_hint_globals: The globals to use to resolving a type hint str.
  171. Returns:
  172. The list of kwargs as ast arg nodes.
  173. """
  174. spec = getfullargspec(func)
  175. kwargs = []
  176. for kwarg in spec.kwonlyargs:
  177. arg = ast.arg(arg=kwarg)
  178. if kwarg in spec.annotations:
  179. arg.annotation = ast.Name(
  180. id=_get_type_hint(spec.annotations[kwarg], type_hint_globals)
  181. )
  182. default = None
  183. if spec.kwonlydefaults is not None and kwarg in spec.kwonlydefaults:
  184. default = ast.Constant(value=spec.kwonlydefaults[kwarg])
  185. kwargs.append((arg, default))
  186. return kwargs
  187. def _extract_class_props_as_ast_nodes(
  188. func: Callable,
  189. clzs: list[Type],
  190. type_hint_globals: dict[str, Any],
  191. extract_real_default: bool = False,
  192. ) -> list[tuple[ast.arg, ast.Constant | None]]:
  193. """Get the props defined on the class and all parents.
  194. Args:
  195. func: The function that kwargs will be added to.
  196. clzs: The classes to extract props from.
  197. type_hint_globals: The globals to use to resolving a type hint str.
  198. extract_real_default: Whether to extract the real default value from the
  199. pydantic field definition.
  200. Returns:
  201. The list of props as ast arg nodes
  202. """
  203. spec = getfullargspec(func)
  204. all_props = []
  205. kwargs = []
  206. for target_class in clzs:
  207. # Import from the target class to ensure type hints are resolvable.
  208. exec(f"from {target_class.__module__} import *", type_hint_globals)
  209. for name, value in target_class.__annotations__.items():
  210. if name in spec.kwonlyargs or name in EXCLUDED_PROPS or name in all_props:
  211. continue
  212. all_props.append(name)
  213. default = None
  214. if extract_real_default:
  215. # TODO: This is not currently working since the default is not type compatible
  216. # with the annotation in some cases.
  217. with contextlib.suppress(AttributeError, KeyError):
  218. # Try to get default from pydantic field definition.
  219. default = target_class.__fields__[name].default
  220. if isinstance(default, Var):
  221. default = default._decode() # type: ignore
  222. kwargs.append(
  223. (
  224. ast.arg(
  225. arg=name,
  226. annotation=ast.Name(
  227. id=_get_type_hint(value, type_hint_globals)
  228. ),
  229. ),
  230. ast.Constant(value=default),
  231. )
  232. )
  233. return kwargs
  234. def _generate_component_create_functiondef(
  235. node: ast.FunctionDef | None,
  236. clz: type[Component],
  237. type_hint_globals: dict[str, Any],
  238. ) -> ast.FunctionDef:
  239. """Generate the create function definition for a Component.
  240. Args:
  241. node: The existing create functiondef node from the ast
  242. clz: The Component class to generate the create functiondef for.
  243. type_hint_globals: The globals to use to resolving a type hint str.
  244. Returns:
  245. The create functiondef node for the ast.
  246. """
  247. # kwargs defined on the actual create function
  248. kwargs = _extract_func_kwargs_as_ast_nodes(clz.create, type_hint_globals)
  249. # kwargs associated with props defined in the class and its parents
  250. all_classes = [c for c in clz.__mro__ if issubclass(c, Component)]
  251. prop_kwargs = _extract_class_props_as_ast_nodes(
  252. clz.create, all_classes, type_hint_globals
  253. )
  254. all_props = [arg[0].arg for arg in prop_kwargs]
  255. kwargs.extend(prop_kwargs)
  256. # event handler kwargs
  257. kwargs.extend(
  258. (
  259. ast.arg(
  260. arg=trigger,
  261. annotation=ast.Name(
  262. id="Optional[Union[EventHandler, EventSpec, list, function, BaseVar]]"
  263. ),
  264. ),
  265. ast.Constant(value=None),
  266. )
  267. for trigger in sorted(clz().get_event_triggers().keys())
  268. )
  269. logger.debug(f"Generated {clz.__name__}.create method with {len(kwargs)} kwargs")
  270. create_args = ast.arguments(
  271. args=[ast.arg(arg="cls")],
  272. posonlyargs=[],
  273. vararg=ast.arg(arg="children"),
  274. kwonlyargs=[arg[0] for arg in kwargs],
  275. kw_defaults=[arg[1] for arg in kwargs],
  276. kwarg=ast.arg(arg="props"),
  277. defaults=[],
  278. )
  279. definition = ast.FunctionDef(
  280. name="create",
  281. args=create_args,
  282. body=[
  283. ast.Expr(
  284. value=ast.Constant(value=_generate_docstrings(all_classes, all_props))
  285. ),
  286. ast.Expr(
  287. value=ast.Ellipsis(),
  288. ),
  289. ],
  290. decorator_list=[
  291. ast.Name(id="overload"),
  292. *(
  293. node.decorator_list
  294. if node is not None
  295. else [ast.Name(id="classmethod")]
  296. ),
  297. ],
  298. lineno=node.lineno if node is not None else None,
  299. returns=ast.Constant(value=clz.__name__),
  300. )
  301. return definition
  302. class StubGenerator(ast.NodeTransformer):
  303. """A node transformer that will generate the stubs for a given module."""
  304. def __init__(self, module: ModuleType, classes: dict[str, Type[Component]]):
  305. """Initialize the stub generator.
  306. Args:
  307. module: The actual module object module to generate stubs for.
  308. classes: The actual Component class objects to generate stubs for.
  309. """
  310. super().__init__()
  311. # Dict mapping class name to actual class object.
  312. self.classes = classes
  313. # Track the last class node that was visited.
  314. self.current_class = None
  315. # These imports will be included in the AST of stub files.
  316. self.typing_imports = DEFAULT_TYPING_IMPORTS
  317. # Whether those typing imports have been inserted yet.
  318. self.inserted_imports = False
  319. # Collected import statements from the module.
  320. self.import_statements: list[str] = []
  321. # This dict is used when evaluating type hints.
  322. self.type_hint_globals = module.__dict__.copy()
  323. @staticmethod
  324. def _remove_docstring(
  325. node: ast.Module | ast.ClassDef | ast.FunctionDef,
  326. ) -> ast.Module | ast.ClassDef | ast.FunctionDef:
  327. """Removes any docstring in place.
  328. Args:
  329. node: The node to remove the docstring from.
  330. Returns:
  331. The modified node.
  332. """
  333. if (
  334. node.body
  335. and isinstance(node.body[0], ast.Expr)
  336. and isinstance(node.body[0].value, ast.Constant)
  337. ):
  338. node.body.pop(0)
  339. return node
  340. def visit_Module(self, node: ast.Module) -> ast.Module:
  341. """Visit a Module node and remove docstring from body.
  342. Args:
  343. node: The Module node to visit.
  344. Returns:
  345. The modified Module node.
  346. """
  347. self.generic_visit(node)
  348. return self._remove_docstring(node) # type: ignore
  349. def visit_Import(
  350. self, node: ast.Import | ast.ImportFrom
  351. ) -> ast.Import | ast.ImportFrom | list[ast.Import | ast.ImportFrom]:
  352. """Collect import statements from the module.
  353. If this is the first import statement, insert the typing imports before it.
  354. Args:
  355. node: The import node to visit.
  356. Returns:
  357. The modified import node(s).
  358. """
  359. self.import_statements.append(ast.unparse(node))
  360. if not self.inserted_imports:
  361. self.inserted_imports = True
  362. return _generate_imports(self.typing_imports) + [node]
  363. return node
  364. def visit_ImportFrom(
  365. self, node: ast.ImportFrom
  366. ) -> ast.Import | ast.ImportFrom | list[ast.Import | ast.ImportFrom] | None:
  367. """Visit an ImportFrom node.
  368. Remove any `from __future__ import *` statements, and hand off to visit_Import.
  369. Args:
  370. node: The ImportFrom node to visit.
  371. Returns:
  372. The modified ImportFrom node.
  373. """
  374. if node.module == "__future__":
  375. return None # ignore __future__ imports
  376. return self.visit_Import(node)
  377. def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
  378. """Visit a ClassDef node.
  379. Remove all assignments in the class body, and add a create functiondef
  380. if one does not exist.
  381. Args:
  382. node: The ClassDef node to visit.
  383. Returns:
  384. The modified ClassDef node.
  385. """
  386. exec("\n".join(self.import_statements), self.type_hint_globals)
  387. self.current_class = node.name
  388. self._remove_docstring(node)
  389. self.generic_visit(node) # Visit child nodes.
  390. if (
  391. not any(
  392. isinstance(child, ast.FunctionDef) and child.name == "create"
  393. for child in node.body
  394. )
  395. and self.current_class in self.classes
  396. ):
  397. # Add a new .create FunctionDef since one does not exist.
  398. node.body.append(
  399. _generate_component_create_functiondef(
  400. node=None,
  401. clz=self.classes[self.current_class],
  402. type_hint_globals=self.type_hint_globals,
  403. )
  404. )
  405. if not node.body:
  406. # We should never return an empty body.
  407. node.body.append(ast.Expr(value=ast.Ellipsis()))
  408. self.current_class = None
  409. return node
  410. def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
  411. """Visit a FunctionDef node.
  412. Special handling for `.create` functions to add type hints for all props
  413. defined on the component class.
  414. Remove all private functions and blank out the function body of the
  415. remaining public functions.
  416. Args:
  417. node: The FunctionDef node to visit.
  418. Returns:
  419. The modified FunctionDef node (or None).
  420. """
  421. if node.name == "create" and self.current_class in self.classes:
  422. node = _generate_component_create_functiondef(
  423. node, self.classes[self.current_class], self.type_hint_globals
  424. )
  425. else:
  426. if node.name.startswith("_"):
  427. return None # remove private methods
  428. # Blank out the function body for public functions.
  429. node.body = [ast.Expr(value=ast.Ellipsis())]
  430. return node
  431. def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:
  432. """Remove non-annotated assignment statements.
  433. Args:
  434. node: The Assign node to visit.
  435. Returns:
  436. The modified Assign node (or None).
  437. """
  438. # Special case for assignments to `typing.Any` as fallback.
  439. if (
  440. node.value is not None
  441. and isinstance(node.value, ast.Name)
  442. and node.value.id == "Any"
  443. ):
  444. return node
  445. if self.current_class in self.classes:
  446. # Remove annotated assignments in Component classes (props)
  447. return None
  448. return node
  449. def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign | None:
  450. """Visit an AnnAssign node (Annotated assignment).
  451. Remove private target and remove the assignment value in the stub.
  452. Args:
  453. node: The AnnAssign node to visit.
  454. Returns:
  455. The modified AnnAssign node (or None).
  456. """
  457. if isinstance(node.target, ast.Name) and node.target.id.startswith("_"):
  458. return None
  459. if self.current_class in self.classes:
  460. # Remove annotated assignments in Component classes (props)
  461. return None
  462. # Blank out assignments in type stubs.
  463. node.value = None
  464. return node
  465. class PyiGenerator:
  466. """A .pyi file generator that will scan all defined Component in Reflex and
  467. generate the approriate stub.
  468. """
  469. modules: list = []
  470. root: str = ""
  471. current_module: Any = {}
  472. default_typing_imports: set = DEFAULT_TYPING_IMPORTS
  473. def _write_pyi_file(self, module_path: Path, source: str):
  474. pyi_content = [
  475. f'"""Stub file for {module_path}"""',
  476. "# ------------------- DO NOT EDIT ----------------------",
  477. "# This file was generated by `scripts/pyi_generator.py`!",
  478. "# ------------------------------------------------------",
  479. "",
  480. ]
  481. for formatted_line in black.format_file_contents(
  482. src_contents=source,
  483. fast=True,
  484. mode=black.mode.Mode(is_pyi=True),
  485. ).splitlines():
  486. # Bit of a hack here, since the AST cannot represent comments.
  487. if "def create(" in formatted_line:
  488. pyi_content.append(formatted_line + " # type: ignore")
  489. elif "Figure" in formatted_line:
  490. pyi_content.append(formatted_line + " # type: ignore")
  491. else:
  492. pyi_content.append(formatted_line)
  493. pyi_content.append("") # add empty line at the end for formatting
  494. pyi_path = module_path.with_suffix(".pyi")
  495. pyi_path.write_text("\n".join(pyi_content))
  496. logger.info(f"Wrote {pyi_path}")
  497. def _scan_file(self, module_path: Path):
  498. module_import = str(module_path.with_suffix("")).replace("/", ".")
  499. module = importlib.import_module(module_import)
  500. class_names = {
  501. name: obj
  502. for name, obj in vars(module).items()
  503. if inspect.isclass(obj)
  504. and issubclass(obj, Component)
  505. and obj != Component
  506. and inspect.getmodule(obj) == module
  507. }
  508. if not class_names:
  509. return
  510. new_tree = StubGenerator(module, class_names).visit(
  511. ast.parse(inspect.getsource(module))
  512. )
  513. self._write_pyi_file(module_path, ast.unparse(new_tree))
  514. def _scan_folder(self, folder):
  515. for root, _, files in os.walk(folder):
  516. for file in files:
  517. if file in EXCLUDED_FILES:
  518. continue
  519. if file.endswith(".py"):
  520. self._scan_file(Path(root) / file)
  521. def scan_all(self, targets):
  522. """Scan all targets for class inheriting Component and generate the .pyi files.
  523. Args:
  524. targets: the list of file/folders to scan.
  525. """
  526. for target in targets:
  527. if target.endswith(".py"):
  528. self._scan_file(Path(target))
  529. else:
  530. self._scan_folder(target)
  531. def generate_init():
  532. """Generate a pyi file for the main __init__.py."""
  533. from reflex import _MAPPING # type: ignore
  534. imports = [
  535. f"from {path if mod != path.rsplit('.')[-1] or mod == 'page' else '.'.join(path.rsplit('.')[:-1])} import {mod} as {mod}"
  536. for mod, path in _MAPPING.items()
  537. ]
  538. imports.append("")
  539. with open("reflex/__init__.pyi", "w") as pyi_file:
  540. pyi_file.writelines("\n".join(imports))
  541. if __name__ == "__main__":
  542. logging.basicConfig(level=logging.DEBUG)
  543. logging.getLogger("blib2to3.pgen2.driver").setLevel(logging.INFO)
  544. targets = sys.argv[1:] if len(sys.argv) > 1 else ["reflex/components"]
  545. logger.info(f"Running .pyi generator for {targets}")
  546. gen = PyiGenerator()
  547. gen.scan_all(targets)
  548. generate_init()