Bladeren bron

json patch

Khaleel Al-Adhami 2 maanden geleden
bovenliggende
commit
c3e6644221

File diff suppressed because it is too large
+ 71 - 124
poetry.lock


+ 1 - 0
pyproject.toml

@@ -48,6 +48,7 @@ twine = ">=4.0.0,<7.0"
 tomlkit = ">=0.12.4,<1.0"
 tomlkit = ">=0.12.4,<1.0"
 lazy_loader = ">=0.4"
 lazy_loader = ">=0.4"
 typing_extensions = ">=4.6.0"
 typing_extensions = ">=4.6.0"
+jsonpatch = "^1.33"
 
 
 [tool.poetry.group.dev.dependencies]
 [tool.poetry.group.dev.dependencies]
 pytest = ">=7.1.2,<9.0"
 pytest = ">=7.1.2,<9.0"

+ 29 - 18
reflex/.templates/web/utils/state.js

@@ -16,6 +16,7 @@ import {
 } from "$/utils/context.js";
 } from "$/utils/context.js";
 import debounce from "$/utils/helpers/debounce";
 import debounce from "$/utils/helpers/debounce";
 import throttle from "$/utils/helpers/throttle";
 import throttle from "$/utils/helpers/throttle";
+import { applyPatch } from "fast-json-patch/index.mjs";
 
 
 // Endpoint URLs.
 // Endpoint URLs.
 const EVENTURL = env.EVENT;
 const EVENTURL = env.EVENT;
@@ -227,8 +228,8 @@ export const applyEvent = async (event, socket) => {
       a.href = eval?.(
       a.href = eval?.(
         event.payload.url.replace(
         event.payload.url.replace(
           "getBackendURL(env.UPLOAD)",
           "getBackendURL(env.UPLOAD)",
-          `"${getBackendURL(env.UPLOAD)}"`,
-        ),
+          `"${getBackendURL(env.UPLOAD)}"`
+        )
       );
       );
     }
     }
     a.download = event.payload.filename;
     a.download = event.payload.filename;
@@ -341,7 +342,7 @@ export const applyRestEvent = async (event, socket) => {
       event.payload.files,
       event.payload.files,
       event.payload.upload_id,
       event.payload.upload_id,
       event.payload.on_upload_progress,
       event.payload.on_upload_progress,
-      socket,
+      socket
     );
     );
     return false;
     return false;
   }
   }
@@ -408,7 +409,7 @@ export const connect = async (
   dispatch,
   dispatch,
   transports,
   transports,
   setConnectErrors,
   setConnectErrors,
-  client_storage = {},
+  client_storage = {}
 ) => {
 ) => {
   // Get backend URL object from the endpoint.
   // Get backend URL object from the endpoint.
   const endpoint = getBackendURL(EVENTURL);
   const endpoint = getBackendURL(EVENTURL);
@@ -464,10 +465,20 @@ export const connect = async (
     window.removeEventListener("pagehide", pagehideHandler);
     window.removeEventListener("pagehide", pagehideHandler);
   });
   });
 
 
+  const last_substate_info = {};
+
   // On each received message, queue the updates and events.
   // On each received message, queue the updates and events.
   socket.current.on("event", async (update) => {
   socket.current.on("event", async (update) => {
     for (const substate in update.delta) {
     for (const substate in update.delta) {
-      dispatch[substate](update.delta[substate]);
+      console.log(last_substate_info[substate]);
+      const new_substate_info = update.delta[substate].__patch
+        ? applyPatch(
+            last_substate_info[substate],
+            update.delta[substate].__patch
+          ).newDocument
+        : update.delta[substate];
+      last_substate_info[substate] = new_substate_info;
+      dispatch[substate](new_substate_info);
     }
     }
     applyClientStorageDelta(client_storage, update.delta);
     applyClientStorageDelta(client_storage, update.delta);
     event_processing = !update.final;
     event_processing = !update.final;
@@ -499,7 +510,7 @@ export const uploadFiles = async (
   files,
   files,
   upload_id,
   upload_id,
   on_upload_progress,
   on_upload_progress,
-  socket,
+  socket
 ) => {
 ) => {
   // return if there's no file to upload
   // return if there's no file to upload
   if (files === undefined || files.length === 0) {
   if (files === undefined || files.length === 0) {
@@ -604,7 +615,7 @@ export const Event = (
   name,
   name,
   payload = {},
   payload = {},
   event_actions = {},
   event_actions = {},
-  handler = null,
+  handler = null
 ) => {
 ) => {
   return { name, payload, handler, event_actions };
   return { name, payload, handler, event_actions };
 };
 };
@@ -631,7 +642,7 @@ export const hydrateClientStorage = (client_storage) => {
     for (const state_key in client_storage.local_storage) {
     for (const state_key in client_storage.local_storage) {
       const options = client_storage.local_storage[state_key];
       const options = client_storage.local_storage[state_key];
       const local_storage_value = localStorage.getItem(
       const local_storage_value = localStorage.getItem(
-        options.name || state_key,
+        options.name || state_key
       );
       );
       if (local_storage_value !== null) {
       if (local_storage_value !== null) {
         client_storage_values[state_key] = local_storage_value;
         client_storage_values[state_key] = local_storage_value;
@@ -642,7 +653,7 @@ export const hydrateClientStorage = (client_storage) => {
     for (const state_key in client_storage.session_storage) {
     for (const state_key in client_storage.session_storage) {
       const session_options = client_storage.session_storage[state_key];
       const session_options = client_storage.session_storage[state_key];
       const session_storage_value = sessionStorage.getItem(
       const session_storage_value = sessionStorage.getItem(
-        session_options.name || state_key,
+        session_options.name || state_key
       );
       );
       if (session_storage_value != null) {
       if (session_storage_value != null) {
         client_storage_values[state_key] = session_storage_value;
         client_storage_values[state_key] = session_storage_value;
@@ -667,7 +678,7 @@ export const hydrateClientStorage = (client_storage) => {
 const applyClientStorageDelta = (client_storage, delta) => {
 const applyClientStorageDelta = (client_storage, delta) => {
   // find the main state and check for is_hydrated
   // find the main state and check for is_hydrated
   const unqualified_states = Object.keys(delta).filter(
   const unqualified_states = Object.keys(delta).filter(
-    (key) => key.split(".").length === 1,
+    (key) => key.split(".").length === 1
   );
   );
   if (unqualified_states.length === 1) {
   if (unqualified_states.length === 1) {
     const main_state = delta[unqualified_states[0]];
     const main_state = delta[unqualified_states[0]];
@@ -701,7 +712,7 @@ const applyClientStorageDelta = (client_storage, delta) => {
         const session_options = client_storage.session_storage[state_key];
         const session_options = client_storage.session_storage[state_key];
         sessionStorage.setItem(
         sessionStorage.setItem(
           session_options.name || state_key,
           session_options.name || state_key,
-          delta[substate][key],
+          delta[substate][key]
         );
         );
       }
       }
     }
     }
@@ -721,7 +732,7 @@ const applyClientStorageDelta = (client_storage, delta) => {
 export const useEventLoop = (
 export const useEventLoop = (
   dispatch,
   dispatch,
   initial_events = () => [],
   initial_events = () => [],
-  client_storage = {},
+  client_storage = {}
 ) => {
 ) => {
   const socket = useRef(null);
   const socket = useRef(null);
   const router = useRouter();
   const router = useRouter();
@@ -735,7 +746,7 @@ export const useEventLoop = (
 
 
     event_actions = events.reduce(
     event_actions = events.reduce(
       (acc, e) => ({ ...acc, ...e.event_actions }),
       (acc, e) => ({ ...acc, ...e.event_actions }),
-      event_actions ?? {},
+      event_actions ?? {}
     );
     );
 
 
     const _e = args.filter((o) => o?.preventDefault !== undefined)[0];
     const _e = args.filter((o) => o?.preventDefault !== undefined)[0];
@@ -763,7 +774,7 @@ export const useEventLoop = (
       debounce(
       debounce(
         combined_name,
         combined_name,
         () => queueEvents(events, socket),
         () => queueEvents(events, socket),
-        event_actions.debounce,
+        event_actions.debounce
       );
       );
     } else {
     } else {
       queueEvents(events, socket);
       queueEvents(events, socket);
@@ -782,7 +793,7 @@ export const useEventLoop = (
             query,
             query,
             asPath,
             asPath,
           }))(router),
           }))(router),
-        })),
+        }))
       );
       );
       sentHydrate.current = true;
       sentHydrate.current = true;
     }
     }
@@ -828,7 +839,7 @@ export const useEventLoop = (
           dispatch,
           dispatch,
           ["websocket"],
           ["websocket"],
           setConnectErrors,
           setConnectErrors,
-          client_storage,
+          client_storage
         );
         );
       }
       }
     }
     }
@@ -876,7 +887,7 @@ export const useEventLoop = (
         vars[storage_to_state_map[e.key]] = e.newValue;
         vars[storage_to_state_map[e.key]] = e.newValue;
         const event = Event(
         const event = Event(
           `${state_name}.reflex___state____update_vars_internal_state.update_vars_internal`,
           `${state_name}.reflex___state____update_vars_internal_state.update_vars_internal`,
-          { vars: vars },
+          { vars: vars }
         );
         );
         addEvents([event], e);
         addEvents([event], e);
       }
       }
@@ -969,7 +980,7 @@ export const getRefValues = (refs) => {
   return refs.map((ref) =>
   return refs.map((ref) =>
     ref.current
     ref.current
       ? ref.current.value || ref.current.getAttribute("aria-valuenow")
       ? ref.current.value || ref.current.getAttribute("aria-valuenow")
-      : null,
+      : null
   );
   );
 };
 };
 
 

+ 1 - 1
reflex/app.py

@@ -1407,7 +1407,7 @@ class App(MiddlewareMixin, LifespanMixin):
         async with self.state_manager.modify_state(token) as state:
         async with self.state_manager.modify_state(token) as state:
             # No other event handler can modify the state while in this context.
             # No other event handler can modify the state while in this context.
             yield state
             yield state
-            delta = state.get_delta()
+            delta = state.get_delta(token=token)
             if delta:
             if delta:
                 # When the state is modified reset dirty status and emit the delta to the frontend.
                 # When the state is modified reset dirty status and emit the delta to the frontend.
                 state._clean()
                 state._clean()

+ 4 - 4
reflex/compiler/utils.py

@@ -27,7 +27,7 @@ from reflex.components.base import (
 )
 )
 from reflex.components.component import Component, ComponentStyle, CustomComponent
 from reflex.components.component import Component, ComponentStyle, CustomComponent
 from reflex.istate.storage import Cookie, LocalStorage, SessionStorage
 from reflex.istate.storage import Cookie, LocalStorage, SessionStorage
-from reflex.state import BaseState, _resolve_delta
+from reflex.state import BaseState, StateDelta, _resolve_delta
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import console, format, imports, path_ops
 from reflex.utils import console, format, imports, path_ops
 from reflex.utils.exec import is_in_app_harness
 from reflex.utils.exec import is_in_app_harness
@@ -187,7 +187,7 @@ def compile_state(state: Type[BaseState]) -> dict:
     Returns:
     Returns:
         A dictionary of the compiled state.
         A dictionary of the compiled state.
     """
     """
-    initial_state = state(_reflex_internal_init=True).dict(initial=True)
+    initial_state = StateDelta(state(_reflex_internal_init=True).dict(initial=True))
     try:
     try:
         _ = asyncio.get_running_loop()
         _ = asyncio.get_running_loop()
     except RuntimeError:
     except RuntimeError:
@@ -202,10 +202,10 @@ def compile_state(state: Type[BaseState]) -> dict:
                 console.warn(
                 console.warn(
                     f"Had to get initial state in a thread 🤮 {resolved_initial_state}",
                     f"Had to get initial state in a thread 🤮 {resolved_initial_state}",
                 )
                 )
-                return resolved_initial_state
+                return resolved_initial_state.data
 
 
     # Normally the compile runs before any event loop starts, we asyncio.run is available for calling.
     # Normally the compile runs before any event loop starts, we asyncio.run is available for calling.
-    return asyncio.run(_resolve_delta(initial_state))
+    return asyncio.run(_resolve_delta(initial_state)).data
 
 
 
 
 def _compile_client_storage_field(
 def _compile_client_storage_field(

+ 3 - 0
reflex/config.py

@@ -718,6 +718,9 @@ class EnvironmentVariables:
     # Used by flexgen to enumerate the pages.
     # Used by flexgen to enumerate the pages.
     REFLEX_ADD_ALL_ROUTES_ENDPOINT: EnvVar[bool] = env_var(False)
     REFLEX_ADD_ALL_ROUTES_ENDPOINT: EnvVar[bool] = env_var(False)
 
 
+    # Use the JSON patch format for websocket messages.
+    REFLEX_USE_JSON_PATCH: EnvVar[bool] = env_var(False)
+
 
 
 environment = EnvironmentVariables()
 environment = EnvironmentVariables()
 
 

+ 1 - 0
reflex/constants/installer.py

@@ -188,6 +188,7 @@ class PackageJson(SimpleNamespace):
         "react-dom": "19.0.0",
         "react-dom": "19.0.0",
         "react-focus-lock": "2.13.6",
         "react-focus-lock": "2.13.6",
         "socket.io-client": "4.8.1",
         "socket.io-client": "4.8.1",
+        "fast-json-patch": "3.1.1",
         "universal-cookie": "7.2.2",
         "universal-cookie": "7.2.2",
     }
     }
     DEV_DEPENDENCIES = {
     DEV_DEPENDENCIES = {

+ 8 - 2
reflex/middleware/hydrate_middleware.py

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING
 from reflex import constants
 from reflex import constants
 from reflex.event import Event, get_hydrate_event
 from reflex.event import Event, get_hydrate_event
 from reflex.middleware.middleware import Middleware
 from reflex.middleware.middleware import Middleware
-from reflex.state import BaseState, StateUpdate, _resolve_delta
+from reflex.state import BaseState, StateDelta, StateUpdate, _resolve_delta
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from reflex.app import App
     from reflex.app import App
@@ -42,7 +42,13 @@ class HydrateMiddleware(Middleware):
         setattr(state, constants.CompileVars.IS_HYDRATED, False)
         setattr(state, constants.CompileVars.IS_HYDRATED, False)
 
 
         # Get the initial state.
         # Get the initial state.
-        delta = await _resolve_delta(state.dict())
+        delta = await _resolve_delta(
+            StateDelta(
+                state.dict(),
+                reflex_delta_token=state.router.session.client_token,
+                flush=True,
+            )
+        )
         # since a full dict was captured, clean any dirtiness
         # since a full dict was captured, clean any dirtiness
         state._clean()
         state._clean()
 
 

+ 135 - 6
reflex/state.py

@@ -40,6 +40,7 @@ from typing import (
 
 
 import pydantic.v1 as pydantic
 import pydantic.v1 as pydantic
 import wrapt
 import wrapt
+from jsonpatch import make_patch
 from pydantic import BaseModel as BaseModelV2
 from pydantic import BaseModel as BaseModelV2
 from pydantic.v1 import BaseModel as BaseModelV1
 from pydantic.v1 import BaseModel as BaseModelV1
 from pydantic.v1 import validator
 from pydantic.v1 import validator
@@ -108,10 +109,135 @@ if TYPE_CHECKING:
     from reflex.components.component import Component
     from reflex.components.component import Component
 
 
 
 
-Delta = dict[str, Any]
 var = computed_var
 var = computed_var
 
 
 
 
+@dataclasses.dataclass
+class StateDelta:
+    """A dictionary representing the state delta."""
+
+    data: dict[str, Any] = dataclasses.field(default_factory=dict)
+    reflex_delta_token: str | None = dataclasses.field(default=None)
+    flush: bool = dataclasses.field(default=False)
+
+    def __getitem__(self, key: str) -> Any:
+        """Get the item from the delta.
+
+        Args:
+            key: The key to get.
+
+        Returns:
+            The item from the delta.
+        """
+        return self.data[key]
+
+    def __setitem__(self, key: str, value: Any):
+        """Set the item in the delta.
+
+        Args:
+            key: The key to set.
+            value: The value to set.
+        """
+        self.data[key] = value
+
+    def __delitem__(self, key: str):
+        """Delete the item from the delta.
+
+        Args:
+            key: The key to delete.
+        """
+        del self.data[key]
+
+    def __iter__(self) -> Any:
+        """Iterate over the delta.
+
+        Returns:
+            The iterator over the delta.
+        """
+        return iter(self.data)
+
+    def __len__(self) -> int:
+        """Get the length of the delta.
+
+        Returns:
+            The length of the delta.
+        """
+        return len(self.data)
+
+    def __contains__(self, key: str) -> bool:
+        """Check if the delta contains the key.
+
+        Args:
+            key: The key to check.
+
+        Returns:
+            Whether the delta contains the key.
+        """
+        return key in self.data
+
+    def keys(self):
+        """Get the keys of the delta.
+
+        Returns:
+            The keys of the delta.
+        """
+        return self.data.keys()
+
+    def __reversed__(self):
+        """Reverse the delta.
+
+        Returns:
+            The reversed delta.
+        """
+        return reversed(self.data)
+
+    def values(self):
+        """Get the values of the delta.
+
+        Returns:
+            The values of the delta.
+        """
+        return self.data.values()
+
+    def items(self):
+        """Get the items of the delta.
+
+        Returns:
+            The items of the delta.
+        """
+        return self.data.items()
+
+
+LAST_DELTA_CACHE: dict[str, StateDelta] = {}
+
+
+@serializer(to=dict)
+def serialize_state_delta(delta: StateDelta) -> dict[str, Any]:
+    """Serialize the state delta.
+
+    Args:
+        delta: The state delta to serialize.
+
+    Returns:
+        The serialized state delta.
+    """
+    if delta.reflex_delta_token is not None and environment.REFLEX_USE_JSON_PATCH.get():
+        full_delta = {}
+        for state_name, new_state_value in delta.items():
+            new_state_value = json.loads(format.json_dumps(new_state_value))
+            key = delta.reflex_delta_token + state_name
+            previous_delta = LAST_DELTA_CACHE.get(key)
+            LAST_DELTA_CACHE[key] = new_state_value
+            if previous_delta is not None and not delta.flush:
+                full_delta[state_name] = {
+                    "__patch": make_patch(previous_delta, new_state_value).patch
+                }
+            else:
+                full_delta[state_name] = new_state_value
+        return full_delta
+    return delta.data
+
+
 if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
 if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
     # If the state is this large, it's considered a performance issue.
     # If the state is this large, it's considered a performance issue.
     TOO_LARGE_SERIALIZED_STATE = environment.REFLEX_STATE_SIZE_LIMIT.get() * 1024
     TOO_LARGE_SERIALIZED_STATE = environment.REFLEX_STATE_SIZE_LIMIT.get() * 1024
@@ -306,7 +432,7 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
     )
     )
 
 
 
 
-async def _resolve_delta(delta: Delta) -> Delta:
+async def _resolve_delta(delta: StateDelta) -> StateDelta:
     """Await all coroutines in the delta.
     """Await all coroutines in the delta.
 
 
     Args:
     Args:
@@ -1679,7 +1805,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
 
         try:
         try:
             # Get the delta after processing the event.
             # Get the delta after processing the event.
-            delta = await _resolve_delta(state.get_delta())
+            delta = await _resolve_delta(state.get_delta(token=token))
             state._clean()
             state._clean()
 
 
             return StateUpdate(
             return StateUpdate(
@@ -1888,9 +2014,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             if include_backend or not self.computed_vars[cvar]._backend
             if include_backend or not self.computed_vars[cvar]._backend
         }
         }
 
 
-    def get_delta(self) -> Delta:
+    def get_delta(self, *, token: str | None = None) -> StateDelta:
         """Get the delta for the state.
         """Get the delta for the state.
 
 
+        Args:
+            token: The reflex delta
+
         Returns:
         Returns:
             The delta for the state.
             The delta for the state.
         """
         """
@@ -1922,7 +2051,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             delta.update(substates[substate].get_delta())
             delta.update(substates[substate].get_delta())
 
 
         # Return the delta.
         # Return the delta.
-        return delta
+        return StateDelta(delta, reflex_delta_token=token)
 
 
     def _mark_dirty(self):
     def _mark_dirty(self):
         """Mark the substate and all parent states as dirty."""
         """Mark the substate and all parent states as dirty."""
@@ -2753,7 +2882,7 @@ class StateUpdate:
     """A state update sent to the frontend."""
     """A state update sent to the frontend."""
 
 
     # The state delta.
     # The state delta.
-    delta: Delta = dataclasses.field(default_factory=dict)
+    delta: StateDelta = dataclasses.field(default_factory=StateDelta)
 
 
     # Events to be added to the event queue.
     # Events to be added to the event queue.
     events: list[Event] = dataclasses.field(default_factory=list)
     events: list[Event] = dataclasses.field(default_factory=list)

+ 30 - 21
tests/units/test_app.py

@@ -41,6 +41,7 @@ from reflex.state import (
     OnLoadInternalState,
     OnLoadInternalState,
     RouterData,
     RouterData,
     State,
     State,
+    StateDelta,
     StateManagerDisk,
     StateManagerDisk,
     StateManagerMemory,
     StateManagerMemory,
     StateManagerRedis,
     StateManagerRedis,
@@ -1050,14 +1051,16 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         update = await process_coro.__anext__()
         update = await process_coro.__anext__()
         # route change (on_load_internal) triggers: [call on_load events, call set_is_hydrated(True)]
         # route change (on_load_internal) triggers: [call on_load events, call set_is_hydrated(True)]
         assert update == StateUpdate(
         assert update == StateUpdate(
-            delta={
-                state.get_name(): {
-                    arg_name: exp_val,
-                    f"comp_{arg_name}": exp_val,
-                    constants.CompileVars.IS_HYDRATED: False,
-                    "router": exp_router,
+            delta=StateDelta(
+                {
+                    state.get_name(): {
+                        arg_name: exp_val,
+                        f"comp_{arg_name}": exp_val,
+                        constants.CompileVars.IS_HYDRATED: False,
+                        "router": exp_router,
+                    }
                 }
                 }
-            },
+            ),
             events=[
             events=[
                 _dynamic_state_event(
                 _dynamic_state_event(
                     name="on_load",
                     name="on_load",
@@ -1093,11 +1096,13 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         )
         )
         on_load_update = await process_coro.__anext__()
         on_load_update = await process_coro.__anext__()
         assert on_load_update == StateUpdate(
         assert on_load_update == StateUpdate(
-            delta={
-                state.get_name(): {
-                    "loaded": exp_index + 1,
-                },
-            },
+            delta=StateDelta(
+                {
+                    state.get_name(): {
+                        "loaded": exp_index + 1,
+                    },
+                }
+            ),
             events=[],
             events=[],
         )
         )
         # complete the processing
         # complete the processing
@@ -1114,11 +1119,13 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         )
         )
         on_set_is_hydrated_update = await process_coro.__anext__()
         on_set_is_hydrated_update = await process_coro.__anext__()
         assert on_set_is_hydrated_update == StateUpdate(
         assert on_set_is_hydrated_update == StateUpdate(
-            delta={
-                state.get_name(): {
-                    "is_hydrated": True,
-                },
-            },
+            delta=StateDelta(
+                {
+                    state.get_name(): {
+                        "is_hydrated": True,
+                    },
+                }
+            ),
             events=[],
             events=[],
         )
         )
         # complete the processing
         # complete the processing
@@ -1135,11 +1142,13 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         )
         )
         update = await process_coro.__anext__()
         update = await process_coro.__anext__()
         assert update == StateUpdate(
         assert update == StateUpdate(
-            delta={
-                state.get_name(): {
-                    "counter": exp_index + 1,
+            delta=StateDelta(
+                {
+                    state.get_name(): {
+                        "counter": exp_index + 1,
+                    }
                 }
                 }
-            },
+            ),
             events=[],
             events=[],
         )
         )
         # complete the processing
         # complete the processing

+ 52 - 39
tests/units/test_state.py

@@ -34,6 +34,7 @@ from reflex.state import (
     OnLoadInternalState,
     OnLoadInternalState,
     RouterData,
     RouterData,
     State,
     State,
+    StateDelta,
     StateManager,
     StateManager,
     StateManagerDisk,
     StateManagerDisk,
     StateManagerMemory,
     StateManagerMemory,
@@ -2012,14 +2013,16 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
     mcall = mock_app.event_namespace.emit.mock_calls[0]  # pyright: ignore [reportFunctionMemberAccess]
     mcall = mock_app.event_namespace.emit.mock_calls[0]  # pyright: ignore [reportFunctionMemberAccess]
     assert mcall.args[0] == str(SocketEvent.EVENT)
     assert mcall.args[0] == str(SocketEvent.EVENT)
     assert mcall.args[1] == StateUpdate(
     assert mcall.args[1] == StateUpdate(
-        delta={
-            grandchild_state.get_full_name(): {
-                "value2": "42",
-            },
-            GrandchildState3.get_full_name(): {
-                "computed": "",
-            },
-        }
+        delta=StateDelta(
+            {
+                grandchild_state.get_full_name(): {
+                    "value2": "42",
+                },
+                GrandchildState3.get_full_name(): {
+                    "computed": "",
+                },
+            }
+        )
     )
     )
     assert mcall.kwargs["to"] == grandchild_state.router.session.session_id
     assert mcall.kwargs["to"] == grandchild_state.router.session.session_id
 
 
@@ -2174,18 +2177,20 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
     ):
     ):
         # other task returns delta
         # other task returns delta
         assert update == StateUpdate(
         assert update == StateUpdate(
-            delta={
-                BackgroundTaskState.get_full_name(): {
-                    "order": [
-                        "background_task:start",
-                        "other",
-                    ],
-                    "computed_order": [
-                        "background_task:start",
-                        "other",
-                    ],
+            delta=StateDelta(
+                {
+                    BackgroundTaskState.get_full_name(): {
+                        "order": [
+                            "background_task:start",
+                            "other",
+                        ],
+                        "computed_order": [
+                            "background_task:start",
+                            "other",
+                        ],
+                    }
                 }
                 }
-            }
+            )
         )
         )
 
 
     # Explicit wait for background tasks
     # Explicit wait for background tasks
@@ -2216,42 +2221,50 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
         is not None
         is not None
     )
     )
     assert first_ws_message == StateUpdate(
     assert first_ws_message == StateUpdate(
-        delta={
-            BackgroundTaskState.get_full_name(): {
-                "order": ["background_task:start"],
-                "computed_order": ["background_task:start"],
+        delta=StateDelta(
+            {
+                BackgroundTaskState.get_full_name(): {
+                    "order": ["background_task:start"],
+                    "computed_order": ["background_task:start"],
+                }
             }
             }
-        },
+        ),
         events=[],
         events=[],
         final=True,
         final=True,
     )
     )
     for call in emit_mock.mock_calls[1:5]:  # pyright: ignore [reportFunctionMemberAccess]
     for call in emit_mock.mock_calls[1:5]:  # pyright: ignore [reportFunctionMemberAccess]
         assert call.args[1] == StateUpdate(
         assert call.args[1] == StateUpdate(
-            delta={
-                BackgroundTaskState.get_full_name(): {
-                    "computed_order": ["background_task:start"],
+            delta=StateDelta(
+                {
+                    BackgroundTaskState.get_full_name(): {
+                        "computed_order": ["background_task:start"],
+                    }
                 }
                 }
-            },
+            ),
             events=[],
             events=[],
             final=True,
             final=True,
         )
         )
     assert emit_mock.mock_calls[-2].args[1] == StateUpdate(  # pyright: ignore [reportFunctionMemberAccess]
     assert emit_mock.mock_calls[-2].args[1] == StateUpdate(  # pyright: ignore [reportFunctionMemberAccess]
-        delta={
-            BackgroundTaskState.get_full_name(): {
-                "order": exp_order,
-                "computed_order": exp_order,
-                "dict_list": {},
+        delta=StateDelta(
+            {
+                BackgroundTaskState.get_full_name(): {
+                    "order": exp_order,
+                    "computed_order": exp_order,
+                    "dict_list": {},
+                }
             }
             }
-        },
+        ),
         events=[],
         events=[],
         final=True,
         final=True,
     )
     )
     assert emit_mock.mock_calls[-1].args[1] == StateUpdate(  # pyright: ignore [reportFunctionMemberAccess]
     assert emit_mock.mock_calls[-1].args[1] == StateUpdate(  # pyright: ignore [reportFunctionMemberAccess]
-        delta={
-            BackgroundTaskState.get_full_name(): {
-                "computed_order": exp_order,
-            },
-        },
+        delta=StateDelta(
+            {
+                BackgroundTaskState.get_full_name(): {
+                    "computed_order": exp_order,
+                },
+            }
+        ),
         events=[],
         events=[],
         final=True,
         final=True,
     )
     )

Some files were not shown because too many files changed in this diff