浏览代码

return websocket protocol when asked (#4683)

Khaleel Al-Adhami 4 月之前
父节点
当前提交
abc9038580
共有 2 个文件被更改,包括 26 次插入2 次删除
  1. 1 1
      reflex/.templates/web/utils/state.js
  2. 25 1
      reflex/app.py

+ 1 - 1
reflex/.templates/web/utils/state.js

@@ -408,7 +408,7 @@ export const connect = async (
   socket.current = io(endpoint.href, {
     path: endpoint["pathname"],
     transports: transports,
-    protocols: env.TEST_MODE ? undefined : [reflexEnvironment.version],
+    protocols: [reflexEnvironment.version],
     autoUnref: false,
   });
   // Ensure undefined fields in events are sent as null instead of removed

+ 25 - 1
reflex/app.py

@@ -405,7 +405,31 @@ class App(MiddlewareMixin, LifespanMixin):
         self.sio.register_namespace(self.event_namespace)
         # Mount the socket app with the API.
         if self.api:
-            self.api.mount(str(constants.Endpoint.EVENT), socket_app)
+
+            class HeaderMiddleware:
+                def __init__(self, app):
+                    self.app = app
+
+                async def __call__(self, scope, receive, send):
+                    original_send = send
+
+                    async def modified_send(message):
+                        headers = dict(scope["headers"])
+                        protocol_key = b"sec-websocket-protocol"
+                        if (
+                            message["type"] == "websocket.accept"
+                            and protocol_key in headers
+                        ):
+                            message["headers"] = [
+                                *message.get("headers", []),
+                                (b"sec-websocket-protocol", headers[protocol_key]),
+                            ]
+                        return await original_send(message)
+
+                    return await self.app(scope, receive, modified_send)
+
+            socket_app_with_headers = HeaderMiddleware(socket_app)
+            self.api.mount(str(constants.Endpoint.EVENT), socket_app_with_headers)
 
         # Check the exception handlers
         self._validate_exception_handlers()