pyi_generator.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. """The pyi generator module."""
  2. import ast
  3. import importlib
  4. import inspect
  5. import os
  6. import re
  7. import sys
  8. from inspect import getfullargspec
  9. from pathlib import Path
  10. from typing import Any, Dict, List, Literal, Optional, Set, Union, get_args # NOQA
  11. import black
  12. from reflex.components.component import Component
  13. # NOQA
  14. from reflex.components.graphing.recharts.recharts import (
  15. LiteralAnimationEasing,
  16. LiteralAreaType,
  17. LiteralComposedChartBaseValue,
  18. LiteralDirection,
  19. LiteralGridType,
  20. LiteralIconType,
  21. LiteralIfOverflow,
  22. LiteralInterval,
  23. LiteralLayout,
  24. LiteralLegendAlign,
  25. LiteralLineType,
  26. LiteralOrientationTopBottom,
  27. LiteralOrientationTopBottomLeftRight,
  28. LiteralPolarRadiusType,
  29. LiteralPosition,
  30. LiteralScale,
  31. LiteralShape,
  32. LiteralStackOffset,
  33. LiteralSyncMethod,
  34. LiteralVerticalAlign,
  35. )
  36. from reflex.components.libs.chakra import (
  37. LiteralAlertDialogSize,
  38. LiteralAvatarSize,
  39. LiteralChakraDirection,
  40. LiteralColorScheme,
  41. LiteralDrawerSize,
  42. LiteralImageLoading,
  43. LiteralInputVariant,
  44. LiteralMenuOption,
  45. LiteralMenuStrategy,
  46. LiteralTagSize,
  47. )
  48. from reflex.components.radix.themes.base import (
  49. LiteralAccentColor,
  50. LiteralAlign,
  51. LiteralAppearance,
  52. LiteralGrayColor,
  53. LiteralJustify,
  54. LiteralPanelBackground,
  55. LiteralRadius,
  56. LiteralScaling,
  57. LiteralSize,
  58. LiteralVariant,
  59. )
  60. from reflex.components.radix.themes.components import (
  61. LiteralButtonSize,
  62. LiteralSwitchSize,
  63. )
  64. from reflex.components.radix.themes.layout import (
  65. LiteralBoolNumber,
  66. LiteralContainerSize,
  67. LiteralFlexDirection,
  68. LiteralFlexDisplay,
  69. LiteralFlexWrap,
  70. LiteralGridDisplay,
  71. LiteralGridFlow,
  72. LiteralSectionSize,
  73. )
  74. from reflex.components.radix.themes.typography import (
  75. LiteralLinkUnderline,
  76. LiteralTextAlign,
  77. LiteralTextSize,
  78. LiteralTextTrim,
  79. LiteralTextWeight,
  80. )
  81. # NOQA
  82. from reflex.event import EventChain
  83. from reflex.style import Style
  84. from reflex.utils import format
  85. from reflex.utils import types as rx_types
  86. from reflex.vars import Var
  87. ruff_dont_remove = [
  88. Var,
  89. Optional,
  90. Dict,
  91. List,
  92. EventChain,
  93. Style,
  94. LiteralInputVariant,
  95. LiteralColorScheme,
  96. LiteralChakraDirection,
  97. LiteralTagSize,
  98. LiteralDrawerSize,
  99. LiteralMenuStrategy,
  100. LiteralMenuOption,
  101. LiteralAlertDialogSize,
  102. LiteralAvatarSize,
  103. LiteralImageLoading,
  104. LiteralLayout,
  105. LiteralAnimationEasing,
  106. LiteralGridType,
  107. LiteralPolarRadiusType,
  108. LiteralScale,
  109. LiteralSyncMethod,
  110. LiteralStackOffset,
  111. LiteralComposedChartBaseValue,
  112. LiteralOrientationTopBottom,
  113. LiteralAreaType,
  114. LiteralShape,
  115. LiteralLineType,
  116. LiteralDirection,
  117. LiteralIfOverflow,
  118. LiteralOrientationTopBottomLeftRight,
  119. LiteralInterval,
  120. LiteralLegendAlign,
  121. LiteralVerticalAlign,
  122. LiteralIconType,
  123. LiteralPosition,
  124. LiteralAccentColor,
  125. LiteralAlign,
  126. LiteralAppearance,
  127. LiteralBoolNumber,
  128. LiteralButtonSize,
  129. LiteralContainerSize,
  130. LiteralFlexDirection,
  131. LiteralFlexDisplay,
  132. LiteralFlexWrap,
  133. LiteralGrayColor,
  134. LiteralGridDisplay,
  135. LiteralGridFlow,
  136. LiteralJustify,
  137. LiteralLinkUnderline,
  138. LiteralPanelBackground,
  139. LiteralRadius,
  140. LiteralScaling,
  141. LiteralSectionSize,
  142. LiteralSize,
  143. LiteralSwitchSize,
  144. LiteralTextAlign,
  145. LiteralTextSize,
  146. LiteralTextTrim,
  147. LiteralTextWeight,
  148. LiteralVariant,
  149. ]
  150. EXCLUDED_FILES = [
  151. "__init__.py",
  152. "component.py",
  153. "bare.py",
  154. "foreach.py",
  155. "cond.py",
  156. "multiselect.py",
  157. ]
  158. # These props exist on the base component, but should not be exposed in create methods.
  159. EXCLUDED_PROPS = [
  160. "alias",
  161. "children",
  162. "event_triggers",
  163. "invalid_children",
  164. "library",
  165. "lib_dependencies",
  166. "tag",
  167. "is_default",
  168. "special_props",
  169. "valid_children",
  170. ]
  171. DEFAULT_TYPING_IMPORTS = {"overload", "Any", "Dict", "List", "Optional", "Union"}
  172. def _get_type_hint(value, top_level=True, no_union=False):
  173. res = ""
  174. args = get_args(value)
  175. if args:
  176. inner_container_type_args = (
  177. [format.wrap(arg, '"') for arg in args]
  178. if rx_types.is_literal(value)
  179. else [
  180. _get_type_hint(arg, top_level=False)
  181. for arg in args
  182. if arg is not type(None)
  183. ]
  184. )
  185. res = f"{value.__name__}[{', '.join(inner_container_type_args)}]"
  186. if value.__name__ == "Var":
  187. types = [res] + [
  188. _get_type_hint(arg, top_level=False)
  189. for arg in args
  190. if arg is not type(None)
  191. ]
  192. if len(types) > 1 and not no_union:
  193. res = ", ".join(types)
  194. res = f"Union[{res}]"
  195. elif isinstance(value, str):
  196. ev = eval(value)
  197. res = _get_type_hint(ev, top_level=False) if ev.__name__ == "Var" else value
  198. else:
  199. res = value.__name__
  200. if top_level and not res.startswith("Optional"):
  201. res = f"Optional[{res}]"
  202. return res
  203. def _get_typing_import(_module):
  204. src = [
  205. line
  206. for line in inspect.getsource(_module).split("\n")
  207. if line.startswith("from typing")
  208. ]
  209. if len(src):
  210. return set(src[0].rpartition("from typing import ")[-1].split(", "))
  211. return set()
  212. def _get_var_definition(_module, _var_name):
  213. for node in ast.parse(inspect.getsource(_module)).body:
  214. if isinstance(node, ast.Assign) and _var_name in [
  215. t.id for t in node.targets if isinstance(t, ast.Name)
  216. ]:
  217. return ast.unparse(node)
  218. raise Exception(f"Could not find var {_var_name} in module {_module}")
  219. class PyiGenerator:
  220. """A .pyi file generator that will scan all defined Component in Reflex and
  221. generate the approriate stub.
  222. """
  223. modules: list = []
  224. root: str = ""
  225. current_module: Any = {}
  226. default_typing_imports: set = DEFAULT_TYPING_IMPORTS
  227. def _generate_imports(self, variables, classes):
  228. variables_imports = {
  229. type(_var) for _, _var in variables if isinstance(_var, Component)
  230. }
  231. bases = {
  232. base
  233. for _, _class in classes
  234. for base in _class.__bases__
  235. if inspect.getmodule(base) != self.current_module
  236. } | variables_imports
  237. bases.add(Component)
  238. typing_imports = self.default_typing_imports | _get_typing_import(
  239. self.current_module
  240. )
  241. bases = sorted(bases, key=lambda base: base.__name__)
  242. return [
  243. f"from typing import {','.join(sorted(typing_imports))}",
  244. *[f"from {base.__module__} import {base.__name__}" for base in bases],
  245. "from reflex.vars import Var, BaseVar, ComputedVar",
  246. "from reflex.event import EventHandler, EventChain, EventSpec",
  247. "from reflex.style import Style",
  248. ]
  249. def _generate_pyi_class(self, _class: type[Component]):
  250. create_spec = getfullargspec(_class.create)
  251. lines = [
  252. "",
  253. f"class {_class.__name__}({', '.join([base.__name__ for base in _class.__bases__])}):",
  254. ]
  255. definition = f" @overload\n @classmethod\n def create( # type: ignore\n cls, *children, "
  256. for kwarg in create_spec.kwonlyargs:
  257. if kwarg in create_spec.annotations:
  258. definition += f"{kwarg}: {_get_type_hint(create_spec.annotations[kwarg])} = None, "
  259. else:
  260. definition += f"{kwarg}, "
  261. all_classes = [c for c in _class.__mro__ if issubclass(c, Component)]
  262. all_props = []
  263. for target_class in all_classes:
  264. for name, value in target_class.__annotations__.items():
  265. if (
  266. name in create_spec.kwonlyargs
  267. or name in EXCLUDED_PROPS
  268. or name in all_props
  269. ):
  270. continue
  271. all_props.append(name)
  272. definition += f"{name}: {_get_type_hint(value)} = None, "
  273. for trigger in sorted(_class().get_event_triggers().keys()):
  274. definition += f"{trigger}: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, "
  275. definition = definition.rstrip(", ")
  276. definition += f", **props) -> '{_class.__name__}':\n"
  277. definition += self._generate_docstrings(all_classes, all_props)
  278. lines.append(definition)
  279. lines.append(" ...")
  280. return lines
  281. def _generate_docstrings(self, _classes, _props):
  282. props_comments = {}
  283. comments = []
  284. for _class in _classes:
  285. for _i, line in enumerate(inspect.getsource(_class).splitlines()):
  286. reached_functions = re.search("def ", line)
  287. if reached_functions:
  288. # We've reached the functions, so stop.
  289. break
  290. # Get comments for prop
  291. if line.strip().startswith("#"):
  292. comments.append(line)
  293. continue
  294. # Check if this line has a prop.
  295. match = re.search("\\w+:", line)
  296. if match is None:
  297. # This line doesn't have a var, so continue.
  298. continue
  299. # Get the prop.
  300. prop = match.group(0).strip(":")
  301. if prop in _props:
  302. if not comments: # do not include undocumented props
  303. continue
  304. props_comments[prop] = "\n".join(
  305. [comment.strip().strip("#") for comment in comments]
  306. )
  307. comments.clear()
  308. continue
  309. if prop in EXCLUDED_PROPS:
  310. comments.clear() # throw away comments for excluded props
  311. _class = _classes[0]
  312. new_docstring = []
  313. for i, line in enumerate(_class.create.__doc__.splitlines()):
  314. if i == 0:
  315. new_docstring.append(" " * 8 + '"""' + line)
  316. else:
  317. new_docstring.append(line)
  318. if "*children" in line:
  319. for nline in [
  320. f"{line.split('*')[0]}{n}:{c}" for n, c in props_comments.items()
  321. ]:
  322. new_docstring.append(nline)
  323. new_docstring += ['"""']
  324. return "\n".join(new_docstring)
  325. def _generate_pyi_variable(self, _name, _var):
  326. return _get_var_definition(self.current_module, _name)
  327. def _generate_function(self, _name, _func):
  328. import textwrap
  329. # Don't generate indented functions.
  330. source = inspect.getsource(_func)
  331. if textwrap.dedent(source) != source:
  332. return []
  333. definition = "".join([line for line in source.split(":\n")[0].split("\n")])
  334. return [f"{definition}:", " ..."]
  335. def _write_pyi_file(self, variables, functions, classes):
  336. pyi_content = [
  337. f'"""Stub file for {self.current_module_path}.py"""',
  338. "# ------------------- DO NOT EDIT ----------------------",
  339. "# This file was generated by `scripts/pyi_generator.py`!",
  340. "# ------------------------------------------------------",
  341. "",
  342. ]
  343. pyi_content.extend(self._generate_imports(variables, classes))
  344. for _name, _var in variables:
  345. pyi_content.append(self._generate_pyi_variable(_name, _var))
  346. for _fname, _func in functions:
  347. pyi_content.extend(self._generate_function(_fname, _func))
  348. for _, _class in classes:
  349. pyi_content.extend(self._generate_pyi_class(_class))
  350. pyi_filename = f"{self.current_module_path}.pyi"
  351. pyi_path = os.path.join(self.root, pyi_filename)
  352. with open(pyi_path, "w") as pyi_file:
  353. pyi_file.write("\n".join(pyi_content))
  354. black.format_file_in_place(
  355. src=Path(pyi_path),
  356. fast=True,
  357. mode=black.FileMode(),
  358. write_back=black.WriteBack.YES,
  359. )
  360. def _scan_file(self, file):
  361. self.current_module_path = os.path.splitext(file)[0]
  362. module_import = os.path.splitext(os.path.join(self.root, file))[0].replace(
  363. "/", "."
  364. )
  365. self.current_module = importlib.import_module(module_import)
  366. local_variables = []
  367. for node in ast.parse(inspect.getsource(self.current_module)).body:
  368. if isinstance(node, ast.Assign):
  369. for t in node.targets:
  370. if not isinstance(t, ast.Name):
  371. # Skip non-var assignment statements
  372. continue
  373. if t.id.startswith("_"):
  374. # Skip private vars
  375. continue
  376. obj = getattr(self.current_module, t.id, None)
  377. if inspect.isclass(obj) or inspect.isfunction(obj):
  378. continue
  379. local_variables.append((t.id, obj))
  380. functions = [
  381. (name, obj)
  382. for name, obj in vars(self.current_module).items()
  383. if not name.startswith("__")
  384. and (
  385. not inspect.getmodule(obj)
  386. or inspect.getmodule(obj) == self.current_module
  387. )
  388. and inspect.isfunction(obj)
  389. ]
  390. class_names = [
  391. (name, obj)
  392. for name, obj in vars(self.current_module).items()
  393. if inspect.isclass(obj)
  394. and issubclass(obj, Component)
  395. and obj != Component
  396. and inspect.getmodule(obj) == self.current_module
  397. ]
  398. if not class_names:
  399. return
  400. print(f"Parsed {file}: Found {[n for n, _ in class_names]}")
  401. self._write_pyi_file(local_variables, functions, class_names)
  402. def _scan_folder(self, folder):
  403. for root, _, files in os.walk(folder):
  404. self.root = root
  405. for file in files:
  406. if file in EXCLUDED_FILES:
  407. continue
  408. if file.endswith(".py"):
  409. self._scan_file(file)
  410. def scan_all(self, targets):
  411. """Scan all targets for class inheriting Component and generate the .pyi files.
  412. Args:
  413. targets: the list of file/folders to scan.
  414. """
  415. for target in targets:
  416. if target.endswith(".py"):
  417. self.root, _, file = target.rpartition("/")
  418. self._scan_file(file)
  419. else:
  420. self._scan_folder(target)
  421. if __name__ == "__main__":
  422. targets = sys.argv[1:] if len(sys.argv) > 1 else ["reflex/components"]
  423. print(f"Running .pyi generator for {targets}")
  424. gen = PyiGenerator()
  425. gen.scan_all(targets)