pyi_generator.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. """The pyi generator module."""
  2. import importlib
  3. import inspect
  4. import os
  5. import sys
  6. from inspect import getfullargspec
  7. from pathlib import Path
  8. from typing import Any, Dict, List, Optional, get_args
  9. import black
  10. from reflex.components.component import Component
  11. from reflex.vars import Var
  12. ruff_dont_remove = [Var, Optional, Dict, List]
  13. EXCLUDED_FILES = [
  14. "__init__.py",
  15. "component.py",
  16. "bare.py",
  17. "foreach.py",
  18. "cond.py",
  19. "multiselect.py",
  20. ]
  21. DEFAULT_TYPING_IMPORTS = {"overload", "Optional", "Union"}
  22. def _get_type_hint(value, top_level=True, no_union=False):
  23. res = ""
  24. args = get_args(value)
  25. if args:
  26. res = f"{value.__name__}[{', '.join([_get_type_hint(arg, top_level=False) for arg in args if arg is not type(None)])}]"
  27. if value.__name__ == "Var":
  28. types = [res] + [
  29. _get_type_hint(arg, top_level=False)
  30. for arg in args
  31. if arg is not type(None)
  32. ]
  33. if len(types) > 1 and not no_union:
  34. res = ", ".join(types)
  35. res = f"Union[{res}]"
  36. elif isinstance(value, str):
  37. ev = eval(value)
  38. res = _get_type_hint(ev, top_level=False) if ev.__name__ == "Var" else value
  39. else:
  40. res = value.__name__
  41. if top_level and not res.startswith("Optional"):
  42. res = f"Optional[{res}]"
  43. return res
  44. def _get_typing_import(_module):
  45. src = [
  46. line
  47. for line in inspect.getsource(_module).split("\n")
  48. if line.startswith("from typing")
  49. ]
  50. if len(src):
  51. return set(src[0].rpartition("from typing import ")[-1].split(", "))
  52. return set()
  53. def _get_var_definition(_module, _var_name):
  54. return [
  55. line.split(" = ")[0]
  56. for line in inspect.getsource(_module).splitlines()
  57. if line.startswith(_var_name)
  58. ]
  59. class PyiGenerator:
  60. """A .pyi file generator that will scan all defined Component in Reflex and
  61. generate the approriate stub.
  62. """
  63. modules: list = []
  64. root: str = ""
  65. current_module: Any = {}
  66. default_typing_imports: set = DEFAULT_TYPING_IMPORTS
  67. def _generate_imports(self, variables, classes):
  68. variables_imports = {
  69. type(_var) for _, _var in variables if isinstance(_var, Component)
  70. }
  71. bases = {
  72. base
  73. for _, _class in classes
  74. for base in _class.__bases__
  75. if inspect.getmodule(base) != self.current_module
  76. } | variables_imports
  77. bases.add(Component)
  78. typing_imports = self.default_typing_imports | _get_typing_import(
  79. self.current_module
  80. )
  81. return [
  82. f"from typing import {','.join(typing_imports)}",
  83. *[f"from {base.__module__} import {base.__name__}" for base in bases],
  84. "from reflex.vars import Var, BaseVar, ComputedVar",
  85. "from reflex.event import EventChain",
  86. ]
  87. def _generate_pyi_class(self, _class):
  88. create_spec = getfullargspec(_class.create)
  89. lines = [
  90. "",
  91. f"class {_class.__name__}({', '.join([base.__name__ for base in _class.__bases__])}):",
  92. ]
  93. definition = f" @overload\n @classmethod\n def create(cls, *children, "
  94. for kwarg in create_spec.kwonlyargs:
  95. if kwarg in create_spec.annotations:
  96. definition += f"{kwarg}: {_get_type_hint(create_spec.annotations[kwarg])} = None, "
  97. else:
  98. definition += f"{kwarg}, "
  99. for name, value in _class.__annotations__.items():
  100. if name in create_spec.kwonlyargs:
  101. continue
  102. definition += f"{name}: {_get_type_hint(value)} = None, "
  103. definition = definition.rstrip(", ")
  104. definition += f", **props) -> '{_class.__name__}': ... # type: ignore"
  105. lines.append(definition)
  106. return lines
  107. def _generate_pyi_variable(self, _name, _var):
  108. return _get_var_definition(self.current_module, _name)
  109. def _generate_function(self, _name, _func):
  110. definition = "".join(inspect.getsource(_func).split(":\n")[0].split("\n"))
  111. return [f"{definition}:", " ..."]
  112. def _write_pyi_file(self, variables, functions, classes):
  113. pyi_content = [
  114. f'"""Stub file for {self.current_module_path}.py"""',
  115. "# ------------------- DO NOT EDIT ----------------------",
  116. "# This file was generated by `scripts/pyi_generator.py`!",
  117. "# ------------------------------------------------------",
  118. "",
  119. ]
  120. pyi_content.extend(self._generate_imports(variables, classes))
  121. for _name, _var in variables:
  122. pyi_content.extend(self._generate_pyi_variable(_name, _var))
  123. for _fname, _func in functions:
  124. pyi_content.extend(self._generate_function(_fname, _func))
  125. for _, _class in classes:
  126. pyi_content.extend(self._generate_pyi_class(_class))
  127. pyi_filename = f"{self.current_module_path}.pyi"
  128. pyi_path = os.path.join(self.root, pyi_filename)
  129. with open(pyi_path, "w") as pyi_file:
  130. pyi_file.write("\n".join(pyi_content))
  131. black.format_file_in_place(
  132. src=Path(pyi_path),
  133. fast=True,
  134. mode=black.FileMode(),
  135. write_back=black.WriteBack.YES,
  136. )
  137. def _scan_file(self, file):
  138. self.current_module_path = os.path.splitext(file)[0]
  139. module_import = os.path.splitext(os.path.join(self.root, file))[0].replace(
  140. "/", "."
  141. )
  142. self.current_module = importlib.import_module(module_import)
  143. local_variables = [
  144. (name, obj)
  145. for name, obj in vars(self.current_module).items()
  146. if not name.startswith("__")
  147. # and (
  148. # not inspect.getmodule(obj)
  149. # or inspect.getmodule(obj) == self.current_module
  150. # )
  151. and not inspect.isclass(obj) and not inspect.isfunction(obj)
  152. ]
  153. functions = [
  154. (name, obj)
  155. for name, obj in vars(self.current_module).items()
  156. if not name.startswith("__")
  157. and (
  158. not inspect.getmodule(obj)
  159. or inspect.getmodule(obj) == self.current_module
  160. )
  161. and inspect.isfunction(obj)
  162. ]
  163. class_names = [
  164. (name, obj)
  165. for name, obj in vars(self.current_module).items()
  166. if inspect.isclass(obj)
  167. and issubclass(obj, Component)
  168. and obj != Component
  169. and inspect.getmodule(obj) == self.current_module
  170. ]
  171. if not class_names:
  172. return
  173. print(f"Parsed {file}: Found {[n for n,_ in class_names]}")
  174. self._write_pyi_file(local_variables, functions, class_names)
  175. def _scan_folder(self, folder):
  176. for root, _, files in os.walk(folder):
  177. self.root = root
  178. for file in files:
  179. if file in EXCLUDED_FILES:
  180. continue
  181. if file.endswith(".py"):
  182. self._scan_file(file)
  183. def scan_all(self, targets):
  184. """Scan all targets for class inheriting Component and generate the .pyi files.
  185. Args:
  186. targets: the list of file/folders to scan.
  187. """
  188. for target in targets:
  189. if target.endswith(".py"):
  190. self.root, _, file = target.rpartition("/")
  191. self._scan_file(file)
  192. else:
  193. self._scan_folder(target)
  194. if __name__ == "__main__":
  195. targets = sys.argv[1:] if len(sys.argv) > 1 else ["reflex/components"]
  196. print(f"Running .pyi generator for {targets}")
  197. gen = PyiGenerator()
  198. gen.scan_all(targets)