浏览代码

Get project name from pyproject (#3048)

Nikhil Rao 1 年之前
父节点
当前提交
d5e3ccaed8
共有 1 个文件被更改,包括 28 次插入27 次删除
  1. 28 27
      reflex/custom_components/custom_components.py

+ 28 - 27
reflex/custom_components/custom_components.py

@@ -71,6 +71,28 @@ def _create_package_config(module_name: str, package_name: str):
         )
 
 
+def _get_package_config(exit_on_fail: bool = True) -> dict:
+    """Get the package configuration from the pyproject.toml file.
+
+    Args:
+        exit_on_fail: Whether to exit if the pyproject.toml file is not found.
+
+    Returns:
+        The package configuration.
+
+    Raises:
+        Exit: If the pyproject.toml file is not found.
+    """
+    try:
+        with open(CustomComponents.PYPROJECT_TOML, "rb") as f:
+            return dict(tomlkit.load(f))
+    except (OSError, TOMLKitError) as ex:
+        console.error(f"Unable to read from pyproject.toml due to {ex}")
+        if exit_on_fail:
+            raise typer.Exit(code=1) from ex
+        raise
+
+
 def _create_readme(module_name: str, package_name: str):
     """Create a package README file.
 
@@ -416,9 +438,7 @@ def _run_commands_in_subprocess(cmds: list[str]) -> bool:
 
 def _make_pyi_files():
     """Create pyi files for the custom component."""
-    from glob import glob
-
-    package_name = glob("custom_components/*.egg-info")[0].replace(".egg-info", "")
+    package_name = _get_package_config()["project"]["name"]
 
     for dir, _, _ in os.walk(f"./{package_name}"):
         if "__pycache__" in dir:
@@ -514,22 +534,10 @@ def _validate_credentials(
 def _get_version_to_publish() -> str:
     """Get the version to publish from the pyproject.toml.
 
-    Raises:
-        Exit: If the version is not found in the pyproject.toml.
-
     Returns:
         The version to publish.
     """
-    # Get the version from the pyproject.toml.
-    try:
-        with open(CustomComponents.PYPROJECT_TOML, "rb") as f:
-            project_toml = tomlkit.parse(f.read())
-            return project_toml.get("project", {})["version"]
-    except (OSError, KeyError, TOMLKitError) as ex:
-        console.error(
-            f"Cannot find the version in {CustomComponents.PYPROJECT_TOML} due to {ex}"
-        )
-        raise typer.Exit(code=1) from ex
+    return _get_package_config()["project"]["version"]
 
 
 def _ensure_dist_dir(version_to_publish: str, build: bool):
@@ -733,12 +741,7 @@ def _validate_project_info():
     Raises:
         Exit: If the pyproject.toml file is ill-formed.
     """
-    try:
-        with open(CustomComponents.PYPROJECT_TOML, "rb") as f:
-            pyproject_toml = tomlkit.parse(f.read())
-    except TOMLKitError as ex:
-        console.error(f"Unable to read from pyproject.toml due to {ex}")
-        raise typer.Exit(code=1) from ex
+    pyproject_toml = _get_package_config()
 
     try:
         project = pyproject_toml.get("project", {})
@@ -796,7 +799,7 @@ def _validate_project_info():
         with open(CustomComponents.PYPROJECT_TOML, "w") as f:
             tomlkit.dump(pyproject_toml, f)
     except (OSError, TOMLKitError) as ex:
-        console.error(f"Unable to read from pyproject.toml due to {ex}")
+        console.error(f"Unable to write to pyproject.toml due to {ex}")
         raise typer.Exit(code=1) from ex
 
 
@@ -816,10 +819,8 @@ def _collect_details_for_gallery():
     params = {}
     package_name = None
     try:
-        with open(CustomComponents.PYPROJECT_TOML, "rb") as f:
-            project_toml = tomlkit.parse(f.read())
-        package_name = project_toml.get("project", {})["name"]
-    except (OSError, TOMLKitError, KeyError) as ex:
+        package_name = _get_package_config(exit_on_fail=False)["project"]["name"]
+    except (TOMLKitError, KeyError) as ex:
         console.debug(
             f"Unable to read from pyproject.toml in current directory due to {ex}"
         )