瀏覽代碼

json patch

Khaleel Al-Adhami 2 月之前
父節點
當前提交
c3e6644221

文件差異過大導致無法顯示
+ 71 - 124
poetry.lock


+ 1 - 0
pyproject.toml

@@ -48,6 +48,7 @@ twine = ">=4.0.0,<7.0"
 tomlkit = ">=0.12.4,<1.0"
 lazy_loader = ">=0.4"
 typing_extensions = ">=4.6.0"
+jsonpatch = "^1.33"
 
 [tool.poetry.group.dev.dependencies]
 pytest = ">=7.1.2,<9.0"

+ 29 - 18
reflex/.templates/web/utils/state.js

@@ -16,6 +16,7 @@ import {
 } from "$/utils/context.js";
 import debounce from "$/utils/helpers/debounce";
 import throttle from "$/utils/helpers/throttle";
+import { applyPatch } from "fast-json-patch/index.mjs";
 
 // Endpoint URLs.
 const EVENTURL = env.EVENT;
@@ -227,8 +228,8 @@ export const applyEvent = async (event, socket) => {
       a.href = eval?.(
         event.payload.url.replace(
           "getBackendURL(env.UPLOAD)",
-          `"${getBackendURL(env.UPLOAD)}"`,
-        ),
+          `"${getBackendURL(env.UPLOAD)}"`
+        )
       );
     }
     a.download = event.payload.filename;
@@ -341,7 +342,7 @@ export const applyRestEvent = async (event, socket) => {
       event.payload.files,
       event.payload.upload_id,
       event.payload.on_upload_progress,
-      socket,
+      socket
     );
     return false;
   }
@@ -408,7 +409,7 @@ export const connect = async (
   dispatch,
   transports,
   setConnectErrors,
-  client_storage = {},
+  client_storage = {}
 ) => {
   // Get backend URL object from the endpoint.
   const endpoint = getBackendURL(EVENTURL);
@@ -464,10 +465,20 @@ export const connect = async (
     window.removeEventListener("pagehide", pagehideHandler);
   });
 
+  const last_substate_info = {};
+
   // On each received message, queue the updates and events.
   socket.current.on("event", async (update) => {
     for (const substate in update.delta) {
-      dispatch[substate](update.delta[substate]);
+      console.log(last_substate_info[substate]);
+      const new_substate_info = update.delta[substate].__patch
+        ? applyPatch(
+            last_substate_info[substate],
+            update.delta[substate].__patch
+          ).newDocument
+        : update.delta[substate];
+      last_substate_info[substate] = new_substate_info;
+      dispatch[substate](new_substate_info);
     }
     applyClientStorageDelta(client_storage, update.delta);
     event_processing = !update.final;
@@ -499,7 +510,7 @@ export const uploadFiles = async (
   files,
   upload_id,
   on_upload_progress,
-  socket,
+  socket
 ) => {
   // return if there's no file to upload
   if (files === undefined || files.length === 0) {
@@ -604,7 +615,7 @@ export const Event = (
   name,
   payload = {},
   event_actions = {},
-  handler = null,
+  handler = null
 ) => {
   return { name, payload, handler, event_actions };
 };
@@ -631,7 +642,7 @@ export const hydrateClientStorage = (client_storage) => {
     for (const state_key in client_storage.local_storage) {
       const options = client_storage.local_storage[state_key];
       const local_storage_value = localStorage.getItem(
-        options.name || state_key,
+        options.name || state_key
       );
       if (local_storage_value !== null) {
         client_storage_values[state_key] = local_storage_value;
@@ -642,7 +653,7 @@ export const hydrateClientStorage = (client_storage) => {
     for (const state_key in client_storage.session_storage) {
       const session_options = client_storage.session_storage[state_key];
       const session_storage_value = sessionStorage.getItem(
-        session_options.name || state_key,
+        session_options.name || state_key
       );
       if (session_storage_value != null) {
         client_storage_values[state_key] = session_storage_value;
@@ -667,7 +678,7 @@ export const hydrateClientStorage = (client_storage) => {
 const applyClientStorageDelta = (client_storage, delta) => {
   // find the main state and check for is_hydrated
   const unqualified_states = Object.keys(delta).filter(
-    (key) => key.split(".").length === 1,
+    (key) => key.split(".").length === 1
   );
   if (unqualified_states.length === 1) {
     const main_state = delta[unqualified_states[0]];
@@ -701,7 +712,7 @@ const applyClientStorageDelta = (client_storage, delta) => {
         const session_options = client_storage.session_storage[state_key];
         sessionStorage.setItem(
           session_options.name || state_key,
-          delta[substate][key],
+          delta[substate][key]
         );
       }
     }
@@ -721,7 +732,7 @@ const applyClientStorageDelta = (client_storage, delta) => {
 export const useEventLoop = (
   dispatch,
   initial_events = () => [],
-  client_storage = {},
+  client_storage = {}
 ) => {
   const socket = useRef(null);
   const router = useRouter();
@@ -735,7 +746,7 @@ export const useEventLoop = (
 
     event_actions = events.reduce(
       (acc, e) => ({ ...acc, ...e.event_actions }),
-      event_actions ?? {},
+      event_actions ?? {}
     );
 
     const _e = args.filter((o) => o?.preventDefault !== undefined)[0];
@@ -763,7 +774,7 @@ export const useEventLoop = (
       debounce(
         combined_name,
         () => queueEvents(events, socket),
-        event_actions.debounce,
+        event_actions.debounce
       );
     } else {
       queueEvents(events, socket);
@@ -782,7 +793,7 @@ export const useEventLoop = (
             query,
             asPath,
           }))(router),
-        })),
+        }))
       );
       sentHydrate.current = true;
     }
@@ -828,7 +839,7 @@ export const useEventLoop = (
           dispatch,
           ["websocket"],
           setConnectErrors,
-          client_storage,
+          client_storage
         );
       }
     }
@@ -876,7 +887,7 @@ export const useEventLoop = (
         vars[storage_to_state_map[e.key]] = e.newValue;
         const event = Event(
           `${state_name}.reflex___state____update_vars_internal_state.update_vars_internal`,
-          { vars: vars },
+          { vars: vars }
         );
         addEvents([event], e);
       }
@@ -969,7 +980,7 @@ export const getRefValues = (refs) => {
   return refs.map((ref) =>
     ref.current
       ? ref.current.value || ref.current.getAttribute("aria-valuenow")
-      : null,
+      : null
   );
 };
 

+ 1 - 1
reflex/app.py

@@ -1407,7 +1407,7 @@ class App(MiddlewareMixin, LifespanMixin):
         async with self.state_manager.modify_state(token) as state:
             # No other event handler can modify the state while in this context.
             yield state
-            delta = state.get_delta()
+            delta = state.get_delta(token=token)
             if delta:
                 # When the state is modified reset dirty status and emit the delta to the frontend.
                 state._clean()

+ 4 - 4
reflex/compiler/utils.py

@@ -27,7 +27,7 @@ from reflex.components.base import (
 )
 from reflex.components.component import Component, ComponentStyle, CustomComponent
 from reflex.istate.storage import Cookie, LocalStorage, SessionStorage
-from reflex.state import BaseState, _resolve_delta
+from reflex.state import BaseState, StateDelta, _resolve_delta
 from reflex.style import Style
 from reflex.utils import console, format, imports, path_ops
 from reflex.utils.exec import is_in_app_harness
@@ -187,7 +187,7 @@ def compile_state(state: Type[BaseState]) -> dict:
     Returns:
         A dictionary of the compiled state.
     """
-    initial_state = state(_reflex_internal_init=True).dict(initial=True)
+    initial_state = StateDelta(state(_reflex_internal_init=True).dict(initial=True))
     try:
         _ = asyncio.get_running_loop()
     except RuntimeError:
@@ -202,10 +202,10 @@ def compile_state(state: Type[BaseState]) -> dict:
                 console.warn(
                     f"Had to get initial state in a thread 🤮 {resolved_initial_state}",
                 )
-                return resolved_initial_state
+                return resolved_initial_state.data
 
     # Normally the compile runs before any event loop starts, we asyncio.run is available for calling.
-    return asyncio.run(_resolve_delta(initial_state))
+    return asyncio.run(_resolve_delta(initial_state)).data
 
 
 def _compile_client_storage_field(

+ 3 - 0
reflex/config.py

@@ -718,6 +718,9 @@ class EnvironmentVariables:
     # Used by flexgen to enumerate the pages.
     REFLEX_ADD_ALL_ROUTES_ENDPOINT: EnvVar[bool] = env_var(False)
 
+    # Use the JSON patch format for websocket messages.
+    REFLEX_USE_JSON_PATCH: EnvVar[bool] = env_var(False)
+
 
 environment = EnvironmentVariables()
 

+ 1 - 0
reflex/constants/installer.py

@@ -188,6 +188,7 @@ class PackageJson(SimpleNamespace):
         "react-dom": "19.0.0",
         "react-focus-lock": "2.13.6",
         "socket.io-client": "4.8.1",
+        "fast-json-patch": "3.1.1",
         "universal-cookie": "7.2.2",
     }
     DEV_DEPENDENCIES = {

+ 8 - 2
reflex/middleware/hydrate_middleware.py

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING
 from reflex import constants
 from reflex.event import Event, get_hydrate_event
 from reflex.middleware.middleware import Middleware
-from reflex.state import BaseState, StateUpdate, _resolve_delta
+from reflex.state import BaseState, StateDelta, StateUpdate, _resolve_delta
 
 if TYPE_CHECKING:
     from reflex.app import App
@@ -42,7 +42,13 @@ class HydrateMiddleware(Middleware):
         setattr(state, constants.CompileVars.IS_HYDRATED, False)
 
         # Get the initial state.
-        delta = await _resolve_delta(state.dict())
+        delta = await _resolve_delta(
+            StateDelta(
+                state.dict(),
+                reflex_delta_token=state.router.session.client_token,
+                flush=True,
+            )
+        )
         # since a full dict was captured, clean any dirtiness
         state._clean()
 

+ 135 - 6
reflex/state.py

@@ -40,6 +40,7 @@ from typing import (
 
 import pydantic.v1 as pydantic
 import wrapt
+from jsonpatch import make_patch
 from pydantic import BaseModel as BaseModelV2
 from pydantic.v1 import BaseModel as BaseModelV1
 from pydantic.v1 import validator
@@ -108,10 +109,135 @@ if TYPE_CHECKING:
     from reflex.components.component import Component
 
 
-Delta = dict[str, Any]
 var = computed_var
 
 
+@dataclasses.dataclass
+class StateDelta:
+    """A dictionary representing the state delta."""
+
+    data: dict[str, Any] = dataclasses.field(default_factory=dict)
+    reflex_delta_token: str | None = dataclasses.field(default=None)
+    flush: bool = dataclasses.field(default=False)
+
+    def __getitem__(self, key: str) -> Any:
+        """Get the item from the delta.
+
+        Args:
+            key: The key to get.
+
+        Returns:
+            The item from the delta.
+        """
+        return self.data[key]
+
+    def __setitem__(self, key: str, value: Any):
+        """Set the item in the delta.
+
+        Args:
+            key: The key to set.
+            value: The value to set.
+        """
+        self.data[key] = value
+
+    def __delitem__(self, key: str):
+        """Delete the item from the delta.
+
+        Args:
+            key: The key to delete.
+        """
+        del self.data[key]
+
+    def __iter__(self) -> Any:
+        """Iterate over the delta.
+
+        Returns:
+            The iterator over the delta.
+        """
+        return iter(self.data)
+
+    def __len__(self) -> int:
+        """Get the length of the delta.
+
+        Returns:
+            The length of the delta.
+        """
+        return len(self.data)
+
+    def __contains__(self, key: str) -> bool:
+        """Check if the delta contains the key.
+
+        Args:
+            key: The key to check.
+
+        Returns:
+            Whether the delta contains the key.
+        """
+        return key in self.data
+
+    def keys(self):
+        """Get the keys of the delta.
+
+        Returns:
+            The keys of the delta.
+        """
+        return self.data.keys()
+
+    def __reversed__(self):
+        """Reverse the delta.
+
+        Returns:
+            The reversed delta.
+        """
+        return reversed(self.data)
+
+    def values(self):
+        """Get the values of the delta.
+
+        Returns:
+            The values of the delta.
+        """
+        return self.data.values()
+
+    def items(self):
+        """Get the items of the delta.
+
+        Returns:
+            The items of the delta.
+        """
+        return self.data.items()
+
+
+LAST_DELTA_CACHE: dict[str, StateDelta] = {}
+
+
+@serializer(to=dict)
+def serialize_state_delta(delta: StateDelta) -> dict[str, Any]:
+    """Serialize the state delta.
+
+    Args:
+        delta: The state delta to serialize.
+
+    Returns:
+        The serialized state delta.
+    """
+    if delta.reflex_delta_token is not None and environment.REFLEX_USE_JSON_PATCH.get():
+        full_delta = {}
+        for state_name, new_state_value in delta.items():
+            new_state_value = json.loads(format.json_dumps(new_state_value))
+            key = delta.reflex_delta_token + state_name
+            previous_delta = LAST_DELTA_CACHE.get(key)
+            LAST_DELTA_CACHE[key] = new_state_value
+            if previous_delta is not None and not delta.flush:
+                full_delta[state_name] = {
+                    "__patch": make_patch(previous_delta, new_state_value).patch
+                }
+            else:
+                full_delta[state_name] = new_state_value
+        return full_delta
+    return delta.data
+
+
 if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
     # If the state is this large, it's considered a performance issue.
     TOO_LARGE_SERIALIZED_STATE = environment.REFLEX_STATE_SIZE_LIMIT.get() * 1024
@@ -306,7 +432,7 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
     )
 
 
-async def _resolve_delta(delta: Delta) -> Delta:
+async def _resolve_delta(delta: StateDelta) -> StateDelta:
     """Await all coroutines in the delta.
 
     Args:
@@ -1679,7 +1805,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
         try:
             # Get the delta after processing the event.
-            delta = await _resolve_delta(state.get_delta())
+            delta = await _resolve_delta(state.get_delta(token=token))
             state._clean()
 
             return StateUpdate(
@@ -1888,9 +2014,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             if include_backend or not self.computed_vars[cvar]._backend
         }
 
-    def get_delta(self) -> Delta:
+    def get_delta(self, *, token: str | None = None) -> StateDelta:
         """Get the delta for the state.
 
+        Args:
+            token: The reflex delta
+
         Returns:
             The delta for the state.
         """
@@ -1922,7 +2051,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             delta.update(substates[substate].get_delta())
 
         # Return the delta.
-        return delta
+        return StateDelta(delta, reflex_delta_token=token)
 
     def _mark_dirty(self):
         """Mark the substate and all parent states as dirty."""
@@ -2753,7 +2882,7 @@ class StateUpdate:
     """A state update sent to the frontend."""
 
     # The state delta.
-    delta: Delta = dataclasses.field(default_factory=dict)
+    delta: StateDelta = dataclasses.field(default_factory=StateDelta)
 
     # Events to be added to the event queue.
     events: list[Event] = dataclasses.field(default_factory=list)

+ 30 - 21
tests/units/test_app.py

@@ -41,6 +41,7 @@ from reflex.state import (
     OnLoadInternalState,
     RouterData,
     State,
+    StateDelta,
     StateManagerDisk,
     StateManagerMemory,
     StateManagerRedis,
@@ -1050,14 +1051,16 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         update = await process_coro.__anext__()
         # route change (on_load_internal) triggers: [call on_load events, call set_is_hydrated(True)]
         assert update == StateUpdate(
-            delta={
-                state.get_name(): {
-                    arg_name: exp_val,
-                    f"comp_{arg_name}": exp_val,
-                    constants.CompileVars.IS_HYDRATED: False,
-                    "router": exp_router,
+            delta=StateDelta(
+                {
+                    state.get_name(): {
+                        arg_name: exp_val,
+                        f"comp_{arg_name}": exp_val,
+                        constants.CompileVars.IS_HYDRATED: False,
+                        "router": exp_router,
+                    }
                 }
-            },
+            ),
             events=[
                 _dynamic_state_event(
                     name="on_load",
@@ -1093,11 +1096,13 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         )
         on_load_update = await process_coro.__anext__()
         assert on_load_update == StateUpdate(
-            delta={
-                state.get_name(): {
-                    "loaded": exp_index + 1,
-                },
-            },
+            delta=StateDelta(
+                {
+                    state.get_name(): {
+                        "loaded": exp_index + 1,
+                    },
+                }
+            ),
             events=[],
         )
         # complete the processing
@@ -1114,11 +1119,13 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         )
         on_set_is_hydrated_update = await process_coro.__anext__()
         assert on_set_is_hydrated_update == StateUpdate(
-            delta={
-                state.get_name(): {
-                    "is_hydrated": True,
-                },
-            },
+            delta=StateDelta(
+                {
+                    state.get_name(): {
+                        "is_hydrated": True,
+                    },
+                }
+            ),
             events=[],
         )
         # complete the processing
@@ -1135,11 +1142,13 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         )
         update = await process_coro.__anext__()
         assert update == StateUpdate(
-            delta={
-                state.get_name(): {
-                    "counter": exp_index + 1,
+            delta=StateDelta(
+                {
+                    state.get_name(): {
+                        "counter": exp_index + 1,
+                    }
                 }
-            },
+            ),
             events=[],
         )
         # complete the processing

+ 52 - 39
tests/units/test_state.py

@@ -34,6 +34,7 @@ from reflex.state import (
     OnLoadInternalState,
     RouterData,
     State,
+    StateDelta,
     StateManager,
     StateManagerDisk,
     StateManagerMemory,
@@ -2012,14 +2013,16 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
     mcall = mock_app.event_namespace.emit.mock_calls[0]  # pyright: ignore [reportFunctionMemberAccess]
     assert mcall.args[0] == str(SocketEvent.EVENT)
     assert mcall.args[1] == StateUpdate(
-        delta={
-            grandchild_state.get_full_name(): {
-                "value2": "42",
-            },
-            GrandchildState3.get_full_name(): {
-                "computed": "",
-            },
-        }
+        delta=StateDelta(
+            {
+                grandchild_state.get_full_name(): {
+                    "value2": "42",
+                },
+                GrandchildState3.get_full_name(): {
+                    "computed": "",
+                },
+            }
+        )
     )
     assert mcall.kwargs["to"] == grandchild_state.router.session.session_id
 
@@ -2174,18 +2177,20 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
     ):
         # other task returns delta
         assert update == StateUpdate(
-            delta={
-                BackgroundTaskState.get_full_name(): {
-                    "order": [
-                        "background_task:start",
-                        "other",
-                    ],
-                    "computed_order": [
-                        "background_task:start",
-                        "other",
-                    ],
+            delta=StateDelta(
+                {
+                    BackgroundTaskState.get_full_name(): {
+                        "order": [
+                            "background_task:start",
+                            "other",
+                        ],
+                        "computed_order": [
+                            "background_task:start",
+                            "other",
+                        ],
+                    }
                 }
-            }
+            )
         )
 
     # Explicit wait for background tasks
@@ -2216,42 +2221,50 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
         is not None
     )
     assert first_ws_message == StateUpdate(
-        delta={
-            BackgroundTaskState.get_full_name(): {
-                "order": ["background_task:start"],
-                "computed_order": ["background_task:start"],
+        delta=StateDelta(
+            {
+                BackgroundTaskState.get_full_name(): {
+                    "order": ["background_task:start"],
+                    "computed_order": ["background_task:start"],
+                }
             }
-        },
+        ),
         events=[],
         final=True,
     )
     for call in emit_mock.mock_calls[1:5]:  # pyright: ignore [reportFunctionMemberAccess]
         assert call.args[1] == StateUpdate(
-            delta={
-                BackgroundTaskState.get_full_name(): {
-                    "computed_order": ["background_task:start"],
+            delta=StateDelta(
+                {
+                    BackgroundTaskState.get_full_name(): {
+                        "computed_order": ["background_task:start"],
+                    }
                 }
-            },
+            ),
             events=[],
             final=True,
         )
     assert emit_mock.mock_calls[-2].args[1] == StateUpdate(  # pyright: ignore [reportFunctionMemberAccess]
-        delta={
-            BackgroundTaskState.get_full_name(): {
-                "order": exp_order,
-                "computed_order": exp_order,
-                "dict_list": {},
+        delta=StateDelta(
+            {
+                BackgroundTaskState.get_full_name(): {
+                    "order": exp_order,
+                    "computed_order": exp_order,
+                    "dict_list": {},
+                }
             }
-        },
+        ),
         events=[],
         final=True,
     )
     assert emit_mock.mock_calls[-1].args[1] == StateUpdate(  # pyright: ignore [reportFunctionMemberAccess]
-        delta={
-            BackgroundTaskState.get_full_name(): {
-                "computed_order": exp_order,
-            },
-        },
+        delta=StateDelta(
+            {
+                BackgroundTaskState.get_full_name(): {
+                    "computed_order": exp_order,
+                },
+            }
+        ),
         events=[],
         final=True,
     )

部分文件因文件數量過多而無法顯示