Przeglądaj źródła

Avoid double JSON encode/decode for socket.io (#4449)

* Avoid double JSON encode/decode for socket.io

socket.io (python and js) already has a built in mechanism for JSON encoding
and decoding messages over the websocket. To use it, we pass a custom `json`
namespace which uses `format.json_dumps` (leveraging reflex serializers) to encode the
messages. This avoids sending a JSON-encoded string of JSON over the wire, and
reduces the number of serialization/deserialization passes over the message
data.

The side benefit is that debugging websocket messages in browser tools displays
the parsed JSON hierarchy and is much easier to work with.

* JSON5.parse in on_upload_progress handler responses
Masen Furer 5 miesięcy temu
rodzic
commit
a2f14e7713

+ 30 - 21
reflex/.templates/web/utils/state.js

@@ -300,7 +300,7 @@ export const applyEvent = async (event, socket) => {
   if (socket) {
     socket.emit(
       "event",
-      JSON.stringify(event, (k, v) => (v === undefined ? null : v))
+      event,
     );
     return true;
   }
@@ -407,6 +407,8 @@ export const connect = async (
     transports: transports,
     autoUnref: false,
   });
+  // Ensure undefined fields in events are sent as null instead of removed
+  socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v)
 
   function checkVisibility() {
     if (document.visibilityState === "visible") {
@@ -443,8 +445,7 @@ export const connect = async (
   });
 
   // On each received message, queue the updates and events.
-  socket.current.on("event", async (message) => {
-    const update = JSON5.parse(message);
+  socket.current.on("event", async (update) => {
     for (const substate in update.delta) {
       dispatch[substate](update.delta[substate]);
     }
@@ -456,7 +457,7 @@ export const connect = async (
   });
   socket.current.on("reload", async (event) => {
     event_processing = false;
-    queueEvents([...initialEvents(), JSON5.parse(event)], socket);
+    queueEvents([...initialEvents(), event], socket);
   });
 
   document.addEventListener("visibilitychange", checkVisibility);
@@ -497,23 +498,31 @@ export const uploadFiles = async (
     // Whenever called, responseText will contain the entire response so far.
     const chunks = progressEvent.event.target.responseText.trim().split("\n");
     // So only process _new_ chunks beyond resp_idx.
-    chunks.slice(resp_idx).map((chunk) => {
-      event_callbacks.map((f, ix) => {
-        f(chunk)
-          .then(() => {
-            if (ix === event_callbacks.length - 1) {
-              // Mark this chunk as processed.
-              resp_idx += 1;
-            }
-          })
-          .catch((e) => {
-            if (progressEvent.progress === 1) {
-              // Chunk may be incomplete, so only report errors when full response is available.
-              console.log("Error parsing chunk", chunk, e);
-            }
-            return;
-          });
-      });
+    chunks.slice(resp_idx).map((chunk_json) => {
+      try {
+        const chunk = JSON5.parse(chunk_json);
+        event_callbacks.map((f, ix) => {
+          f(chunk)
+            .then(() => {
+              if (ix === event_callbacks.length - 1) {
+                // Mark this chunk as processed.
+                resp_idx += 1;
+              }
+            })
+            .catch((e) => {
+              if (progressEvent.progress === 1) {
+                // Chunk may be incomplete, so only report errors when full response is available.
+                console.log("Error processing chunk", chunk, e);
+              }
+              return;
+            });
+        });
+      } catch (e) {
+        if (progressEvent.progress === 1) {
+          console.log("Error parsing chunk", chunk_json, e);
+        }
+        return;
+      }
     });
   };
 

+ 7 - 2
reflex/app.py

@@ -17,6 +17,7 @@ import sys
 import traceback
 from datetime import datetime
 from pathlib import Path
+from types import SimpleNamespace
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -363,6 +364,10 @@ class App(MiddlewareMixin, LifespanMixin):
                 max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
                 ping_interval=constants.Ping.INTERVAL,
                 ping_timeout=constants.Ping.TIMEOUT,
+                json=SimpleNamespace(
+                    dumps=staticmethod(format.json_dumps),
+                    loads=staticmethod(json.loads),
+                ),
                 transports=["websocket"],
             )
         elif getattr(self.sio, "async_mode", "") != "asgi":
@@ -1543,7 +1548,7 @@ class EventNamespace(AsyncNamespace):
         """
         # Creating a task prevents the update from being blocked behind other coroutines.
         await asyncio.create_task(
-            self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
+            self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
         )
 
     async def on_event(self, sid, data):
@@ -1556,7 +1561,7 @@ class EventNamespace(AsyncNamespace):
             sid: The Socket.IO session id.
             data: The event data.
         """
-        fields = json.loads(data)
+        fields = data
         # Get the event.
         event = Event(
             **{k: v for k, v in fields.items() if k not in ("handler", "event_actions")}

+ 6 - 2
reflex/utils/format.py

@@ -664,18 +664,22 @@ def format_library_name(library_fullname: str):
     return lib
 
 
-def json_dumps(obj: Any) -> str:
+def json_dumps(obj: Any, **kwargs) -> str:
     """Takes an object and returns a jsonified string.
 
     Args:
         obj: The object to be serialized.
+        kwargs: Additional keyword arguments to pass to json.dumps.
 
     Returns:
         A string
     """
     from reflex.utils import serializers
 
-    return json.dumps(obj, ensure_ascii=False, default=serializers.serialize)
+    kwargs.setdefault("ensure_ascii", False)
+    kwargs.setdefault("default", serializers.serialize)
+
+    return json.dumps(obj, **kwargs)
 
 
 def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]:

+ 54 - 38
tests/units/test_state.py

@@ -1840,6 +1840,24 @@ async def test_state_manager_lock_expire_contend(
     assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
 
 
+class CopyingAsyncMock(AsyncMock):
+    """An AsyncMock, but deepcopy the args and kwargs first."""
+
+    def __call__(self, *args, **kwargs):
+        """Call the mock.
+
+        Args:
+            args: the arguments passed to the mock
+            kwargs: the keyword arguments passed to the mock
+
+        Returns:
+            The result of the mock call
+        """
+        args = copy.deepcopy(args)
+        kwargs = copy.deepcopy(kwargs)
+        return super().__call__(*args, **kwargs)
+
+
 @pytest.fixture(scope="function")
 def mock_app_simple(monkeypatch) -> rx.App:
     """Simple Mock app fixture.
@@ -1856,7 +1874,7 @@ def mock_app_simple(monkeypatch) -> rx.App:
 
     setattr(app_module, CompileVars.APP, app)
     app.state = TestState
-    app.event_namespace.emit = AsyncMock()  # type: ignore
+    app.event_namespace.emit = CopyingAsyncMock()  # type: ignore
 
     def _mock_get_app(*args, **kwargs):
         return app_module
@@ -1960,21 +1978,19 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
     mock_app.event_namespace.emit.assert_called_once()
     mcall = mock_app.event_namespace.emit.mock_calls[0]
     assert mcall.args[0] == str(SocketEvent.EVENT)
-    assert json.loads(mcall.args[1]) == dataclasses.asdict(
-        StateUpdate(
-            delta={
-                parent_state.get_full_name(): {
-                    "upper": "",
-                    "sum": 3.14,
-                },
-                grandchild_state.get_full_name(): {
-                    "value2": "42",
-                },
-                GrandchildState3.get_full_name(): {
-                    "computed": "",
-                },
-            }
-        )
+    assert mcall.args[1] == StateUpdate(
+        delta={
+            parent_state.get_full_name(): {
+                "upper": "",
+                "sum": 3.14,
+            },
+            grandchild_state.get_full_name(): {
+                "value2": "42",
+            },
+            GrandchildState3.get_full_name(): {
+                "computed": "",
+            },
+        }
     )
     assert mcall.kwargs["to"] == grandchild_state.router.session.session_id
 
@@ -2156,51 +2172,51 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
     assert mock_app.event_namespace is not None
     emit_mock = mock_app.event_namespace.emit
 
-    first_ws_message = json.loads(emit_mock.mock_calls[0].args[1])
+    first_ws_message = emit_mock.mock_calls[0].args[1]
     assert (
-        first_ws_message["delta"][BackgroundTaskState.get_full_name()].pop("router")
+        first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
         is not None
     )
-    assert first_ws_message == {
-        "delta": {
+    assert first_ws_message == StateUpdate(
+        delta={
             BackgroundTaskState.get_full_name(): {
                 "order": ["background_task:start"],
                 "computed_order": ["background_task:start"],
             }
         },
-        "events": [],
-        "final": True,
-    }
+        events=[],
+        final=True,
+    )
     for call in emit_mock.mock_calls[1:5]:
-        assert json.loads(call.args[1]) == {
-            "delta": {
+        assert call.args[1] == StateUpdate(
+            delta={
                 BackgroundTaskState.get_full_name(): {
                     "computed_order": ["background_task:start"],
                 }
             },
-            "events": [],
-            "final": True,
-        }
-    assert json.loads(emit_mock.mock_calls[-2].args[1]) == {
-        "delta": {
+            events=[],
+            final=True,
+        )
+    assert emit_mock.mock_calls[-2].args[1] == StateUpdate(
+        delta={
             BackgroundTaskState.get_full_name(): {
                 "order": exp_order,
                 "computed_order": exp_order,
                 "dict_list": {},
             }
         },
-        "events": [],
-        "final": True,
-    }
-    assert json.loads(emit_mock.mock_calls[-1].args[1]) == {
-        "delta": {
+        events=[],
+        final=True,
+    )
+    assert emit_mock.mock_calls[-1].args[1] == StateUpdate(
+        delta={
             BackgroundTaskState.get_full_name(): {
                 "computed_order": exp_order,
             },
         },
-        "events": [],
-        "final": True,
-    }
+        events=[],
+        final=True,
+    )
 
 
 @pytest.mark.asyncio