|
@@ -148,7 +148,9 @@ class App(Base):
|
|
|
allow_origins=["*"],
|
|
|
)
|
|
|
|
|
|
- def preprocess(self, state: State, event: Event) -> Optional[Delta]:
|
|
|
+ async def preprocess(
|
|
|
+ self, state: State, event: Event
|
|
|
+ ) -> Optional[Union[StateUpdate, List[StateUpdate]]]:
|
|
|
"""Preprocess the event.
|
|
|
|
|
|
This is where middleware can modify the event before it is processed.
|
|
@@ -165,11 +167,13 @@ class App(Base):
|
|
|
An optional state to return.
|
|
|
"""
|
|
|
for middleware in self.middleware:
|
|
|
- out = middleware.preprocess(app=self, state=state, event=event)
|
|
|
+ out = await middleware.preprocess(app=self, state=state, event=event)
|
|
|
if out is not None:
|
|
|
return out
|
|
|
|
|
|
- def postprocess(self, state: State, event: Event, delta: Delta) -> Optional[Delta]:
|
|
|
+ async def postprocess(
|
|
|
+ self, state: State, event: Event, delta: Delta
|
|
|
+ ) -> Optional[Delta]:
|
|
|
"""Postprocess the event.
|
|
|
|
|
|
This is where middleware can modify the delta after it is processed.
|
|
@@ -187,7 +191,7 @@ class App(Base):
|
|
|
An optional state to return.
|
|
|
"""
|
|
|
for middleware in self.middleware:
|
|
|
- out = middleware.postprocess(
|
|
|
+ out = await middleware.postprocess(
|
|
|
app=self, state=state, event=event, delta=delta
|
|
|
)
|
|
|
if out is not None:
|
|
@@ -400,7 +404,7 @@ class App(Base):
|
|
|
|
|
|
async def process(
|
|
|
app: App, event: Event, sid: str, headers: Dict, client_ip: str
|
|
|
-) -> StateUpdate:
|
|
|
+) -> Union[StateUpdate, List[StateUpdate]]:
|
|
|
"""Process an event.
|
|
|
|
|
|
Args:
|
|
@@ -411,7 +415,7 @@ async def process(
|
|
|
client_ip: The client_ip.
|
|
|
|
|
|
Returns:
|
|
|
- The state update after processing the event.
|
|
|
+ The state update(s) after processing the event.
|
|
|
"""
|
|
|
# Get the state for the session.
|
|
|
state = app.state_manager.get_state(event.token)
|
|
@@ -430,21 +434,27 @@ async def process(
|
|
|
state.router_data[constants.RouteVar.CLIENT_IP] = client_ip
|
|
|
|
|
|
# Preprocess the event.
|
|
|
- pre = app.preprocess(state, event)
|
|
|
- if pre is not None:
|
|
|
- return StateUpdate(delta=pre)
|
|
|
+ pre = await app.preprocess(state, event)
|
|
|
+ if pre is not None and not isinstance(pre, List):
|
|
|
+ return pre
|
|
|
|
|
|
# Apply the event to the state.
|
|
|
- update = await state.process(event)
|
|
|
+ updates = pre if pre else await state.process(event)
|
|
|
app.state_manager.set_state(event.token, state)
|
|
|
|
|
|
+ updates = updates if isinstance(updates, List) else [updates]
|
|
|
+
|
|
|
# Postprocess the event.
|
|
|
- post = app.postprocess(state, event, update.delta)
|
|
|
- if post is not None:
|
|
|
- return StateUpdate(delta=post)
|
|
|
+ post_list = []
|
|
|
+ for update in updates:
|
|
|
+ post = await app.postprocess(state, event, update.delta) # type: ignore
|
|
|
+ post_list.append(post) if post else None
|
|
|
+
|
|
|
+ if post_list:
|
|
|
+ return [StateUpdate(delta=post) for post in post_list]
|
|
|
|
|
|
# Return the update.
|
|
|
- return update
|
|
|
+ return updates
|
|
|
|
|
|
|
|
|
async def ping() -> str:
|
|
@@ -578,11 +588,12 @@ class EventNamespace(AsyncNamespace):
|
|
|
# Get the client IP
|
|
|
client_ip = environ["REMOTE_ADDR"]
|
|
|
|
|
|
- # Process the event.
|
|
|
- update = await process(self.app, event, sid, headers, client_ip)
|
|
|
+ # Process the events.
|
|
|
+ updates = await process(self.app, event, sid, headers, client_ip)
|
|
|
|
|
|
# Emit the event.
|
|
|
- await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
|
|
|
+ for update in updates:
|
|
|
+ await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore
|
|
|
|
|
|
async def on_ping(self, sid):
|
|
|
"""Event for testing the API endpoint.
|