浏览代码

Use concurrent.futures for threading (#1483)

Nikhil Rao 1 年之前
父节点
当前提交
91c0de4b5f
共有 5 个文件被更改,包括 59 次插入35 次删除
  1. 8 10
      reflex/reflex.py
  2. 1 1
      reflex/utils/console.py
  3. 6 16
      reflex/utils/prerequisites.py
  4. 40 8
      reflex/utils/processes.py
  5. 4 0
      tests/test_testing.py

+ 8 - 10
reflex/reflex.py

@@ -2,7 +2,6 @@
 
 import os
 import signal
-import threading
 from pathlib import Path
 
 import httpx
@@ -167,18 +166,17 @@ def run(
     # Post a telemetry event.
     telemetry.send(f"run-{env.value}", config.telemetry_enabled)
 
-    # Run the frontend and backend.
+    # Display custom message when there is a keyboard interrupt.
+    signal.signal(signal.SIGINT, processes.catch_keyboard_interrupt)
+
+    # Run the frontend and backend together.
+    commands = []
     if frontend:
         setup_frontend(Path.cwd())
-        threading.Thread(target=frontend_cmd, args=(Path.cwd(), frontend_port)).start()
+        commands.append((frontend_cmd, Path.cwd(), frontend_port))
     if backend:
-        threading.Thread(
-            target=backend_cmd,
-            args=(app.__name__, backend_host, backend_port),
-        ).start()
-
-    # Display custom message when there is a keyboard interrupt.
-    signal.signal(signal.SIGINT, processes.catch_keyboard_interrupt)
+        commands.append((backend_cmd, app.__name__, backend_host, backend_port))
+    processes.run_concurrently(*commands)
 
 
 @cli.command()

+ 1 - 1
reflex/utils/console.py

@@ -121,7 +121,7 @@ def error(msg: str, **kwargs):
         kwargs: Keyword arguments to pass to the print function.
     """
     if LOG_LEVEL <= LogLevel.ERROR:
-        print(f"[red]Error: {msg}[/red]", **kwargs)
+        print(f"[red]{msg}[/red]", **kwargs)
 
 
 def ask(

+ 6 - 16
reflex/utils/prerequisites.py

@@ -9,7 +9,6 @@ import platform
 import re
 import sys
 import tempfile
-import threading
 from fileinput import FileInput
 from pathlib import Path
 from types import ModuleType
@@ -167,7 +166,7 @@ def get_default_app_name() -> str:
         console.error(
             f"The app directory cannot be named [bold]{constants.MODULE_NAME}[/bold]."
         )
-        raise typer.Exit()
+        raise typer.Exit(1)
 
     return app_name
 
@@ -315,7 +314,7 @@ def install_node():
         console.error(
             f"Node.js version {constants.NODE_VERSION} or higher is required to run Reflex."
         )
-        raise typer.Exit()
+        raise typer.Exit(1)
 
     # Create the nvm directory and install.
     path_ops.mkdir(constants.NVM_DIR)
@@ -332,7 +331,7 @@ def install_node():
         ],
         env=env,
     )
-    processes.show_status("", process)
+    processes.show_status("Installing node", process)
 
 
 def install_bun():
@@ -401,14 +400,14 @@ def check_initialized(frontend: bool = True):
         console.error(
             f"The app is not initialized. Run [bold]{constants.MODULE_NAME} init[/bold] first."
         )
-        raise typer.Exit()
+        raise typer.Exit(1)
 
     # Check that the template is up to date.
     if frontend and not is_latest_template():
         console.error(
             "The base app template has updated. Run [bold]reflex init[/bold] again."
         )
-        raise typer.Exit()
+        raise typer.Exit(1)
 
     # Print a warning for Windows users.
     if IS_WINDOWS:
@@ -436,20 +435,11 @@ def initialize_frontend_dependencies():
     path_ops.mkdir(constants.REFLEX_DIR)
 
     # Install the frontend dependencies.
-    threads = [
-        threading.Thread(target=initialize_bun),
-        threading.Thread(target=initialize_node),
-    ]
-    for thread in threads:
-        thread.start()
+    processes.run_concurrently(install_node, install_bun)
 
     # Set up the web directory.
     initialize_web_directory()
 
-    # Wait for the threads to finish.
-    for thread in threads:
-        thread.join()
-
 
 def check_admin_settings():
     """Check if admin settings are set and valid for logging in cli app."""

+ 40 - 8
reflex/utils/processes.py

@@ -2,15 +2,17 @@
 
 from __future__ import annotations
 
+import collections
 import contextlib
 import os
 import signal
 import subprocess
-import sys
-from typing import List, Optional
+from concurrent import futures
+from typing import Callable, List, Optional, Tuple, Union
 from urllib.parse import urlparse
 
 import psutil
+import typer
 
 from reflex import constants
 from reflex.config import get_config
@@ -99,6 +101,10 @@ def change_or_terminate_port(port, _type) -> str:
 
     Returns:
         The new port or the current one.
+
+
+    Raises:
+        Exit: If the user wants to exit.
     """
     console.info(
         f"Something is already running on port [bold underline]{port}[/bold underline]. This is the port the {_type} runs on."
@@ -120,7 +126,7 @@ def change_or_terminate_port(port, _type) -> str:
             return new_port
     else:
         console.log("Exiting...")
-        sys.exit()
+        raise typer.Exit()
 
 
 def new_process(args, run: bool = False, show_logs: bool = False, **kwargs):
@@ -153,6 +159,26 @@ def new_process(args, run: bool = False, show_logs: bool = False, **kwargs):
     return fn(args, **kwargs)
 
 
+def run_concurrently(*fns: Union[Callable, Tuple]):
+    """Run functions concurrently in a thread pool.
+
+
+    Args:
+        *fns: The functions to run.
+    """
+    # Convert the functions to tuples.
+    fns = [fn if isinstance(fn, tuple) else (fn,) for fn in fns]  # type: ignore
+
+    # Run the functions concurrently.
+    with futures.ThreadPoolExecutor(max_workers=len(fns)) as executor:
+        # Submit the tasks.
+        tasks = [executor.submit(*fn) for fn in fns]  # type: ignore
+
+        # Get the results in the order completed to check any exceptions.
+        for task in futures.as_completed(tasks):
+            task.result()
+
+
 def stream_logs(
     message: str,
     process: subprocess.Popen,
@@ -165,21 +191,27 @@ def stream_logs(
 
     Yields:
         The lines of the process output.
+
+    Raises:
+        Exit: If the process failed.
     """
+    # Store the tail of the logs.
+    logs = collections.deque(maxlen=512)
     with process:
         console.debug(message)
         if process.stdout is None:
             return
         for line in process.stdout:
             console.debug(line, end="")
+            logs.append(line)
             yield line
 
     if process.returncode != 0:
-        console.error(f"Error during {message}")
-        console.error(
-            "Run in with [bold]--loglevel debug[/bold] to see the full error."
-        )
-        os._exit(1)
+        console.error(f"{message} failed with exit code {process.returncode}")
+        for line in logs:
+            console.error(line, end="")
+        console.error("Run with [bold]--loglevel debug [/bold] for the full log.")
+        raise typer.Exit(1)
 
 
 def show_logs(

+ 4 - 0
tests/test_testing.py

@@ -1,5 +1,6 @@
 """Unit tests for the included testing tools."""
 from reflex.testing import AppHarness
+from reflex.utils.prerequisites import IS_WINDOWS
 
 
 def test_app_harness(tmp_path):
@@ -8,6 +9,9 @@ def test_app_harness(tmp_path):
     Args:
         tmp_path: pytest tmp_path fixture
     """
+    # Skip in Windows CI.
+    if IS_WINDOWS:
+        return
 
     def BasicApp():
         import reflex as rx