Jelajahi Sumber

Add timeout parameter to run_in_thread function

Co-Authored-By: Alek Petuskey <alek@pynecone.io>
Devin AI 2 minggu lalu
induk
melakukan
21ab114ed1
2 mengubah file dengan 51 tambahan dan 2 penghapusan
  1. 9 2
      reflex/utils/misc.py
  2. 42 0
      tests/units/utils/test_misc.py

+ 9 - 2
reflex/utils/misc.py

@@ -5,20 +5,27 @@ from collections.abc import Callable
 from typing import Any
 
 
-async def run_in_thread(func: Callable) -> Any:
+async def run_in_thread(func: Callable, *, timeout: float | None = None) -> Any:
     """Run a function in a separate thread.
 
     To not block the UI event queue, run_in_thread must be inside inside a rx.event(background=True) decorated method.
 
     Args:
         func: The non-async function to run.
+        timeout: Maximum number of seconds to wait for the function to complete.
+                If None (default), wait indefinitely.
 
     Raises:
         ValueError: If the function is an async function.
+        asyncio.TimeoutError: If the function execution exceeds the specified timeout.
 
     Returns:
         Any: The return value of the function.
     """
     if asyncio.coroutines.iscoroutinefunction(func):
         raise ValueError("func must be a non-async function")
-    return await asyncio.get_event_loop().run_in_executor(None, func)
+
+    task = asyncio.get_event_loop().run_in_executor(None, func)
+    if timeout is not None:
+        return await asyncio.wait_for(task, timeout=timeout)
+    return await task

+ 42 - 0
tests/units/utils/test_misc.py

@@ -0,0 +1,42 @@
+"""Test misc utilities."""
+
+import asyncio
+import time
+
+import pytest
+
+from reflex.utils.misc import run_in_thread
+
+
+async def test_run_in_thread():
+    """Test that run_in_thread runs a function in a separate thread."""
+
+    def simple_function():
+        return 42
+
+    result = await run_in_thread(simple_function)
+    assert result == 42
+
+    def slow_function():
+        time.sleep(0.1)
+        return "completed"
+
+    result = await run_in_thread(slow_function, timeout=0.5)
+    assert result == "completed"
+
+    async def async_function():
+        return 42
+
+    with pytest.raises(ValueError):
+        await run_in_thread(async_function)
+
+
+async def test_run_in_thread_timeout():
+    """Test that run_in_thread raises TimeoutError when timeout is exceeded."""
+
+    def very_slow_function():
+        time.sleep(0.5)
+        return "should not reach here"
+
+    with pytest.raises(asyncio.TimeoutError):
+        await run_in_thread(very_slow_function, timeout=0.1)