run.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import asyncio
  2. import sys
  3. import traceback
  4. from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
  5. from functools import partial
  6. from typing import Any, Callable, TypeVar
  7. from typing_extensions import ParamSpec
  8. process_pool = ProcessPoolExecutor()
  9. thread_pool = ThreadPoolExecutor()
  10. P = ParamSpec('P')
  11. R = TypeVar('R')
  12. class SubprocessException(Exception):
  13. """A picklable exception to represent exceptions raised in subprocesses."""
  14. def __init__(self, original_type, original_message, original_traceback) -> None:
  15. self.original_type = original_type
  16. self.original_message = original_message
  17. self.original_traceback = original_traceback
  18. super().__init__(f'{original_type}: {original_message}')
  19. def __reduce__(self):
  20. return (SubprocessException, (self.original_type, self.original_message, self.original_traceback))
  21. def __str__(self):
  22. return (f'Exception in subprocess:\n'
  23. f' Type: {self.original_type}\n'
  24. f' Message: {self.original_message}\n'
  25. f' {self.original_traceback}')
  26. def safe_callback(callback: Callable, *args, **kwargs) -> Any:
  27. """Run a callback; catch and wrap any exceptions that might occur."""
  28. try:
  29. return callback(*args, **kwargs)
  30. except Exception as e:
  31. # NOTE: we do not want to pass the original exception because it might be unpicklable
  32. raise SubprocessException(type(e).__name__, str(e), traceback.format_exc()) from None
  33. async def _run(executor: Any, callback: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
  34. # TODO
  35. # if core.app.is_stopping:
  36. # return # type: ignore # the assumption is that the user's code no longer cares about this value
  37. try:
  38. loop = asyncio.get_running_loop()
  39. return await loop.run_in_executor(executor, partial(callback, *args, **kwargs))
  40. except RuntimeError as e:
  41. if 'cannot schedule new futures after shutdown' not in str(e):
  42. raise
  43. except asyncio.CancelledError:
  44. pass
  45. return # type: ignore # the assumption is that the user's code no longer cares about this value
  46. async def cpu_bound(callback: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
  47. """Run a CPU-bound function in a separate process.
  48. `run.cpu_bound` needs to execute the function in a separate process.
  49. For this it needs to transfer the whole state of the passed function to the process (which is done with pickle).
  50. It is encouraged to create static methods (or free functions) which get all the data as simple parameters (eg. no class/ui logic)
  51. and return the result (instead of writing it in class properties or global variables).
  52. """
  53. return await _run(process_pool, safe_callback, callback, *args, **kwargs)
  54. async def io_bound(callback: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
  55. """Run an I/O-bound function in a separate thread."""
  56. return await _run(thread_pool, callback, *args, **kwargs)
  57. def tear_down() -> None:
  58. """Kill all processes and threads."""
  59. # TODO
  60. # if helpers.is_pytest():
  61. # return
  62. for p in process_pool._processes.values(): # pylint: disable=protected-access
  63. p.kill()
  64. kwargs = {'cancel_futures': True} if sys.version_info >= (3, 9) else {}
  65. process_pool.shutdown(wait=True, **kwargs)
  66. thread_pool.shutdown(wait=False, **kwargs)