|
@@ -405,7 +405,31 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
self.sio.register_namespace(self.event_namespace)
|
|
self.sio.register_namespace(self.event_namespace)
|
|
# Mount the socket app with the API.
|
|
# Mount the socket app with the API.
|
|
if self.api:
|
|
if self.api:
|
|
- self.api.mount(str(constants.Endpoint.EVENT), socket_app)
|
|
|
|
|
|
+
|
|
|
|
+ class HeaderMiddleware:
|
|
|
|
+ def __init__(self, app):
|
|
|
|
+ self.app = app
|
|
|
|
+
|
|
|
|
+ async def __call__(self, scope, receive, send):
|
|
|
|
+ original_send = send
|
|
|
|
+
|
|
|
|
+ async def modified_send(message):
|
|
|
|
+ headers = dict(scope["headers"])
|
|
|
|
+ protocol_key = b"sec-websocket-protocol"
|
|
|
|
+ if (
|
|
|
|
+ message["type"] == "websocket.accept"
|
|
|
|
+ and protocol_key in headers
|
|
|
|
+ ):
|
|
|
|
+ message["headers"] = [
|
|
|
|
+ *message.get("headers", []),
|
|
|
|
+ (b"sec-websocket-protocol", headers[protocol_key]),
|
|
|
|
+ ]
|
|
|
|
+ return await original_send(message)
|
|
|
|
+
|
|
|
|
+ return await self.app(scope, receive, modified_send)
|
|
|
|
+
|
|
|
|
+ socket_app_with_headers = HeaderMiddleware(socket_app)
|
|
|
|
+ self.api.mount(str(constants.Endpoint.EVENT), socket_app_with_headers)
|
|
|
|
|
|
# Check the exception handlers
|
|
# Check the exception handlers
|
|
self._validate_exception_handlers()
|
|
self._validate_exception_handlers()
|