فهرست منبع

pyi generator git diff (#2460)

benedikt-bartscher 1 سال پیش
والد
کامیت
08524e22aa
2فایلهای تغییر یافته به همراه171 افزوده شده و 19 حذف شده
  1. 2 0
      .gitignore
  2. 169 19
      scripts/pyi_generator.py

+ 2 - 0
.gitignore

@@ -9,3 +9,5 @@ examples/
 .venv
 .venv
 venv
 venv
 requirements.txt
 requirements.txt
+.pyi_generator_last_run
+.pyi_generator_diff

+ 169 - 19
scripts/pyi_generator.py

@@ -5,8 +5,8 @@ import contextlib
 import importlib
 import importlib
 import inspect
 import inspect
 import logging
 import logging
-import os
 import re
 import re
+import subprocess
 import sys
 import sys
 import textwrap
 import textwrap
 import typing
 import typing
@@ -25,6 +25,12 @@ from reflex.vars import Var
 
 
 logger = logging.getLogger("pyi_generator")
 logger = logging.getLogger("pyi_generator")
 
 
+LAST_RUN_COMMIT_SHA_FILE = Path(".pyi_generator_last_run").resolve()
+INIT_FILE = Path("reflex/__init__.pyi").resolve()
+PWD = Path(".").resolve()
+GENERATOR_FILE = Path(__file__).resolve()
+GENERATOR_DIFF_FILE = Path(".pyi_generator_diff").resolve()
+
 EXCLUDED_FILES = [
 EXCLUDED_FILES = [
     "__init__.py",
     "__init__.py",
     "component.py",
     "component.py",
@@ -62,6 +68,108 @@ DEFAULT_TYPING_IMPORTS = {
 }
 }
 
 
 
 
+def _walk_files(path):
+    """Walk all files in a path.
+    This can be replaced with Path.walk() in python3.12.
+
+    Args:
+        path: The path to walk.
+
+    Yields:
+        The next file in the path.
+    """
+    for p in Path(path).iterdir():
+        if p.is_dir():
+            yield from _walk_files(p)
+            continue
+        yield p.resolve()
+
+
+def _relative_to_pwd(path: Path) -> Path:
+    """Get the relative path of a path to the current working directory.
+
+    Args:
+        path: The path to get the relative path for.
+
+    Returns:
+        The relative path.
+    """
+    return path.relative_to(PWD)
+
+
+def _git_diff(args: list[str]) -> str:
+    """Run a git diff command.
+
+    Args:
+        args: The args to pass to git diff.
+
+    Returns:
+        The output of the git diff command.
+    """
+    cmd = ["git", "diff", "--no-color", *args]
+    return subprocess.run(cmd, capture_output=True, encoding="utf-8").stdout
+
+
+def _git_changed_files(args: list[str] | None = None) -> list[Path]:
+    """Get the list of changed files for a git diff command.
+
+    Args:
+        args: The args to pass to git diff.
+
+    Returns:
+        The list of changed files.
+    """
+    if not args:
+        args = []
+
+    if "--name-only" not in args:
+        args.insert(0, "--name-only")
+
+    diff = _git_diff(args).splitlines()
+    return [Path(file.strip()) for file in diff]
+
+
+def _get_changed_files() -> list[Path] | None:
+    """Get the list of changed files since the last run of the generator.
+
+    Returns:
+        The list of changed files, or None if all files should be regenerated.
+    """
+    try:
+        last_run_commit_sha = LAST_RUN_COMMIT_SHA_FILE.read_text().strip()
+    except FileNotFoundError:
+        logger.info(
+            "pyi_generator.py last run could not be determined, regenerating all .pyi files"
+        )
+        return None
+    changed_files = _git_changed_files([f"{last_run_commit_sha}..HEAD"])
+    # get all unstaged changes
+    changed_files.extend(_git_changed_files())
+    if _relative_to_pwd(GENERATOR_FILE) not in changed_files:
+        return changed_files
+    logger.info("pyi_generator.py has changed, checking diff now")
+    diff = "".join(_git_diff([GENERATOR_FILE.as_posix()]).splitlines()[2:])
+
+    try:
+        last_diff = GENERATOR_DIFF_FILE.read_text()
+        if diff != last_diff:
+            logger.info("pyi_generator.py has changed, regenerating all .pyi files")
+            changed_files = None
+        else:
+            logger.info(
+                "pyi_generator.py has not changed, only regenerating changed files"
+            )
+    except FileNotFoundError:
+        logger.info(
+            "pyi_generator.py diff could not be determined, regenerating all .pyi files"
+        )
+        changed_files = None
+
+    GENERATOR_DIFF_FILE.write_text(diff)
+
+    return changed_files
+
+
 def _get_type_hint(value, type_hint_globals, is_optional=True) -> str:
 def _get_type_hint(value, type_hint_globals, is_optional=True) -> str:
     """Resolve the type hint for value.
     """Resolve the type hint for value.
 
 
@@ -592,8 +700,9 @@ class PyiGenerator:
     current_module: Any = {}
     current_module: Any = {}
 
 
     def _write_pyi_file(self, module_path: Path, source: str):
     def _write_pyi_file(self, module_path: Path, source: str):
+        relpath = _relative_to_pwd(module_path)
         pyi_content = [
         pyi_content = [
-            f'"""Stub file for {module_path}"""',
+            f'"""Stub file for {relpath}"""',
             "# ------------------- DO NOT EDIT ----------------------",
             "# ------------------- DO NOT EDIT ----------------------",
             "# This file was generated by `scripts/pyi_generator.py`!",
             "# This file was generated by `scripts/pyi_generator.py`!",
             "# ------------------------------------------------------",
             "# ------------------------------------------------------",
@@ -616,10 +725,13 @@ class PyiGenerator:
 
 
         pyi_path = module_path.with_suffix(".pyi")
         pyi_path = module_path.with_suffix(".pyi")
         pyi_path.write_text("\n".join(pyi_content))
         pyi_path.write_text("\n".join(pyi_content))
-        logger.info(f"Wrote {pyi_path}")
+        logger.info(f"Wrote {relpath}")
 
 
     def _scan_file(self, module_path: Path):
     def _scan_file(self, module_path: Path):
-        module_import = str(module_path.with_suffix("")).replace("/", ".")
+        #  module_import = str(module_path.with_suffix("")).replace("/", ".")
+        module_import = (
+            _relative_to_pwd(module_path).with_suffix("").as_posix().replace("/", ".")
+        )
         module = importlib.import_module(module_import)
         module = importlib.import_module(module_import)
         logger.debug(f"Read {module_path}")
         logger.debug(f"Read {module_path}")
         class_names = {
         class_names = {
@@ -638,29 +750,56 @@ class PyiGenerator:
         )
         )
         self._write_pyi_file(module_path, ast.unparse(new_tree))
         self._write_pyi_file(module_path, ast.unparse(new_tree))
 
 
-    def _scan_files_multiprocess(self, files):
+    def _scan_files_multiprocess(self, files: list[Path]):
         with Pool(processes=cpu_count()) as pool:
         with Pool(processes=cpu_count()) as pool:
             pool.map(self._scan_file, files)
             pool.map(self._scan_file, files)
 
 
-    def scan_all(self, targets):
+    def _scan_files(self, files: list[Path]):
+        for file in files:
+            self._scan_file(file)
+
+    def scan_all(self, targets, changed_files: list[Path] | None = None):
         """Scan all targets for class inheriting Component and generate the .pyi files.
         """Scan all targets for class inheriting Component and generate the .pyi files.
 
 
         Args:
         Args:
             targets: the list of file/folders to scan.
             targets: the list of file/folders to scan.
+            changed_files (optional): the list of changed files since the last run.
         """
         """
         file_targets = []
         file_targets = []
         for target in targets:
         for target in targets:
-            path = Path(target)
-            if target.endswith(".py") and path.is_file():
-                file_targets.append(path)
-            elif path.is_dir():
-                for root, _, files in os.walk(path):
-                    for file in files:
-                        if file in EXCLUDED_FILES or not file.endswith(".py"):
-                            continue
-                        file_targets.append(Path(root) / file)
+            target_path = Path(target)
+            if target_path.is_file() and target_path.suffix == ".py":
+                file_targets.append(target_path)
+                continue
+            if not target_path.is_dir():
+                continue
+            for file_path in _walk_files(target_path):
+                relative = _relative_to_pwd(file_path)
+                if relative.name in EXCLUDED_FILES or file_path.suffix != ".py":
+                    continue
+                if (
+                    changed_files is not None
+                    and _relative_to_pwd(file_path) not in changed_files
+                ):
+                    continue
+                file_targets.append(file_path)
+
+        # check if pyi changed but not the source
+        if changed_files is not None:
+            for changed_file in changed_files:
+                if changed_file.suffix != ".pyi":
+                    continue
+                py_file_path = changed_file.with_suffix(".py")
+                if not py_file_path.exists() and changed_file.exists():
+                    changed_file.unlink()
+                if py_file_path in file_targets:
+                    continue
+                subprocess.run(["git", "checkout", changed_file])
 
 
-        self._scan_files_multiprocess(file_targets)
+        if cpu_count() == 1 or len(file_targets) < 5:
+            self._scan_files(file_targets)
+        else:
+            self._scan_files_multiprocess(file_targets)
 
 
 
 
 def generate_init():
 def generate_init():
@@ -673,8 +812,7 @@ def generate_init():
     ]
     ]
     imports.append("")
     imports.append("")
 
 
-    with open("reflex/__init__.pyi", "w") as pyi_file:
-        pyi_file.writelines("\n".join(imports))
+    INIT_FILE.write_text("\n".join(imports))
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
@@ -683,6 +821,18 @@ if __name__ == "__main__":
 
 
     targets = sys.argv[1:] if len(sys.argv) > 1 else ["reflex/components"]
     targets = sys.argv[1:] if len(sys.argv) > 1 else ["reflex/components"]
     logger.info(f"Running .pyi generator for {targets}")
     logger.info(f"Running .pyi generator for {targets}")
+
+    changed_files = _get_changed_files()
+    if changed_files is None:
+        logger.info("Changed files could not be detected, regenerating all .pyi files")
+    else:
+        logger.info(f"Detected changed files: {changed_files}")
+
     gen = PyiGenerator()
     gen = PyiGenerator()
-    gen.scan_all(targets)
+    gen.scan_all(targets, changed_files)
     generate_init()
     generate_init()
+
+    current_commit_sha = subprocess.run(
+        ["git", "rev-parse", "HEAD"], capture_output=True, encoding="utf-8"
+    ).stdout.strip()
+    LAST_RUN_COMMIT_SHA_FILE.write_text(current_commit_sha)