|
@@ -2,15 +2,18 @@ import logging
|
|
|
import sys
|
|
|
import traceback
|
|
|
from contextlib import contextmanager
|
|
|
-
|
|
|
+import asyncio
|
|
|
from .utils import random_str
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class WebIOFuture:
|
|
|
+ def __init__(self, coro=None):
|
|
|
+ self.coro = coro
|
|
|
+
|
|
|
def __iter__(self):
|
|
|
- result = yield
|
|
|
+ result = yield self
|
|
|
return result
|
|
|
|
|
|
__await__ = __iter__ # make compatible with 'await' expression
|
|
@@ -156,10 +159,10 @@ class Task:
|
|
|
logger.debug('Task[%s] created ', self.coro_id)
|
|
|
|
|
|
def step(self, result=None):
|
|
|
- future_or_none = None
|
|
|
+ coro_yield = None
|
|
|
with self.ws_context():
|
|
|
try:
|
|
|
- future_or_none = self.coro.send(result)
|
|
|
+ coro_yield = self.coro.send(result)
|
|
|
except StopIteration as e:
|
|
|
if len(e.args) == 1:
|
|
|
self.result = e.args[0]
|
|
@@ -168,10 +171,15 @@ class Task:
|
|
|
except Exception as e:
|
|
|
self.ws.on_coro_error()
|
|
|
|
|
|
- if not isinstance(future_or_none, WebIOFuture) and future_or_none is not None:
|
|
|
- if not self.ws.closed():
|
|
|
- future_or_none.add_done_callback(self._tornado_future_callback)
|
|
|
- self.pending_futures[id(future_or_none)] = future_or_none
|
|
|
+ future = None
|
|
|
+ if isinstance(coro_yield, WebIOFuture):
|
|
|
+ if coro_yield.coro:
|
|
|
+ future = asyncio.run_coroutine_threadsafe(coro_yield.coro, asyncio.get_event_loop())
|
|
|
+ elif coro_yield is not None:
|
|
|
+ future = coro_yield
|
|
|
+ if not self.ws.closed() and hasattr(future, 'add_done_callback'):
|
|
|
+ future.add_done_callback(self._tornado_future_callback)
|
|
|
+ self.pending_futures[id(future)] = future
|
|
|
|
|
|
def _tornado_future_callback(self, future):
|
|
|
if not future.cancelled():
|