Sfoglia il codice sorgente

fix subprotocol for granian (#4698)

* fix subprotocol for granian

* use scope subprotocols

* use subprotocols or headers

* separate the logic
Khaleel Al-Adhami 3 mesi fa
parent
commit
64fb78ac5e
1 ha cambiato i file con 17 aggiunte e 11 eliminazioni
  1. 17 11
      reflex/app.py

+ 17 - 11
reflex/app.py

@@ -27,6 +27,7 @@ from typing import (
     Dict,
     Generic,
     List,
+    MutableMapping,
     Optional,
     Set,
     Type,
@@ -410,20 +411,25 @@ class App(MiddlewareMixin, LifespanMixin):
                 def __init__(self, app):
                     self.app = app
 
-                async def __call__(self, scope, receive, send):
+                async def __call__(
+                    self, scope: MutableMapping[str, Any], receive, send
+                ):
                     original_send = send
 
                     async def modified_send(message):
-                        headers = dict(scope["headers"])
-                        protocol_key = b"sec-websocket-protocol"
-                        if (
-                            message["type"] == "websocket.accept"
-                            and protocol_key in headers
-                        ):
-                            message["headers"] = [
-                                *message.get("headers", []),
-                                (b"sec-websocket-protocol", headers[protocol_key]),
-                            ]
+                        if message["type"] == "websocket.accept":
+                            if scope.get("subprotocols"):
+                                # The following *does* say "subprotocol" instead of "subprotocols", intentionally.
+                                message["subprotocol"] = scope["subprotocols"][0]
+
+                            headers = dict(message.get("headers", []))
+                            header_key = b"sec-websocket-protocol"
+                            if subprotocol := headers.get(header_key):
+                                message["headers"] = [
+                                    *message.get("headers", []),
+                                    (header_key, subprotocol),
+                                ]
+
                         return await original_send(message)
 
                     return await self.app(scope, receive, modified_send)