فهرست منبع

No state No Websocket (#1950)

Elijah Ahianyo 1 سال پیش
والد
کامیت
433ccda3a6
3فایلهای تغییر یافته به همراه41 افزوده شده و 36 حذف شده
  1. 12 10
      reflex/.templates/web/utils/state.js
  2. 23 22
      reflex/app.py
  3. 6 4
      tests/test_state.py

+ 12 - 10
reflex/.templates/web/utils/state.js

@@ -507,17 +507,19 @@ export const useEventLoop = (
     if (!router.isReady) {
       return;
     }
-
-    // Initialize the websocket connection.
-    if (!socket.current) {
-      connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage)
-    }
-    (async () => {
-      // Process all outstanding events.
-      while (event_queue.length > 0 && !event_processing) {
-        await processEvent(socket.current)
+    // only use websockets if state is present
+    if (Object.keys(state).length > 0) {
+      // Initialize the websocket connection.
+      if (!socket.current) {
+        connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage)
       }
-    })()
+      (async () => {
+        // Process all outstanding events.
+        while (event_queue.length > 0 && !event_processing) {
+          await processEvent(socket.current)
+        }
+      })()
+    }
   })
   return [state, addEvents, connectError]
 }

+ 23 - 22
reflex/app.py

@@ -175,32 +175,33 @@ class App(Base):
         self.add_cors()
         self.add_default_endpoints()
 
-        # Set up the Socket.IO AsyncServer.
-        self.sio = AsyncServer(
-            async_mode="asgi",
-            cors_allowed_origins="*"
-            if config.cors_allowed_origins == ["*"]
-            else config.cors_allowed_origins,
-            cors_credentials=True,
-            max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
-            ping_interval=constants.Ping.INTERVAL,
-            ping_timeout=constants.Ping.TIMEOUT,
-        )
+        if self.state is not DefaultState:
+            # Set up the Socket.IO AsyncServer.
+            self.sio = AsyncServer(
+                async_mode="asgi",
+                cors_allowed_origins="*"
+                if config.cors_allowed_origins == ["*"]
+                else config.cors_allowed_origins,
+                cors_credentials=True,
+                max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
+                ping_interval=constants.Ping.INTERVAL,
+                ping_timeout=constants.Ping.TIMEOUT,
+            )
 
-        # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
-        self.socket_app = ASGIApp(self.sio, socketio_path="")
-        namespace = config.get_event_namespace()
+            # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
+            self.socket_app = ASGIApp(self.sio, socketio_path="")
+            namespace = config.get_event_namespace()
 
-        if not namespace:
-            raise ValueError("event namespace must be provided in the config.")
+            if not namespace:
+                raise ValueError("event namespace must be provided in the config.")
 
-        # Create the event namespace and attach the main app. Not related to any paths.
-        self.event_namespace = EventNamespace(namespace, self)
+            # Create the event namespace and attach the main app. Not related to any paths.
+            self.event_namespace = EventNamespace(namespace, self)
 
-        # Register the event namespace with the socket.
-        self.sio.register_namespace(self.event_namespace)
-        # Mount the socket app with the API.
-        self.api.mount(str(constants.Endpoint.EVENT), self.socket_app)
+            # Register the event namespace with the socket.
+            self.sio.register_namespace(self.event_namespace)
+            # Mount the socket app with the API.
+            self.api.mount(str(constants.Endpoint.EVENT), self.socket_app)
 
         # Set up the admin dash.
         self.setup_admin_dash()

+ 6 - 4
tests/test_state.py

@@ -14,6 +14,7 @@ import pytest
 from plotly.graph_objects import Figure
 
 import reflex as rx
+from reflex.app import App
 from reflex.base import Base
 from reflex.constants import CompileVars, RouteVar, SocketEvent
 from reflex.event import Event, EventHandler
@@ -1528,23 +1529,24 @@ async def test_state_manager_lock_expire_contend(
 
 
 @pytest.fixture(scope="function")
-def mock_app(monkeypatch, app: rx.App, state_manager: StateManager) -> rx.App:
+def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
     """Mock app fixture.
 
     Args:
         monkeypatch: Pytest monkeypatch object.
-        app: An app.
         state_manager: A state manager.
 
     Returns:
         The app, after mocking out prerequisites.get_app()
     """
+    app = App(state=TestState)
+
     app_module = Mock()
+
     setattr(app_module, CompileVars.APP, app)
     app.state = TestState
     app.state_manager = state_manager
-    assert app.event_namespace is not None
-    app.event_namespace.emit = AsyncMock()
+    app.event_namespace.emit = AsyncMock()  # type: ignore
     monkeypatch.setattr(prerequisites, "get_app", lambda: app_module)
     return app