pyi_generator.py 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348
  1. """The pyi generator module."""
  2. from __future__ import annotations
  3. import ast
  4. import contextlib
  5. import importlib
  6. import inspect
  7. import json
  8. import logging
  9. import re
  10. import subprocess
  11. import sys
  12. import typing
  13. from collections.abc import Callable, Iterable, Sequence
  14. from fileinput import FileInput
  15. from hashlib import md5
  16. from inspect import getfullargspec
  17. from itertools import chain
  18. from multiprocessing import Pool, cpu_count
  19. from pathlib import Path
  20. from types import ModuleType, SimpleNamespace, UnionType
  21. from typing import Any, get_args, get_origin
  22. from reflex.components.component import Component
  23. from reflex.utils import types as rx_types
  24. from reflex.vars.base import Var
  25. logger = logging.getLogger("pyi_generator")
  26. PWD = Path.cwd()
  27. PYI_HASHES = "pyi_hashes.json"
  28. EXCLUDED_FILES = [
  29. "app.py",
  30. "component.py",
  31. "bare.py",
  32. "foreach.py",
  33. "cond.py",
  34. "match.py",
  35. "multiselect.py",
  36. "literals.py",
  37. ]
  38. # These props exist on the base component, but should not be exposed in create methods.
  39. EXCLUDED_PROPS = [
  40. "alias",
  41. "children",
  42. "event_triggers",
  43. "library",
  44. "lib_dependencies",
  45. "tag",
  46. "is_default",
  47. "special_props",
  48. "_is_tag_in_global_scope",
  49. "_invalid_children",
  50. "_memoization_mode",
  51. "_rename_props",
  52. "_valid_children",
  53. "_valid_parents",
  54. "State",
  55. ]
  56. OVERWRITE_TYPES = {
  57. "style": "Sequence[Mapping[str, Any]] | Mapping[str, Any] | Var[Mapping[str, Any]] | Breakpoints | None",
  58. }
  59. DEFAULT_TYPING_IMPORTS = {
  60. "overload",
  61. "Any",
  62. "Callable",
  63. "Dict",
  64. # "List",
  65. "Sequence",
  66. "Mapping",
  67. "Literal",
  68. "Optional",
  69. "Union",
  70. "Annotated",
  71. }
  72. # TODO: fix import ordering and unused imports with ruff later
  73. DEFAULT_IMPORTS = {
  74. "typing": sorted(DEFAULT_TYPING_IMPORTS),
  75. "reflex.components.core.breakpoints": ["Breakpoints"],
  76. "reflex.event": [
  77. "EventChain",
  78. "EventHandler",
  79. "EventSpec",
  80. "EventType",
  81. "KeyInputInfo",
  82. ],
  83. "reflex.style": ["Style"],
  84. "reflex.vars.base": ["Var"],
  85. }
  86. def _walk_files(path: str | Path):
  87. """Walk all files in a path.
  88. This can be replaced with Path.walk() in python3.12.
  89. Args:
  90. path: The path to walk.
  91. Yields:
  92. The next file in the path.
  93. """
  94. for p in Path(path).iterdir():
  95. if p.is_dir():
  96. yield from _walk_files(p)
  97. continue
  98. yield p.resolve()
  99. def _relative_to_pwd(path: Path) -> Path:
  100. """Get the relative path of a path to the current working directory.
  101. Args:
  102. path: The path to get the relative path for.
  103. Returns:
  104. The relative path.
  105. """
  106. if path.is_absolute():
  107. return path.relative_to(PWD)
  108. return path
  109. def _get_type_hint(
  110. value: Any, type_hint_globals: dict, is_optional: bool = True
  111. ) -> str:
  112. """Resolve the type hint for value.
  113. Args:
  114. value: The type annotation as a str or actual types/aliases.
  115. type_hint_globals: The globals to use to resolving a type hint str.
  116. is_optional: Whether the type hint should be wrapped in Optional.
  117. Returns:
  118. The resolved type hint as a str.
  119. Raises:
  120. TypeError: If the value name is not visible in the type hint globals.
  121. """
  122. res = ""
  123. args = get_args(value)
  124. if value is type(None):
  125. return "None"
  126. if rx_types.is_union(value):
  127. if type(None) in value.__args__:
  128. res_args = [
  129. _get_type_hint(arg, type_hint_globals, rx_types.is_optional(arg))
  130. for arg in value.__args__
  131. if arg is not type(None)
  132. ]
  133. res_args.sort()
  134. if len(res_args) == 1:
  135. return f"{res_args[0]} | None"
  136. res = f"{' | '.join(res_args)}"
  137. return f"{res} | None"
  138. res_args = [
  139. _get_type_hint(arg, type_hint_globals, rx_types.is_optional(arg))
  140. for arg in value.__args__
  141. ]
  142. res_args.sort()
  143. return f"{' | '.join(res_args)}"
  144. if args:
  145. inner_container_type_args = (
  146. sorted(repr(arg) for arg in args)
  147. if rx_types.is_literal(value)
  148. else [
  149. _get_type_hint(arg, type_hint_globals, is_optional=False)
  150. for arg in args
  151. if arg is not type(None)
  152. ]
  153. )
  154. if (
  155. value.__module__ not in ["builtins", "__builtins__"]
  156. and value.__name__ not in type_hint_globals
  157. ):
  158. msg = (
  159. f"{value.__module__ + '.' + value.__name__} is not a default import, "
  160. "add it to DEFAULT_IMPORTS in pyi_generator.py"
  161. )
  162. raise TypeError(msg)
  163. res = f"{value.__name__}[{', '.join(inner_container_type_args)}]"
  164. if value.__name__ == "Var":
  165. args = list(
  166. chain.from_iterable(
  167. [get_args(arg) if rx_types.is_union(arg) else [arg] for arg in args]
  168. )
  169. )
  170. # For Var types, Union with the inner args so they can be passed directly.
  171. types = [res] + [
  172. _get_type_hint(arg, type_hint_globals, is_optional=False)
  173. for arg in args
  174. if arg is not type(None)
  175. ]
  176. if len(types) > 1:
  177. res = " | ".join(sorted(types))
  178. elif isinstance(value, str):
  179. ev = eval(value, type_hint_globals)
  180. if rx_types.is_optional(ev):
  181. return _get_type_hint(ev, type_hint_globals, is_optional=False)
  182. if rx_types.is_union(ev):
  183. res = [
  184. _get_type_hint(arg, type_hint_globals, rx_types.is_optional(arg))
  185. for arg in ev.__args__
  186. ]
  187. return f"{' | '.join(res)}"
  188. res = (
  189. _get_type_hint(ev, type_hint_globals, is_optional=False)
  190. if ev.__name__ == "Var"
  191. else value
  192. )
  193. elif isinstance(value, list):
  194. res = [
  195. _get_type_hint(arg, type_hint_globals, rx_types.is_optional(arg))
  196. for arg in value
  197. ]
  198. return f"[{', '.join(res)}]"
  199. else:
  200. res = value.__name__
  201. if is_optional and not res.startswith("Optional") and not res.endswith("| None"):
  202. res = f"{res} | None"
  203. return res
  204. def _generate_imports(
  205. typing_imports: Iterable[str],
  206. ) -> list[ast.ImportFrom | ast.Import]:
  207. """Generate the import statements for the stub file.
  208. Args:
  209. typing_imports: The typing imports to include.
  210. Returns:
  211. The list of import statements.
  212. """
  213. return [
  214. *[
  215. ast.ImportFrom(module=name, names=[ast.alias(name=val) for val in values]) # pyright: ignore [reportCallIssue]
  216. for name, values in DEFAULT_IMPORTS.items()
  217. ],
  218. ast.Import([ast.alias("reflex")]),
  219. ]
  220. def _generate_docstrings(clzs: list[type[Component]], props: list[str]) -> str:
  221. """Generate the docstrings for the create method.
  222. Args:
  223. clzs: The classes to generate docstrings for.
  224. props: The props to generate docstrings for.
  225. Returns:
  226. The docstring for the create method.
  227. """
  228. props_comments = {}
  229. comments = []
  230. for clz in clzs:
  231. for line in inspect.getsource(clz).splitlines():
  232. reached_functions = re.search("def ", line)
  233. if reached_functions:
  234. # We've reached the functions, so stop.
  235. break
  236. if line == "":
  237. # We hit a blank line, so clear comments to avoid commented out prop appearing in next prop docs.
  238. comments.clear()
  239. continue
  240. # Get comments for prop
  241. if line.strip().startswith("#"):
  242. # Remove noqa from the comments.
  243. line = line.partition(" # noqa")[0]
  244. comments.append(line)
  245. continue
  246. # Check if this line has a prop.
  247. match = re.search("\\w+:", line)
  248. if match is None:
  249. # This line doesn't have a var, so continue.
  250. continue
  251. # Get the prop.
  252. prop = match.group(0).strip(":")
  253. if prop in props:
  254. if not comments: # do not include undocumented props
  255. continue
  256. props_comments[prop] = [
  257. comment.strip().strip("#") for comment in comments
  258. ]
  259. comments.clear()
  260. clz = clzs[0]
  261. new_docstring = []
  262. for line in (clz.create.__doc__ or "").splitlines():
  263. if "**" in line:
  264. indent = line.split("**")[0]
  265. new_docstring.extend(
  266. [f"{indent}{n}:{' '.join(c)}" for n, c in props_comments.items()]
  267. )
  268. new_docstring.append(line)
  269. return "\n".join(new_docstring)
  270. def _extract_func_kwargs_as_ast_nodes(
  271. func: Callable,
  272. type_hint_globals: dict[str, Any],
  273. ) -> list[tuple[ast.arg, ast.Constant | None]]:
  274. """Get the kwargs already defined on the function.
  275. Args:
  276. func: The function to extract kwargs from.
  277. type_hint_globals: The globals to use to resolving a type hint str.
  278. Returns:
  279. The list of kwargs as ast arg nodes.
  280. """
  281. spec = getfullargspec(func)
  282. kwargs = []
  283. for kwarg in spec.kwonlyargs:
  284. arg = ast.arg(arg=kwarg)
  285. if kwarg in spec.annotations:
  286. arg.annotation = ast.Name(
  287. id=_get_type_hint(spec.annotations[kwarg], type_hint_globals)
  288. )
  289. default = None
  290. if spec.kwonlydefaults is not None and kwarg in spec.kwonlydefaults:
  291. default = ast.Constant(value=spec.kwonlydefaults[kwarg])
  292. kwargs.append((arg, default))
  293. return kwargs
  294. def _extract_class_props_as_ast_nodes(
  295. func: Callable,
  296. clzs: list[type],
  297. type_hint_globals: dict[str, Any],
  298. extract_real_default: bool = False,
  299. ) -> list[tuple[ast.arg, ast.Constant | None]]:
  300. """Get the props defined on the class and all parents.
  301. Args:
  302. func: The function that kwargs will be added to.
  303. clzs: The classes to extract props from.
  304. type_hint_globals: The globals to use to resolving a type hint str.
  305. extract_real_default: Whether to extract the real default value from the
  306. pydantic field definition.
  307. Returns:
  308. The list of props as ast arg nodes
  309. """
  310. spec = getfullargspec(func)
  311. all_props = []
  312. kwargs = []
  313. for target_class in clzs:
  314. event_triggers = target_class._create([]).get_event_triggers()
  315. # Import from the target class to ensure type hints are resolvable.
  316. exec(f"from {target_class.__module__} import *", type_hint_globals)
  317. for name, value in target_class.__annotations__.items():
  318. if (
  319. name in spec.kwonlyargs
  320. or name in EXCLUDED_PROPS
  321. or name in all_props
  322. or name in event_triggers
  323. or (isinstance(value, str) and "ClassVar" in value)
  324. ):
  325. continue
  326. all_props.append(name)
  327. default = None
  328. if extract_real_default:
  329. # TODO: This is not currently working since the default is not type compatible
  330. # with the annotation in some cases.
  331. with contextlib.suppress(AttributeError, KeyError):
  332. # Try to get default from pydantic field definition.
  333. default = target_class.__fields__[name].default
  334. if isinstance(default, Var):
  335. default = default._decode()
  336. modules = {cls.__module__ for cls in target_class.__mro__}
  337. available_vars = {}
  338. for module in modules:
  339. available_vars.update(sys.modules[module].__dict__)
  340. kwargs.append(
  341. (
  342. ast.arg(
  343. arg=name,
  344. annotation=ast.Name(
  345. id=OVERWRITE_TYPES.get(
  346. name,
  347. _get_type_hint(
  348. value,
  349. type_hint_globals | available_vars,
  350. ),
  351. )
  352. ),
  353. ),
  354. ast.Constant(value=default),
  355. )
  356. )
  357. return kwargs
  358. def type_to_ast(typ: Any, cls: type) -> ast.expr:
  359. """Converts any type annotation into its AST representation.
  360. Handles nested generic types, unions, etc.
  361. Args:
  362. typ: The type annotation to convert.
  363. cls: The class where the type annotation is used.
  364. Returns:
  365. The AST representation of the type annotation.
  366. """
  367. if typ is type(None):
  368. return ast.Name(id="None")
  369. origin = get_origin(typ)
  370. if origin is UnionType:
  371. origin = typing.Union
  372. # Handle plain types (int, str, custom classes, etc.)
  373. if origin is None:
  374. if hasattr(typ, "__name__"):
  375. if typ.__module__.startswith("reflex."):
  376. typ_parts = typ.__module__.split(".")
  377. cls_parts = cls.__module__.split(".")
  378. zipped = list(zip(typ_parts, cls_parts, strict=False))
  379. if all(a == b for a, b in zipped) and len(typ_parts) == len(cls_parts):
  380. return ast.Name(id=typ.__name__)
  381. return ast.Name(id=typ.__module__ + "." + typ.__name__)
  382. return ast.Name(id=typ.__name__)
  383. if hasattr(typ, "_name"):
  384. return ast.Name(id=typ._name)
  385. return ast.Name(id=str(typ))
  386. # Get the base type name (List, Dict, Optional, etc.)
  387. base_name = getattr(origin, "_name", origin.__name__)
  388. # Get type arguments
  389. args = get_args(typ)
  390. # Handle empty type arguments
  391. if not args:
  392. return ast.Name(id=base_name)
  393. # Convert all type arguments recursively
  394. arg_nodes = [type_to_ast(arg, cls) for arg in args]
  395. # Special case for single-argument types (like list[T] or Optional[T])
  396. if len(arg_nodes) == 1:
  397. slice_value = arg_nodes[0]
  398. else:
  399. slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load())
  400. return ast.Subscript(
  401. value=ast.Name(id=base_name),
  402. slice=slice_value,
  403. ctx=ast.Load(),
  404. )
  405. def _get_parent_imports(func: Callable):
  406. _imports = {"reflex.vars": ["Var"]}
  407. for type_hint in inspect.get_annotations(func).values():
  408. try:
  409. match = re.match(r"\w+\[([\w\d]+)\]", type_hint)
  410. except TypeError:
  411. continue
  412. if match:
  413. type_hint = match.group(1)
  414. if type_hint in importlib.import_module(func.__module__).__dir__():
  415. _imports.setdefault(func.__module__, []).append(type_hint)
  416. return _imports
  417. def _generate_component_create_functiondef(
  418. clz: type[Component],
  419. type_hint_globals: dict[str, Any],
  420. lineno: int,
  421. decorator_list: Sequence[ast.expr] = (ast.Name(id="classmethod"),),
  422. ) -> ast.FunctionDef:
  423. """Generate the create function definition for a Component.
  424. Args:
  425. clz: The Component class to generate the create functiondef for.
  426. type_hint_globals: The globals to use to resolving a type hint str.
  427. lineno: The line number to use for the ast nodes.
  428. decorator_list: The list of decorators to apply to the create functiondef.
  429. Returns:
  430. The create functiondef node for the ast.
  431. Raises:
  432. TypeError: If clz is not a subclass of Component.
  433. """
  434. if not issubclass(clz, Component):
  435. msg = f"clz must be a subclass of Component, not {clz!r}"
  436. raise TypeError(msg)
  437. # add the imports needed by get_type_hint later
  438. type_hint_globals.update(
  439. {name: getattr(typing, name) for name in DEFAULT_TYPING_IMPORTS}
  440. )
  441. if clz.__module__ != clz.create.__module__:
  442. _imports = _get_parent_imports(clz.create)
  443. for name, values in _imports.items():
  444. exec(f"from {name} import {','.join(values)}", type_hint_globals)
  445. kwargs = _extract_func_kwargs_as_ast_nodes(clz.create, type_hint_globals)
  446. # kwargs associated with props defined in the class and its parents
  447. all_classes = [c for c in clz.__mro__ if issubclass(c, Component)]
  448. prop_kwargs = _extract_class_props_as_ast_nodes(
  449. clz.create, all_classes, type_hint_globals
  450. )
  451. all_props = [arg[0].arg for arg in prop_kwargs]
  452. kwargs.extend(prop_kwargs)
  453. def figure_out_return_type(annotation: Any):
  454. if inspect.isclass(annotation) and issubclass(annotation, inspect._empty):
  455. return ast.Name(id="EventType[Any]")
  456. if not isinstance(annotation, str) and get_origin(annotation) is tuple:
  457. arguments = get_args(annotation)
  458. arguments_without_var = [
  459. get_args(argument)[0] if get_origin(argument) == Var else argument
  460. for argument in arguments
  461. ]
  462. # Convert each argument type to its AST representation
  463. type_args = [type_to_ast(arg, cls=clz) for arg in arguments_without_var]
  464. # Get all prefixes of the type arguments
  465. all_count_args_type = [
  466. ast.Name(
  467. f"EventType[{', '.join([ast.unparse(arg) for arg in type_args[:i]])}]"
  468. )
  469. if i > 0
  470. else ast.Name("EventType[()]")
  471. for i in range(len(type_args) + 1)
  472. ]
  473. # Create EventType using the joined string
  474. return ast.Name(id=f"{' | '.join(map(ast.unparse, all_count_args_type))}")
  475. if isinstance(annotation, str) and annotation.lower().startswith("tuple["):
  476. inside_of_tuple = (
  477. annotation.removeprefix("tuple[")
  478. .removeprefix("Tuple[")
  479. .removesuffix("]")
  480. )
  481. if inside_of_tuple == "()":
  482. return ast.Name(id="EventType[()]")
  483. arguments = [""]
  484. bracket_count = 0
  485. for char in inside_of_tuple:
  486. if char == "[":
  487. bracket_count += 1
  488. elif char == "]":
  489. bracket_count -= 1
  490. if char == "," and bracket_count == 0:
  491. arguments.append("")
  492. else:
  493. arguments[-1] += char
  494. arguments = [argument.strip() for argument in arguments]
  495. arguments_without_var = [
  496. argument.removeprefix("Var[").removesuffix("]")
  497. if argument.startswith("Var[")
  498. else argument
  499. for argument in arguments
  500. ]
  501. all_count_args_type = [
  502. ast.Name(f"EventType[{', '.join(arguments_without_var[:i])}]")
  503. if i > 0
  504. else ast.Name("EventType[()]")
  505. for i in range(len(arguments) + 1)
  506. ]
  507. return ast.Name(id=f"{' | '.join(map(ast.unparse, all_count_args_type))}")
  508. return ast.Name(id="EventType[Any]")
  509. event_triggers = clz._create([]).get_event_triggers()
  510. # event handler kwargs
  511. kwargs.extend(
  512. (
  513. ast.arg(
  514. arg=trigger,
  515. annotation=ast.Subscript(
  516. ast.Name("Optional"),
  517. ast.Name(
  518. id=ast.unparse(
  519. figure_out_return_type(
  520. inspect.signature(event_specs).return_annotation
  521. )
  522. if not isinstance(
  523. event_specs := event_triggers[trigger], Sequence
  524. )
  525. else ast.Subscript(
  526. ast.Name("Union"),
  527. ast.Tuple(
  528. [
  529. figure_out_return_type(
  530. inspect.signature(
  531. event_spec
  532. ).return_annotation
  533. )
  534. for event_spec in event_specs
  535. ]
  536. ),
  537. )
  538. )
  539. ),
  540. ),
  541. ),
  542. ast.Constant(value=None),
  543. )
  544. for trigger in sorted(event_triggers)
  545. )
  546. logger.debug(f"Generated {clz.__name__}.create method with {len(kwargs)} kwargs")
  547. create_args = ast.arguments(
  548. args=[ast.arg(arg="cls")],
  549. posonlyargs=[],
  550. vararg=ast.arg(arg="children"),
  551. kwonlyargs=[arg[0] for arg in kwargs],
  552. kw_defaults=[arg[1] for arg in kwargs],
  553. kwarg=ast.arg(arg="props"),
  554. defaults=[],
  555. )
  556. return ast.FunctionDef( # pyright: ignore [reportCallIssue]
  557. name="create",
  558. args=create_args,
  559. body=[
  560. ast.Expr(
  561. value=ast.Constant(
  562. value=_generate_docstrings(
  563. all_classes, [*all_props, *event_triggers]
  564. )
  565. ),
  566. ),
  567. ast.Expr(
  568. value=ast.Constant(value=Ellipsis),
  569. ),
  570. ],
  571. decorator_list=[
  572. ast.Name(id="overload"),
  573. *decorator_list,
  574. ],
  575. lineno=lineno,
  576. returns=ast.Constant(value=clz.__name__),
  577. )
  578. def _generate_staticmethod_call_functiondef(
  579. node: ast.ClassDef,
  580. clz: type[Component] | type[SimpleNamespace],
  581. type_hint_globals: dict[str, Any],
  582. ) -> ast.FunctionDef | None:
  583. fullspec = getfullargspec(clz.__call__)
  584. call_args = ast.arguments(
  585. args=[
  586. ast.arg(
  587. name,
  588. annotation=ast.Name(
  589. id=_get_type_hint(
  590. anno := fullspec.annotations[name],
  591. type_hint_globals,
  592. is_optional=rx_types.is_optional(anno),
  593. )
  594. ),
  595. )
  596. for name in fullspec.args
  597. ],
  598. posonlyargs=[],
  599. kwonlyargs=[],
  600. kw_defaults=[],
  601. kwarg=ast.arg(arg="props"),
  602. defaults=(
  603. [ast.Constant(value=default) for default in fullspec.defaults]
  604. if fullspec.defaults
  605. else []
  606. ),
  607. )
  608. return ast.FunctionDef( # pyright: ignore [reportCallIssue]
  609. name="__call__",
  610. args=call_args,
  611. body=[
  612. ast.Expr(value=ast.Constant(value=clz.__call__.__doc__)),
  613. ast.Expr(
  614. value=ast.Constant(...),
  615. ),
  616. ],
  617. decorator_list=[ast.Name(id="staticmethod")],
  618. lineno=node.lineno,
  619. returns=ast.Constant(
  620. value=_get_type_hint(
  621. typing.get_type_hints(clz.__call__).get("return", None),
  622. type_hint_globals,
  623. is_optional=False,
  624. )
  625. ),
  626. )
  627. def _generate_namespace_call_functiondef(
  628. node: ast.ClassDef,
  629. clz_name: str,
  630. classes: dict[str, type[Component] | type[SimpleNamespace]],
  631. type_hint_globals: dict[str, Any],
  632. ) -> ast.FunctionDef | None:
  633. """Generate the __call__ function definition for a SimpleNamespace.
  634. Args:
  635. node: The existing __call__ classdef parent node from the ast
  636. clz_name: The name of the SimpleNamespace class to generate the __call__ functiondef for.
  637. classes: Map name to actual class definition.
  638. type_hint_globals: The globals to use to resolving a type hint str.
  639. Returns:
  640. The create functiondef node for the ast.
  641. """
  642. # add the imports needed by get_type_hint later
  643. type_hint_globals.update(
  644. {name: getattr(typing, name) for name in DEFAULT_TYPING_IMPORTS}
  645. )
  646. clz = classes[clz_name]
  647. if not hasattr(clz.__call__, "__self__"):
  648. return _generate_staticmethod_call_functiondef(node, clz, type_hint_globals)
  649. # Determine which class is wrapped by the namespace __call__ method
  650. component_clz = clz.__call__.__self__
  651. if clz.__call__.__func__.__name__ != "create": # pyright: ignore [reportFunctionMemberAccess]
  652. return None
  653. if not issubclass(component_clz, Component):
  654. return None
  655. definition = _generate_component_create_functiondef(
  656. clz=component_clz,
  657. type_hint_globals=type_hint_globals,
  658. lineno=node.lineno,
  659. decorator_list=[],
  660. )
  661. definition.name = "__call__"
  662. # Turn the definition into a staticmethod
  663. del definition.args.args[0] # remove `cls` arg
  664. definition.decorator_list = [ast.Name(id="staticmethod")]
  665. return definition
  666. class StubGenerator(ast.NodeTransformer):
  667. """A node transformer that will generate the stubs for a given module."""
  668. def __init__(
  669. self, module: ModuleType, classes: dict[str, type[Component | SimpleNamespace]]
  670. ):
  671. """Initialize the stub generator.
  672. Args:
  673. module: The actual module object module to generate stubs for.
  674. classes: The actual Component class objects to generate stubs for.
  675. """
  676. super().__init__()
  677. # Dict mapping class name to actual class object.
  678. self.classes = classes
  679. # Track the last class node that was visited.
  680. self.current_class = None
  681. # These imports will be included in the AST of stub files.
  682. self.typing_imports = DEFAULT_TYPING_IMPORTS.copy()
  683. # Whether those typing imports have been inserted yet.
  684. self.inserted_imports = False
  685. # Collected import statements from the module.
  686. self.import_statements: list[str] = []
  687. # This dict is used when evaluating type hints.
  688. self.type_hint_globals = module.__dict__.copy()
  689. @staticmethod
  690. def _remove_docstring(
  691. node: ast.Module | ast.ClassDef | ast.FunctionDef,
  692. ) -> ast.Module | ast.ClassDef | ast.FunctionDef:
  693. """Removes any docstring in place.
  694. Args:
  695. node: The node to remove the docstring from.
  696. Returns:
  697. The modified node.
  698. """
  699. if (
  700. node.body
  701. and isinstance(node.body[0], ast.Expr)
  702. and isinstance(node.body[0].value, ast.Constant)
  703. ):
  704. node.body.pop(0)
  705. return node
  706. def _current_class_is_component(self) -> type[Component] | None:
  707. """Check if the current class is a Component.
  708. Returns:
  709. Whether the current class is a Component.
  710. """
  711. if (
  712. self.current_class is not None
  713. and self.current_class in self.classes
  714. and issubclass((clz := self.classes[self.current_class]), Component)
  715. ):
  716. return clz
  717. return None
  718. def visit_Module(self, node: ast.Module) -> ast.Module:
  719. """Visit a Module node and remove docstring from body.
  720. Args:
  721. node: The Module node to visit.
  722. Returns:
  723. The modified Module node.
  724. """
  725. self.generic_visit(node)
  726. return self._remove_docstring(node) # pyright: ignore [reportReturnType]
  727. def visit_Import(
  728. self, node: ast.Import | ast.ImportFrom
  729. ) -> ast.Import | ast.ImportFrom | list[ast.Import | ast.ImportFrom]:
  730. """Collect import statements from the module.
  731. If this is the first import statement, insert the typing imports before it.
  732. Args:
  733. node: The import node to visit.
  734. Returns:
  735. The modified import node(s).
  736. """
  737. self.import_statements.append(ast.unparse(node))
  738. if not self.inserted_imports:
  739. self.inserted_imports = True
  740. default_imports = _generate_imports(self.typing_imports)
  741. self.import_statements.extend(ast.unparse(i) for i in default_imports)
  742. return [*default_imports, node]
  743. return node
  744. def visit_ImportFrom(
  745. self, node: ast.ImportFrom
  746. ) -> ast.Import | ast.ImportFrom | list[ast.Import | ast.ImportFrom] | None:
  747. """Visit an ImportFrom node.
  748. Remove any `from __future__ import *` statements, and hand off to visit_Import.
  749. Args:
  750. node: The ImportFrom node to visit.
  751. Returns:
  752. The modified ImportFrom node.
  753. """
  754. if node.module == "__future__":
  755. return None # ignore __future__ imports
  756. return self.visit_Import(node)
  757. def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
  758. """Visit a ClassDef node.
  759. Remove all assignments in the class body, and add a create functiondef
  760. if one does not exist.
  761. Args:
  762. node: The ClassDef node to visit.
  763. Returns:
  764. The modified ClassDef node.
  765. """
  766. exec("\n".join(self.import_statements), self.type_hint_globals)
  767. self.current_class = node.name
  768. self._remove_docstring(node)
  769. # Define `__call__` as a real function so the docstring appears in the stub.
  770. call_definition = None
  771. for child in node.body[:]:
  772. found_call = False
  773. if (
  774. isinstance(child, ast.AnnAssign)
  775. and isinstance(child.target, ast.Name)
  776. and child.target.id.startswith("_")
  777. ):
  778. node.body.remove(child)
  779. if isinstance(child, ast.Assign):
  780. for target in child.targets[:]:
  781. if isinstance(target, ast.Name) and target.id == "__call__":
  782. child.targets.remove(target)
  783. found_call = True
  784. if not found_call:
  785. continue
  786. if not child.targets[:]:
  787. node.body.remove(child)
  788. call_definition = _generate_namespace_call_functiondef(
  789. node,
  790. self.current_class,
  791. self.classes,
  792. type_hint_globals=self.type_hint_globals,
  793. )
  794. break
  795. self.generic_visit(node) # Visit child nodes.
  796. if (
  797. not any(
  798. isinstance(child, ast.FunctionDef) and child.name == "create"
  799. for child in node.body
  800. )
  801. and (clz := self._current_class_is_component()) is not None
  802. ):
  803. # Add a new .create FunctionDef since one does not exist.
  804. node.body.append(
  805. _generate_component_create_functiondef(
  806. clz=clz,
  807. type_hint_globals=self.type_hint_globals,
  808. lineno=node.lineno,
  809. )
  810. )
  811. if call_definition is not None:
  812. node.body.append(call_definition)
  813. if not node.body:
  814. # We should never return an empty body.
  815. node.body.append(ast.Expr(value=ast.Constant(value=Ellipsis)))
  816. self.current_class = None
  817. return node
  818. def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
  819. """Visit a FunctionDef node.
  820. Special handling for `.create` functions to add type hints for all props
  821. defined on the component class.
  822. Remove all private functions and blank out the function body of the
  823. remaining public functions.
  824. Args:
  825. node: The FunctionDef node to visit.
  826. Returns:
  827. The modified FunctionDef node (or None).
  828. """
  829. if (
  830. node.name == "create"
  831. and self.current_class in self.classes
  832. and issubclass((clz := self.classes[self.current_class]), Component)
  833. ):
  834. node = _generate_component_create_functiondef(
  835. clz=clz,
  836. type_hint_globals=self.type_hint_globals,
  837. lineno=node.lineno,
  838. decorator_list=node.decorator_list,
  839. )
  840. else:
  841. if node.name.startswith("_") and node.name != "__call__":
  842. return None # remove private methods
  843. if node.body[-1] != ast.Expr(value=ast.Constant(value=Ellipsis)):
  844. # Blank out the function body for public functions.
  845. node.body = [ast.Expr(value=ast.Constant(value=Ellipsis))]
  846. return node
  847. def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:
  848. """Remove non-annotated assignment statements.
  849. Args:
  850. node: The Assign node to visit.
  851. Returns:
  852. The modified Assign node (or None).
  853. """
  854. # Special case for assignments to `typing.Any` as fallback.
  855. if (
  856. node.value is not None
  857. and isinstance(node.value, ast.Name)
  858. and node.value.id == "Any"
  859. ):
  860. return node
  861. if self._current_class_is_component():
  862. # Remove annotated assignments in Component classes (props)
  863. return None
  864. # remove dunder method assignments for lazy_loader.attach
  865. for target in node.targets:
  866. if isinstance(target, ast.Tuple):
  867. for name in target.elts:
  868. if isinstance(name, ast.Name) and name.id.startswith("_"):
  869. return None
  870. return node
  871. def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign | None:
  872. """Visit an AnnAssign node (Annotated assignment).
  873. Remove private target and remove the assignment value in the stub.
  874. Args:
  875. node: The AnnAssign node to visit.
  876. Returns:
  877. The modified AnnAssign node (or None).
  878. """
  879. # skip ClassVars
  880. if (
  881. isinstance(node.annotation, ast.Subscript)
  882. and isinstance(node.annotation.value, ast.Name)
  883. and node.annotation.value.id == "ClassVar"
  884. ):
  885. return node
  886. if isinstance(node.target, ast.Name) and node.target.id.startswith("_"):
  887. return None
  888. if self._current_class_is_component():
  889. # Remove annotated assignments in Component classes (props)
  890. return None
  891. # Blank out assignments in type stubs.
  892. node.value = None
  893. return node
  894. class InitStubGenerator(StubGenerator):
  895. """A node transformer that will generate the stubs for a given init file."""
  896. def visit_Import(
  897. self, node: ast.Import | ast.ImportFrom
  898. ) -> ast.Import | ast.ImportFrom | list[ast.Import | ast.ImportFrom]:
  899. """Collect import statements from the init module.
  900. Args:
  901. node: The import node to visit.
  902. Returns:
  903. The modified import node(s).
  904. """
  905. return [node]
  906. class PyiGenerator:
  907. """A .pyi file generator that will scan all defined Component in Reflex and
  908. generate the appropriate stub.
  909. """
  910. modules: list = []
  911. root: str = ""
  912. current_module: Any = {}
  913. written_files: list[tuple[str, str]] = []
  914. def _write_pyi_file(self, module_path: Path, source: str) -> str:
  915. relpath = str(_relative_to_pwd(module_path)).replace("\\", "/")
  916. pyi_content = (
  917. "\n".join(
  918. [
  919. f'"""Stub file for {relpath}"""',
  920. "# ------------------- DO NOT EDIT ----------------------",
  921. "# This file was generated by `reflex/utils/pyi_generator.py`!",
  922. "# ------------------------------------------------------",
  923. "",
  924. ]
  925. )
  926. + source
  927. )
  928. pyi_path = module_path.with_suffix(".pyi")
  929. pyi_path.write_text(pyi_content)
  930. logger.info(f"Wrote {relpath}")
  931. return md5(pyi_content.encode()).hexdigest()
  932. def _get_init_lazy_imports(self, mod: tuple | ModuleType, new_tree: ast.AST):
  933. # retrieve the _SUBMODULES and _SUBMOD_ATTRS from an init file if present.
  934. sub_mods = getattr(mod, "_SUBMODULES", None)
  935. sub_mod_attrs = getattr(mod, "_SUBMOD_ATTRS", None)
  936. pyright_ignore_imports = getattr(mod, "_PYRIGHT_IGNORE_IMPORTS", [])
  937. if not sub_mods and not sub_mod_attrs:
  938. return None
  939. sub_mods_imports = []
  940. sub_mod_attrs_imports = []
  941. if sub_mods:
  942. sub_mods_imports = [
  943. f"from . import {mod} as {mod}" for mod in sorted(sub_mods)
  944. ]
  945. sub_mods_imports.append("")
  946. if sub_mod_attrs:
  947. sub_mod_attrs = {
  948. attr: mod for mod, attrs in sub_mod_attrs.items() for attr in attrs
  949. }
  950. # construct the import statement and handle special cases for aliases
  951. sub_mod_attrs_imports = [
  952. f"from .{path} import {mod if not isinstance(mod, tuple) else mod[0]} as {mod if not isinstance(mod, tuple) else mod[1]}"
  953. + (
  954. " # type: ignore"
  955. if mod in pyright_ignore_imports
  956. else " # noqa: F401" # ignore ruff formatting here for cases like rx.list.
  957. if isinstance(mod, tuple)
  958. else ""
  959. )
  960. for mod, path in sub_mod_attrs.items()
  961. ]
  962. sub_mod_attrs_imports.append("")
  963. text = "\n" + "\n".join([*sub_mods_imports, *sub_mod_attrs_imports])
  964. text += ast.unparse(new_tree) + "\n"
  965. return text
  966. def _scan_file(self, module_path: Path) -> tuple[str, str] | None:
  967. module_import = (
  968. _relative_to_pwd(module_path)
  969. .with_suffix("")
  970. .as_posix()
  971. .replace("/", ".")
  972. .replace("\\", ".")
  973. )
  974. module = importlib.import_module(module_import)
  975. logger.debug(f"Read {module_path}")
  976. class_names = {
  977. name: obj
  978. for name, obj in vars(module).items()
  979. if inspect.isclass(obj)
  980. and (
  981. rx_types.safe_issubclass(obj, Component)
  982. or rx_types.safe_issubclass(obj, SimpleNamespace)
  983. )
  984. and obj != Component
  985. and inspect.getmodule(obj) == module
  986. }
  987. is_init_file = _relative_to_pwd(module_path).name == "__init__.py"
  988. if not class_names and not is_init_file:
  989. return None
  990. if is_init_file:
  991. new_tree = InitStubGenerator(module, class_names).visit(
  992. ast.parse(inspect.getsource(module))
  993. )
  994. init_imports = self._get_init_lazy_imports(module, new_tree)
  995. if not init_imports:
  996. return None
  997. content_hash = self._write_pyi_file(module_path, init_imports)
  998. else:
  999. new_tree = StubGenerator(module, class_names).visit(
  1000. ast.parse(inspect.getsource(module))
  1001. )
  1002. content_hash = self._write_pyi_file(module_path, ast.unparse(new_tree))
  1003. return str(module_path.with_suffix(".pyi").resolve()), content_hash
  1004. def _scan_files_multiprocess(self, files: list[Path]):
  1005. with Pool(processes=cpu_count()) as pool:
  1006. self.written_files.extend(f for f in pool.map(self._scan_file, files) if f)
  1007. def _scan_files(self, files: list[Path]):
  1008. for file in files:
  1009. pyi_path = self._scan_file(file)
  1010. if pyi_path:
  1011. self.written_files.append(pyi_path)
  1012. def scan_all(
  1013. self,
  1014. targets: list,
  1015. changed_files: list[Path] | None = None,
  1016. use_json: bool = False,
  1017. ):
  1018. """Scan all targets for class inheriting Component and generate the .pyi files.
  1019. Args:
  1020. targets: the list of file/folders to scan.
  1021. changed_files (optional): the list of changed files since the last run.
  1022. use_json: whether to use json to store the hashes.
  1023. """
  1024. file_targets = []
  1025. for target in targets:
  1026. target_path = Path(target)
  1027. if (
  1028. target_path.is_file()
  1029. and target_path.suffix == ".py"
  1030. and target_path.name not in EXCLUDED_FILES
  1031. ):
  1032. file_targets.append(target_path)
  1033. continue
  1034. if not target_path.is_dir():
  1035. continue
  1036. for file_path in _walk_files(target_path):
  1037. relative = _relative_to_pwd(file_path)
  1038. if relative.name in EXCLUDED_FILES or file_path.suffix != ".py":
  1039. continue
  1040. if (
  1041. changed_files is not None
  1042. and _relative_to_pwd(file_path) not in changed_files
  1043. ):
  1044. continue
  1045. file_targets.append(file_path)
  1046. # check if pyi changed but not the source
  1047. if changed_files is not None:
  1048. for changed_file in changed_files:
  1049. if changed_file.suffix != ".pyi":
  1050. continue
  1051. py_file_path = changed_file.with_suffix(".py")
  1052. if not py_file_path.exists() and changed_file.exists():
  1053. changed_file.unlink()
  1054. if py_file_path in file_targets:
  1055. continue
  1056. subprocess.run(["git", "checkout", changed_file])
  1057. if True:
  1058. self._scan_files(file_targets)
  1059. else:
  1060. self._scan_files_multiprocess(file_targets)
  1061. file_paths, hashes = (
  1062. [f[0] for f in self.written_files],
  1063. [f[1] for f in self.written_files],
  1064. )
  1065. # Fix generated pyi files with ruff.
  1066. if file_paths:
  1067. subprocess.run(["ruff", "format", *file_paths])
  1068. subprocess.run(["ruff", "check", "--fix", *file_paths])
  1069. # For some reason, we need to format the __init__.pyi files again after fixing...
  1070. init_files = [f for f in file_paths if "/__init__.pyi" in f]
  1071. subprocess.run(["ruff", "format", *init_files])
  1072. if use_json:
  1073. if file_paths and changed_files is None:
  1074. file_paths = list(map(Path, file_paths))
  1075. top_dir = file_paths[0].parent
  1076. for file_path in file_paths:
  1077. file_parent = file_path.parent
  1078. while len(file_parent.parts) > len(top_dir.parts):
  1079. file_parent = file_parent.parent
  1080. while len(top_dir.parts) > len(file_parent.parts):
  1081. top_dir = top_dir.parent
  1082. while not file_parent.samefile(top_dir):
  1083. file_parent = file_parent.parent
  1084. top_dir = top_dir.parent
  1085. while (
  1086. not top_dir.samefile(top_dir.parent)
  1087. and not (top_dir / PYI_HASHES).exists()
  1088. ):
  1089. top_dir = top_dir.parent
  1090. pyi_hashes_file = top_dir / PYI_HASHES
  1091. if pyi_hashes_file.exists():
  1092. pyi_hashes_file.write_text(
  1093. json.dumps(
  1094. dict(
  1095. zip(
  1096. [
  1097. f.relative_to(pyi_hashes_file.parent).as_posix()
  1098. for f in file_paths
  1099. ],
  1100. hashes,
  1101. strict=True,
  1102. )
  1103. ),
  1104. indent=2,
  1105. sort_keys=True,
  1106. )
  1107. + "\n",
  1108. )
  1109. elif file_paths:
  1110. file_paths = list(map(Path, file_paths))
  1111. pyi_hashes_parent = file_paths[0].parent
  1112. while (
  1113. not pyi_hashes_parent.samefile(pyi_hashes_parent.parent)
  1114. and not (pyi_hashes_parent / PYI_HASHES).exists()
  1115. ):
  1116. pyi_hashes_parent = pyi_hashes_parent.parent
  1117. pyi_hashes_file = pyi_hashes_parent / PYI_HASHES
  1118. if pyi_hashes_file.exists():
  1119. pyi_hashes = json.loads(pyi_hashes_file.read_text())
  1120. for file_path, hashed_content in zip(
  1121. file_paths, hashes, strict=False
  1122. ):
  1123. formatted_path = file_path.relative_to(
  1124. pyi_hashes_parent
  1125. ).as_posix()
  1126. pyi_hashes[formatted_path] = hashed_content
  1127. pyi_hashes_file.write_text(
  1128. json.dumps(pyi_hashes, indent=2, sort_keys=True) + "\n"
  1129. )
  1130. # Post-process the generated pyi files to add hacky type: ignore comments
  1131. for file_path in file_paths:
  1132. with FileInput(file_path, inplace=True) as f:
  1133. for line in f:
  1134. # Hack due to ast not supporting comments in the tree.
  1135. if (
  1136. "def create(" in line
  1137. or "Var[Figure]" in line
  1138. or "Var[Template]" in line
  1139. ):
  1140. line = line.rstrip() + " # type: ignore\n"
  1141. print(line, end="") # noqa: T201
  1142. if __name__ == "__main__":
  1143. logging.basicConfig(level=logging.INFO)
  1144. logging.getLogger("blib2to3.pgen2.driver").setLevel(logging.INFO)
  1145. gen = PyiGenerator()
  1146. gen.scan_all(
  1147. ["reflex/components", "reflex/experimental", "reflex/__init__.py"],
  1148. None,
  1149. use_json=True,
  1150. )