generate_pyi.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Copyright 2021-2025 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import ast
  12. import re
  13. from pathlib import Path
  14. from typing import List
  15. _end_doc = re.compile(r"\"\"\"\s*(#\s*noqa\s*:\s*E501)?\s*\n")
  16. def _get_function_delimiters(initial_line, lines):
  17. begin = end = initial_line
  18. while True:
  19. if lines[begin - 1] == "\n":
  20. break
  21. begin -= 1
  22. if lines[end].endswith("(\n"):
  23. while ":\n" not in lines[end]:
  24. end += 1
  25. if '"""' in lines[end + 1]:
  26. while True:
  27. if _end_doc.search(lines[end]):
  28. break
  29. end += 1
  30. return begin, end + 1
  31. def _get_file_lines(filename: str) -> List[str]:
  32. # Get file lines for later
  33. with open(filename) as f:
  34. return f.readlines()
  35. def _get_file_ast(filename: str):
  36. # Get raw text and build ast
  37. _config = Path(filename)
  38. _tree = _config.read_text()
  39. return ast.parse(_tree)
  40. def _build_base_config_pyi(filename, base_pyi) -> str:
  41. lines = _get_file_lines(filename)
  42. tree = _get_file_ast(filename)
  43. class_lineno = [f.lineno for f in ast.walk(tree) if isinstance(f, ast.ClassDef) and f.name == "Config"]
  44. begin_class, end_class = _get_function_delimiters(class_lineno[0] - 1, lines)
  45. base_pyi += "".join(lines[begin_class:end_class])
  46. functions = [f.lineno for f in ast.walk(tree) if isinstance(f, ast.FunctionDef) and not f.name.startswith("__")]
  47. for ln in functions:
  48. begin_line, end_line = _get_function_delimiters(ln - 1, lines)
  49. base_pyi += "".join(lines[begin_line:end_line])
  50. base_pyi = __add_docstring(base_pyi, lines, end_line)
  51. base_pyi += "\n"
  52. return base_pyi
  53. def __add_docstring(base_pyi, lines, end_line):
  54. if '"""' not in lines[end_line - 1]:
  55. base_pyi += '\t\t""""""\n'.replace("\t", " ")
  56. return base_pyi
  57. def _build_entity_config_pyi(base_pyi, filename, entity_map) -> str:
  58. lines = _get_file_lines(filename)
  59. tree = _get_file_ast(filename)
  60. functions = {}
  61. for f in ast.walk(tree):
  62. if isinstance(f, ast.FunctionDef):
  63. if "_configure" in f.name and not f.name.startswith("__"):
  64. functions[f.name] = f.lineno
  65. elif "_set_default" in f.name and not f.name.startswith("__"):
  66. functions[f.name] = f.lineno
  67. elif "_add" in f.name and not f.name.startswith("__"):
  68. functions[f.name] = f.lineno
  69. for k, v in functions.items():
  70. begin_line, end_line = _get_function_delimiters(v - 1, lines)
  71. try:
  72. func = "".join(lines[begin_line:end_line])
  73. func = func if not k.startswith("_") else func.replace(k, entity_map.get(k))
  74. func = __add_docstring(func, lines, end_line) + "\n"
  75. base_pyi += func
  76. except Exception:
  77. print(f"key={k}") # noqa: T201
  78. raise
  79. return base_pyi
  80. def _generate_entity_and_property_maps(filenames):
  81. entities_map = {}
  82. property_map = {}
  83. for filename in filenames:
  84. etty_tree = _get_file_ast(filename)
  85. functions = [
  86. f for f in ast.walk(etty_tree) if isinstance(f, ast.Call) and getattr(f.func, "id", "") == "_inject_section"
  87. ]
  88. for f in functions:
  89. entity = ast.unparse(f.args[0])
  90. entities_map[entity] = {}
  91. property_map[eval(ast.unparse(f.args[1]))] = entity
  92. # Remove class name from function map
  93. text = ast.unparse(f.args[-1]).replace(f"{entity}.", "")
  94. matches = re.findall(r"\((.*?)\)", text)
  95. for m in matches:
  96. v, k = m.replace("'", "").split(",")
  97. entities_map[entity][k.strip()] = v
  98. return entities_map, property_map
  99. def _generate_acessors(base_pyi, property_map) -> str:
  100. for property, cls in property_map.items():
  101. return_template = f"Dict[str, {cls}]" if property != "job_config" else f"{cls}"
  102. template = ("\t@_Classproperty\n" + f'\tdef {property}(cls) -> {return_template}:\n\t\t""""""\n').replace(
  103. "\t", " "
  104. )
  105. base_pyi += template + "\n"
  106. return base_pyi
  107. def _build_header(filename) -> str:
  108. _file = Path(filename)
  109. return _file.read_text() + "\n"
  110. if __name__ == "__main__":
  111. header_file = "tools/config/pyi_header.py"
  112. base_config = "taipy/common/config/config.py"
  113. config_init = [Path("taipy/core/config/__init__.py"), Path("taipy/rest/config/__init__.py")]
  114. dn_filename = "taipy/core/config/data_node_config.py"
  115. job_filename = "taipy/core/config/job_config.py"
  116. scenario_filename = "taipy/core/config/scenario_config.py"
  117. task_filename = "taipy/core/config/task_config.py"
  118. core_filename = "taipy/core/config/core_section.py"
  119. rest_filename = "taipy/rest/config/rest_config.py"
  120. entities_map, property_map = _generate_entity_and_property_maps(config_init)
  121. pyi = _build_header(header_file)
  122. pyi = _build_base_config_pyi(base_config, pyi)
  123. pyi = _generate_acessors(pyi, property_map)
  124. pyi = _build_entity_config_pyi(pyi, scenario_filename, entities_map["ScenarioConfig"])
  125. pyi = _build_entity_config_pyi(pyi, dn_filename, entities_map["DataNodeConfig"])
  126. pyi = _build_entity_config_pyi(pyi, task_filename, entities_map["TaskConfig"])
  127. pyi = _build_entity_config_pyi(pyi, job_filename, entities_map["JobConfig"])
  128. pyi = _build_entity_config_pyi(pyi, core_filename, entities_map["CoreSection"])
  129. pyi = _build_entity_config_pyi(pyi, rest_filename, entities_map["RestConfig"])
  130. # Remove the final redundant \n
  131. pyi = pyi[:-1]
  132. with open("taipy/common/config/config.pyi", "w") as f:
  133. f.writelines(pyi)