pyi_generator.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. """The pyi generator module."""
  2. import importlib
  3. import inspect
  4. import os
  5. import re
  6. import sys
  7. from inspect import getfullargspec
  8. from pathlib import Path
  9. from typing import Any, Dict, List, Literal, Optional, Set, Union, get_args # NOQA
  10. import black
  11. from reflex.components.component import Component
  12. # NOQA
  13. from reflex.components.graphing.recharts.recharts import (
  14. LiteralAnimationEasing,
  15. LiteralAreaType,
  16. LiteralComposedChartBaseValue,
  17. LiteralDirection,
  18. LiteralGridType,
  19. LiteralIconType,
  20. LiteralIfOverflow,
  21. LiteralInterval,
  22. LiteralLayout,
  23. LiteralLegendAlign,
  24. LiteralLineType,
  25. LiteralOrientationTopBottom,
  26. LiteralOrientationTopBottomLeftRight,
  27. LiteralPolarRadiusType,
  28. LiteralPosition,
  29. LiteralScale,
  30. LiteralShape,
  31. LiteralStackOffset,
  32. LiteralSyncMethod,
  33. LiteralVerticalAlign,
  34. )
  35. from reflex.components.libs.chakra import (
  36. LiteralAlertDialogSize,
  37. LiteralAvatarSize,
  38. LiteralChakraDirection,
  39. LiteralColorScheme,
  40. LiteralDrawerSize,
  41. LiteralImageLoading,
  42. LiteralInputVariant,
  43. LiteralMenuOption,
  44. LiteralMenuStrategy,
  45. LiteralTagSize,
  46. )
  47. # NOQA
  48. from reflex.event import EventChain
  49. from reflex.style import Style
  50. from reflex.utils import format
  51. from reflex.utils import types as rx_types
  52. from reflex.vars import Var
  53. ruff_dont_remove = [
  54. Var,
  55. Optional,
  56. Dict,
  57. List,
  58. EventChain,
  59. Style,
  60. LiteralInputVariant,
  61. LiteralColorScheme,
  62. LiteralChakraDirection,
  63. LiteralTagSize,
  64. LiteralDrawerSize,
  65. LiteralMenuStrategy,
  66. LiteralMenuOption,
  67. LiteralAlertDialogSize,
  68. LiteralAvatarSize,
  69. LiteralImageLoading,
  70. LiteralLayout,
  71. LiteralAnimationEasing,
  72. LiteralGridType,
  73. LiteralPolarRadiusType,
  74. LiteralScale,
  75. LiteralSyncMethod,
  76. LiteralStackOffset,
  77. LiteralComposedChartBaseValue,
  78. LiteralOrientationTopBottom,
  79. LiteralAreaType,
  80. LiteralShape,
  81. LiteralLineType,
  82. LiteralDirection,
  83. LiteralIfOverflow,
  84. LiteralOrientationTopBottomLeftRight,
  85. LiteralInterval,
  86. LiteralLegendAlign,
  87. LiteralVerticalAlign,
  88. LiteralIconType,
  89. LiteralPosition,
  90. ]
  91. EXCLUDED_FILES = [
  92. "__init__.py",
  93. "component.py",
  94. "bare.py",
  95. "foreach.py",
  96. "cond.py",
  97. "multiselect.py",
  98. ]
  99. # These props exist on the base component, but should not be exposed in create methods.
  100. EXCLUDED_PROPS = [
  101. "alias",
  102. "children",
  103. "event_triggers",
  104. "invalid_children",
  105. "library",
  106. "lib_dependencies",
  107. "tag",
  108. "is_default",
  109. "special_props",
  110. "valid_children",
  111. ]
  112. DEFAULT_TYPING_IMPORTS = {"overload", "Any", "Dict", "List", "Optional", "Union"}
  113. def _get_type_hint(value, top_level=True, no_union=False):
  114. res = ""
  115. args = get_args(value)
  116. if args:
  117. inner_container_type_args = (
  118. [format.wrap(arg, '"') for arg in args]
  119. if rx_types.is_literal(value)
  120. else [
  121. _get_type_hint(arg, top_level=False)
  122. for arg in args
  123. if arg is not type(None)
  124. ]
  125. )
  126. res = f"{value.__name__}[{', '.join(inner_container_type_args)}]"
  127. if value.__name__ == "Var":
  128. types = [res] + [
  129. _get_type_hint(arg, top_level=False)
  130. for arg in args
  131. if arg is not type(None)
  132. ]
  133. if len(types) > 1 and not no_union:
  134. res = ", ".join(types)
  135. res = f"Union[{res}]"
  136. elif isinstance(value, str):
  137. ev = eval(value)
  138. res = _get_type_hint(ev, top_level=False) if ev.__name__ == "Var" else value
  139. else:
  140. res = value.__name__
  141. if top_level and not res.startswith("Optional"):
  142. res = f"Optional[{res}]"
  143. return res
  144. def _get_typing_import(_module):
  145. src = [
  146. line
  147. for line in inspect.getsource(_module).split("\n")
  148. if line.startswith("from typing")
  149. ]
  150. if len(src):
  151. return set(src[0].rpartition("from typing import ")[-1].split(", "))
  152. return set()
  153. def _get_var_definition(_module, _var_name):
  154. return [
  155. line.split(" = ")[0]
  156. for line in inspect.getsource(_module).splitlines()
  157. if line.startswith(_var_name)
  158. ]
  159. class PyiGenerator:
  160. """A .pyi file generator that will scan all defined Component in Reflex and
  161. generate the approriate stub.
  162. """
  163. modules: list = []
  164. root: str = ""
  165. current_module: Any = {}
  166. default_typing_imports: set = DEFAULT_TYPING_IMPORTS
  167. def _generate_imports(self, variables, classes):
  168. variables_imports = {
  169. type(_var) for _, _var in variables if isinstance(_var, Component)
  170. }
  171. bases = {
  172. base
  173. for _, _class in classes
  174. for base in _class.__bases__
  175. if inspect.getmodule(base) != self.current_module
  176. } | variables_imports
  177. bases.add(Component)
  178. typing_imports = self.default_typing_imports | _get_typing_import(
  179. self.current_module
  180. )
  181. bases = sorted(bases, key=lambda base: base.__name__)
  182. return [
  183. f"from typing import {','.join(sorted(typing_imports))}",
  184. *[f"from {base.__module__} import {base.__name__}" for base in bases],
  185. "from reflex.vars import Var, BaseVar, ComputedVar",
  186. "from reflex.event import EventHandler, EventChain, EventSpec",
  187. "from reflex.style import Style",
  188. ]
  189. def _generate_pyi_class(self, _class: type[Component]):
  190. create_spec = getfullargspec(_class.create)
  191. lines = [
  192. "",
  193. f"class {_class.__name__}({', '.join([base.__name__ for base in _class.__bases__])}):",
  194. ]
  195. definition = f" @overload\n @classmethod\n def create( # type: ignore\n cls, *children, "
  196. for kwarg in create_spec.kwonlyargs:
  197. if kwarg in create_spec.annotations:
  198. definition += f"{kwarg}: {_get_type_hint(create_spec.annotations[kwarg])} = None, "
  199. else:
  200. definition += f"{kwarg}, "
  201. all_classes = [c for c in _class.__mro__ if issubclass(c, Component)]
  202. all_props = []
  203. for target_class in all_classes:
  204. for name, value in target_class.__annotations__.items():
  205. if (
  206. name in create_spec.kwonlyargs
  207. or name in EXCLUDED_PROPS
  208. or name in all_props
  209. ):
  210. continue
  211. all_props.append(name)
  212. definition += f"{name}: {_get_type_hint(value)} = None, "
  213. for trigger in sorted(_class().get_event_triggers().keys()):
  214. definition += f"{trigger}: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, "
  215. definition = definition.rstrip(", ")
  216. definition += f", **props) -> '{_class.__name__}':\n"
  217. definition += self._generate_docstrings(all_classes, all_props)
  218. lines.append(definition)
  219. lines.append(" ...")
  220. return lines
  221. def _generate_docstrings(self, _classes, _props):
  222. props_comments = {}
  223. comments = []
  224. for _class in _classes:
  225. for _i, line in enumerate(inspect.getsource(_class).splitlines()):
  226. reached_functions = re.search("def ", line)
  227. if reached_functions:
  228. # We've reached the functions, so stop.
  229. break
  230. # Get comments for prop
  231. if line.strip().startswith("#"):
  232. comments.append(line)
  233. continue
  234. # Check if this line has a prop.
  235. match = re.search("\\w+:", line)
  236. if match is None:
  237. # This line doesn't have a var, so continue.
  238. continue
  239. # Get the prop.
  240. prop = match.group(0).strip(":")
  241. if prop in _props:
  242. if not comments: # do not include undocumented props
  243. continue
  244. props_comments[prop] = "\n".join(
  245. [comment.strip().strip("#") for comment in comments]
  246. )
  247. comments.clear()
  248. continue
  249. if prop in EXCLUDED_PROPS:
  250. comments.clear() # throw away comments for excluded props
  251. _class = _classes[0]
  252. new_docstring = []
  253. for i, line in enumerate(_class.create.__doc__.splitlines()):
  254. if i == 0:
  255. new_docstring.append(" " * 8 + '"""' + line)
  256. else:
  257. new_docstring.append(line)
  258. if "*children" in line:
  259. for nline in [
  260. f"{line.split('*')[0]}{n}:{c}" for n, c in props_comments.items()
  261. ]:
  262. new_docstring.append(nline)
  263. new_docstring += ['"""']
  264. return "\n".join(new_docstring)
  265. def _generate_pyi_variable(self, _name, _var):
  266. return _get_var_definition(self.current_module, _name)
  267. def _generate_function(self, _name, _func):
  268. import textwrap
  269. # Don't generate indented functions.
  270. source = inspect.getsource(_func)
  271. if textwrap.dedent(source) != source:
  272. return []
  273. definition = "".join([line for line in source.split(":\n")[0].split("\n")])
  274. return [f"{definition}:", " ..."]
  275. def _write_pyi_file(self, variables, functions, classes):
  276. pyi_content = [
  277. f'"""Stub file for {self.current_module_path}.py"""',
  278. "# ------------------- DO NOT EDIT ----------------------",
  279. "# This file was generated by `scripts/pyi_generator.py`!",
  280. "# ------------------------------------------------------",
  281. "",
  282. ]
  283. pyi_content.extend(self._generate_imports(variables, classes))
  284. for _name, _var in variables:
  285. pyi_content.extend(self._generate_pyi_variable(_name, _var))
  286. for _fname, _func in functions:
  287. pyi_content.extend(self._generate_function(_fname, _func))
  288. for _, _class in classes:
  289. pyi_content.extend(self._generate_pyi_class(_class))
  290. pyi_filename = f"{self.current_module_path}.pyi"
  291. pyi_path = os.path.join(self.root, pyi_filename)
  292. with open(pyi_path, "w") as pyi_file:
  293. pyi_file.write("\n".join(pyi_content))
  294. black.format_file_in_place(
  295. src=Path(pyi_path),
  296. fast=True,
  297. mode=black.FileMode(),
  298. write_back=black.WriteBack.YES,
  299. )
  300. def _scan_file(self, file):
  301. self.current_module_path = os.path.splitext(file)[0]
  302. module_import = os.path.splitext(os.path.join(self.root, file))[0].replace(
  303. "/", "."
  304. )
  305. self.current_module = importlib.import_module(module_import)
  306. local_variables = [
  307. (name, obj)
  308. for name, obj in vars(self.current_module).items()
  309. if not name.startswith("_")
  310. and not inspect.isclass(obj)
  311. and not inspect.isfunction(obj)
  312. ]
  313. functions = [
  314. (name, obj)
  315. for name, obj in vars(self.current_module).items()
  316. if not name.startswith("__")
  317. and (
  318. not inspect.getmodule(obj)
  319. or inspect.getmodule(obj) == self.current_module
  320. )
  321. and inspect.isfunction(obj)
  322. ]
  323. class_names = [
  324. (name, obj)
  325. for name, obj in vars(self.current_module).items()
  326. if inspect.isclass(obj)
  327. and issubclass(obj, Component)
  328. and obj != Component
  329. and inspect.getmodule(obj) == self.current_module
  330. ]
  331. if not class_names:
  332. return
  333. print(f"Parsed {file}: Found {[n for n, _ in class_names]}")
  334. self._write_pyi_file(local_variables, functions, class_names)
  335. def _scan_folder(self, folder):
  336. for root, _, files in os.walk(folder):
  337. self.root = root
  338. for file in files:
  339. if file in EXCLUDED_FILES:
  340. continue
  341. if file.endswith(".py"):
  342. self._scan_file(file)
  343. def scan_all(self, targets):
  344. """Scan all targets for class inheriting Component and generate the .pyi files.
  345. Args:
  346. targets: the list of file/folders to scan.
  347. """
  348. for target in targets:
  349. if target.endswith(".py"):
  350. self.root, _, file = target.rpartition("/")
  351. self._scan_file(file)
  352. else:
  353. self._scan_folder(target)
  354. if __name__ == "__main__":
  355. targets = sys.argv[1:] if len(sys.argv) > 1 else ["reflex/components"]
  356. print(f"Running .pyi generator for {targets}")
  357. gen = PyiGenerator()
  358. gen.scan_all(targets)