Ver Fonte

Fixed unused kwargs in reflex.App (#3170)

Angelina Sheyko há 1 ano atrás
pai
commit
4b6a7ff8e3
1 ficheiros alterados com 25 adições e 18 exclusões
  1. 25 18
      reflex/app.py

+ 25 - 18
reflex/app.py

@@ -116,9 +116,6 @@ class App(Base):
     # The Socket.IO AsyncServer.
     # The Socket.IO AsyncServer.
     sio: Optional[AsyncServer] = None
     sio: Optional[AsyncServer] = None
 
 
-    # The socket app.
-    socket_app: Optional[ASGIApp] = None
-
     # The state class to use for the app.
     # The state class to use for the app.
     state: Optional[Type[BaseState]] = None
     state: Optional[Type[BaseState]] = None
 
 
@@ -213,7 +210,11 @@ class App(Base):
             self.setup_state()
             self.setup_state()
 
 
     def setup_state(self) -> None:
     def setup_state(self) -> None:
-        """Set up the state for the app."""
+        """Set up the state for the app.
+
+        Raises:
+            RuntimeError: If custom `sio` does not use `async_mode='asgi'`.
+        """
         if not self.state:
         if not self.state:
             return
             return
 
 
@@ -223,21 +224,27 @@ class App(Base):
         self._state_manager = StateManager.create(state=self.state)
         self._state_manager = StateManager.create(state=self.state)
 
 
         # Set up the Socket.IO AsyncServer.
         # Set up the Socket.IO AsyncServer.
-        self.sio = AsyncServer(
-            async_mode="asgi",
-            cors_allowed_origins=(
-                "*"
-                if config.cors_allowed_origins == ["*"]
-                else config.cors_allowed_origins
-            ),
-            cors_credentials=True,
-            max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
-            ping_interval=constants.Ping.INTERVAL,
-            ping_timeout=constants.Ping.TIMEOUT,
-        )
+        if not self.sio:
+            self.sio = AsyncServer(
+                async_mode="asgi",
+                cors_allowed_origins=(
+                    "*"
+                    if config.cors_allowed_origins == ["*"]
+                    else config.cors_allowed_origins
+                ),
+                cors_credentials=True,
+                max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
+                ping_interval=constants.Ping.INTERVAL,
+                ping_timeout=constants.Ping.TIMEOUT,
+            )
+        elif getattr(self.sio, "async_mode", "") != "asgi":
+            raise RuntimeError(
+                f"Custom `sio` must use `async_mode='asgi'`, not '{self.sio.async_mode}'."
+            )
 
 
         # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
         # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
-        self.socket_app = ASGIApp(self.sio, socketio_path="")
+        socket_app = ASGIApp(self.sio, socketio_path="")
+
         namespace = config.get_event_namespace()
         namespace = config.get_event_namespace()
 
 
         # Create the event namespace and attach the main app. Not related to any paths.
         # Create the event namespace and attach the main app. Not related to any paths.
@@ -246,7 +253,7 @@ class App(Base):
         # Register the event namespace with the socket.
         # Register the event namespace with the socket.
         self.sio.register_namespace(self.event_namespace)
         self.sio.register_namespace(self.event_namespace)
         # Mount the socket app with the API.
         # Mount the socket app with the API.
-        self.api.mount(str(constants.Endpoint.EVENT), self.socket_app)
+        self.api.mount(str(constants.Endpoint.EVENT), socket_app)
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         """Get the string representation of the app.
         """Get the string representation of the app.