ソースを参照

do not broadcast event to all clients if modify state is misused (#5322)

* do not broadcast event to all clients if modify state is misused

* fix tests

* maybe?
Khaleel Al-Adhami 1 週間 前
コミット
5f2ebf1c06
2 ファイル変更32 行追加4 行削除
  1. 7 0
      reflex/app.py
  2. 25 4
      tests/units/test_state.py

+ 7 - 0
reflex/app.py

@@ -1996,6 +1996,13 @@ class EventNamespace(AsyncNamespace):
             update: The state update to send.
             sid: The Socket.IO session id.
         """
+        if not sid:
+            # If the sid is None, we are not connected to a client. Prevent sending
+            # updates to all clients.
+            return
+        if sid not in self.sid_to_token:
+            console.warn(f"Attempting to send delta to disconnected websocket {sid}")
+            return
         # Creating a task prevents the update from being blocked behind other coroutines.
         await asyncio.create_task(
             self.emit(str(constants.SocketEvent.EVENT), update, to=sid)

+ 25 - 4
tests/units/test_state.py

@@ -1948,21 +1948,37 @@ class ModelDC:
 
 
 @pytest.mark.asyncio
-async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
+async def test_state_proxy(
+    grandchild_state: GrandchildState, mock_app: rx.App, token: str
+):
     """Test that the state proxy works.
 
     Args:
         grandchild_state: A grandchild state.
         mock_app: An app that will be returned by `get_app()`
+        token: A token.
     """
     child_state = grandchild_state.parent_state
     assert child_state is not None
     parent_state = child_state.parent_state
     assert parent_state is not None
+    router_data = RouterData({"query": {}, "token": token, "sid": "test_sid"})
+    grandchild_state.router = router_data
+    namespace = mock_app.event_namespace
+    assert namespace is not None
+    namespace.sid_to_token[router_data.session.session_id] = token
     if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
         mock_app.state_manager.states[parent_state.router.session.client_token] = (
             parent_state
         )
+    elif isinstance(mock_app.state_manager, StateManagerRedis):
+        pickle_state = parent_state._serialize()
+        if pickle_state:
+            await mock_app.state_manager.redis.set(
+                _substate_key(parent_state.router.session.client_token, parent_state),
+                pickle_state,
+                ex=mock_app.state_manager.token_expiration,
+            )
 
     sp = StateProxy(grandchild_state)
     assert sp.__wrapped__ == grandchild_state
@@ -2029,6 +2045,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
     assert mcall.args[0] == str(SocketEvent.EVENT)
     assert mcall.args[1] == StateUpdate(
         delta={
+            TestState.get_full_name(): {"router": router_data},
             grandchild_state.get_full_name(): {
                 "value2": "42",
             },
@@ -2154,7 +2171,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
         mock_app: An app that will be returned by `get_app()`
         token: A token.
     """
-    router_data = {"query": {}}
+    router_data = {"query": {}, "token": token}
+    sid = "test_sid"
+    namespace = mock_app.event_namespace
+    assert namespace is not None
+    namespace.sid_to_token[sid] = token
     mock_app.state_manager.state = mock_app._state = BackgroundTaskState
     async for update in rx.app.process(
         mock_app,
@@ -2164,7 +2185,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
             router_data=router_data,
             payload={},
         ),
-        sid="",
+        sid=sid,
         headers={},
         client_ip="",
     ):
@@ -2184,7 +2205,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
             router_data=router_data,
             payload={},
         ),
-        sid="",
+        sid=sid,
         headers={},
         client_ip="",
     ):