Explorar o código

feat: Synchronizing localStorage between tabs using browser events (#2533)

* feat: Synchronizing localStorage between tabs using browser events

* test_client_storage: Test sync'd local storage vars

* update_vars_internal: generic handler to apply var changes to state tree

Apply fully qualified var names to each substate they are associated with. This
allows consistent updates to arbitrary state vars without having to know their
"setter" arguments, in case the user has overwritted the `set_x` name.

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
abulvenz hai 1 ano
pai
achega
980834605b

+ 55 - 10
integration/test_client_storage.py

@@ -41,6 +41,13 @@ def ClientSide():
         l3: str = rx.LocalStorage(name="l3")
         l3: str = rx.LocalStorage(name="l3")
         l4: str = rx.LocalStorage("l4 default")
         l4: str = rx.LocalStorage("l4 default")
 
 
+        # Sync'd local storage
+        l5: str = rx.LocalStorage(sync=True)
+        l6: str = rx.LocalStorage(sync=True, name="l6")
+
+        def set_l6(self, my_param: str):
+            self.l6 = my_param
+
         def set_var(self):
         def set_var(self):
             setattr(self, self.state_var, self.input_value)
             setattr(self, self.state_var, self.input_value)
             self.state_var = self.input_value = ""
             self.state_var = self.input_value = ""
@@ -93,6 +100,8 @@ def ClientSide():
             rx.box(ClientSideSubState.l2, id="l2"),
             rx.box(ClientSideSubState.l2, id="l2"),
             rx.box(ClientSideSubState.l3, id="l3"),
             rx.box(ClientSideSubState.l3, id="l3"),
             rx.box(ClientSideSubState.l4, id="l4"),
             rx.box(ClientSideSubState.l4, id="l4"),
+            rx.box(ClientSideSubState.l5, id="l5"),
+            rx.box(ClientSideSubState.l6, id="l6"),
             rx.box(ClientSideSubSubState.c1s, id="c1s"),
             rx.box(ClientSideSubSubState.c1s, id="c1s"),
             rx.box(ClientSideSubSubState.l1s, id="l1s"),
             rx.box(ClientSideSubSubState.l1s, id="l1s"),
         )
         )
@@ -191,33 +200,44 @@ async def test_client_side_state(
     """
     """
     assert client_side.app_instance is not None
     assert client_side.app_instance is not None
     assert client_side.frontend_url is not None
     assert client_side.frontend_url is not None
-    token_input = driver.find_element(By.ID, "token")
-    assert token_input
 
 
-    # wait for the backend connection to send the token
-    token = client_side.poll_for_value(token_input)
-    assert token is not None
+    def poll_for_token():
+        token_input = driver.find_element(By.ID, "token")
+        assert token_input
 
 
-    # get a reference to the cookie manipulation form
-    state_var_input = driver.find_element(By.ID, "state_var")
-    input_value_input = driver.find_element(By.ID, "input_value")
-    set_sub_state_button = driver.find_element(By.ID, "set_sub_state")
-    set_sub_sub_state_button = driver.find_element(By.ID, "set_sub_sub_state")
+        # wait for the backend connection to send the token
+        token = client_side.poll_for_value(token_input)
+        assert token is not None
+        return token
 
 
     def set_sub(var: str, value: str):
     def set_sub(var: str, value: str):
+        # Get a reference to the cookie manipulation form.
+        state_var_input = driver.find_element(By.ID, "state_var")
+        input_value_input = driver.find_element(By.ID, "input_value")
+        set_sub_state_button = driver.find_element(By.ID, "set_sub_state")
         AppHarness._poll_for(lambda: state_var_input.get_attribute("value") == "")
         AppHarness._poll_for(lambda: state_var_input.get_attribute("value") == "")
         AppHarness._poll_for(lambda: input_value_input.get_attribute("value") == "")
         AppHarness._poll_for(lambda: input_value_input.get_attribute("value") == "")
+
+        # Set the values.
         state_var_input.send_keys(var)
         state_var_input.send_keys(var)
         input_value_input.send_keys(value)
         input_value_input.send_keys(value)
         set_sub_state_button.click()
         set_sub_state_button.click()
 
 
     def set_sub_sub(var: str, value: str):
     def set_sub_sub(var: str, value: str):
+        # Get a reference to the cookie manipulation form.
+        state_var_input = driver.find_element(By.ID, "state_var")
+        input_value_input = driver.find_element(By.ID, "input_value")
+        set_sub_sub_state_button = driver.find_element(By.ID, "set_sub_sub_state")
         AppHarness._poll_for(lambda: state_var_input.get_attribute("value") == "")
         AppHarness._poll_for(lambda: state_var_input.get_attribute("value") == "")
         AppHarness._poll_for(lambda: input_value_input.get_attribute("value") == "")
         AppHarness._poll_for(lambda: input_value_input.get_attribute("value") == "")
+
+        # Set the values.
         state_var_input.send_keys(var)
         state_var_input.send_keys(var)
         input_value_input.send_keys(value)
         input_value_input.send_keys(value)
         set_sub_sub_state_button.click()
         set_sub_sub_state_button.click()
 
 
+    token = poll_for_token()
+
     # get a reference to all cookie and local storage elements
     # get a reference to all cookie and local storage elements
     c1 = driver.find_element(By.ID, "c1")
     c1 = driver.find_element(By.ID, "c1")
     c2 = driver.find_element(By.ID, "c2")
     c2 = driver.find_element(By.ID, "c2")
@@ -485,6 +505,31 @@ async def test_client_side_state(
         "value": "c5%20value",
         "value": "c5%20value",
     }
     }
 
 
+    # Open a new tab to check that sync'd local storage is working
+    main_tab = driver.window_handles[0]
+    driver.switch_to.new_window("window")
+    driver.get(client_side.frontend_url)
+
+    # New tab should have a different state token.
+    assert poll_for_token() != token
+
+    # Set values and check them in the new tab.
+    set_sub("l5", "l5 value")
+    set_sub("l6", "l6 value")
+    l5 = driver.find_element(By.ID, "l5")
+    l6 = driver.find_element(By.ID, "l6")
+    assert l5.text == "l5 value"
+    assert l6.text == "l6 value"
+
+    # Switch back to main window.
+    driver.switch_to.window(main_tab)
+
+    # The values should have updated automatically.
+    l5 = driver.find_element(By.ID, "l5")
+    l6 = driver.find_element(By.ID, "l6")
+    assert l5.text == "l5 value"
+    assert l6.text == "l6 value"
+
     # clear the cookie jar and local storage, ensure state reset to default
     # clear the cookie jar and local storage, ensure state reset to default
     driver.delete_all_cookies()
     driver.delete_all_cookies()
     local_storage.clear()
     local_storage.clear()

+ 8 - 2
reflex/.templates/jinja/web/utils/context.js.jinja2

@@ -23,13 +23,19 @@ export const clientStorage = {}
 {% endif %}
 {% endif %}
 
 
 {% if state_name %}
 {% if state_name %}
-export const onLoadInternalEvent = () => [Event('{{state_name}}.{{const.on_load_internal}}')]
+export const state_name = "{{state_name}}"
+export const onLoadInternalEvent = () => [
+    Event('{{state_name}}.{{const.update_vars_internal}}', {vars: hydrateClientStorage(clientStorage)}),
+    Event('{{state_name}}.{{const.on_load_internal}}')
+]
 
 
 export const initialEvents = () => [
 export const initialEvents = () => [
-    Event('{{state_name}}.{{const.hydrate}}', hydrateClientStorage(clientStorage)),
+    Event('{{state_name}}.{{const.hydrate}}'),
     ...onLoadInternalEvent()
     ...onLoadInternalEvent()
 ]
 ]
 {% else %}
 {% else %}
+export const state_name = undefined
+
 export const onLoadInternalEvent = () => []
 export const onLoadInternalEvent = () => []
 
 
 export const initialEvents = () => []
 export const initialEvents = () => []

+ 34 - 7
reflex/.templates/web/utils/state.js

@@ -6,7 +6,7 @@ import env from "/env.json";
 import Cookies from "universal-cookie";
 import Cookies from "universal-cookie";
 import { useEffect, useReducer, useRef, useState } from "react";
 import { useEffect, useReducer, useRef, useState } from "react";
 import Router, { useRouter } from "next/router";
 import Router, { useRouter } from "next/router";
-import { initialEvents, initialState, onLoadInternalEvent } from "utils/context.js"
+import { initialEvents, initialState, onLoadInternalEvent, state_name } from "utils/context.js"
 
 
 // Endpoint URLs.
 // Endpoint URLs.
 const EVENTURL = env.EVENT
 const EVENTURL = env.EVENT
@@ -441,17 +441,14 @@ export const Event = (name, payload = {}, handler = null) => {
  * @returns payload dict of client storage values
  * @returns payload dict of client storage values
  */
  */
 export const hydrateClientStorage = (client_storage) => {
 export const hydrateClientStorage = (client_storage) => {
-  const client_storage_values = {
-    "cookies": {},
-    "local_storage": {}
-  }
+  const client_storage_values = {}
   if (client_storage.cookies) {
   if (client_storage.cookies) {
     for (const state_key in client_storage.cookies) {
     for (const state_key in client_storage.cookies) {
       const cookie_options = client_storage.cookies[state_key]
       const cookie_options = client_storage.cookies[state_key]
       const cookie_name = cookie_options.name || state_key
       const cookie_name = cookie_options.name || state_key
       const cookie_value = cookies.get(cookie_name)
       const cookie_value = cookies.get(cookie_name)
       if (cookie_value !== undefined) {
       if (cookie_value !== undefined) {
-        client_storage_values.cookies[state_key] = cookies.get(cookie_name)
+        client_storage_values[state_key] = cookies.get(cookie_name)
       }
       }
     }
     }
   }
   }
@@ -460,7 +457,7 @@ export const hydrateClientStorage = (client_storage) => {
       const options = client_storage.local_storage[state_key]
       const options = client_storage.local_storage[state_key]
       const local_storage_value = localStorage.getItem(options.name || state_key)
       const local_storage_value = localStorage.getItem(options.name || state_key)
       if (local_storage_value !== null) {
       if (local_storage_value !== null) {
-        client_storage_values.local_storage[state_key] = local_storage_value
+        client_storage_values[state_key] = local_storage_value
       }
       }
     }
     }
   }
   }
@@ -568,6 +565,36 @@ export const useEventLoop = (
     }
     }
   })
   })
 
 
+
+  // localStorage event handling
+  useEffect(() => {
+    const storage_to_state_map = {};
+
+    if (client_storage.local_storage && typeof window !== "undefined") {
+      for (const state_key in client_storage.local_storage) {
+        const options = client_storage.local_storage[state_key];
+        if (options.sync) {
+          const local_storage_value_key = options.name || state_key;
+          storage_to_state_map[local_storage_value_key] = state_key;
+        }
+      }
+    }
+
+    // e is StorageEvent
+    const handleStorage = (e) => {
+      if (storage_to_state_map[e.key]) {
+        const vars = {}
+        vars[storage_to_state_map[e.key]] = e.newValue
+        const event = Event(`${state_name}.update_vars_internal`, {vars: vars})
+        addEvents([event], e);
+      }
+    };
+
+    window.addEventListener("storage", handleStorage);
+    return () => window.removeEventListener("storage", handleStorage);
+  });
+
+
   // Route after the initial page hydration.
   // Route after the initial page hydration.
   useEffect(() => {
   useEffect(() => {
     const change_complete = () => addEvents(onLoadInternalEvent())
     const change_complete = () => addEvents(onLoadInternalEvent())

+ 1 - 0
reflex/compiler/templates.py

@@ -41,6 +41,7 @@ class ReflexJinjaEnvironment(Environment):
             "use_color_mode": constants.ColorMode.USE,
             "use_color_mode": constants.ColorMode.USE,
             "hydrate": constants.CompileVars.HYDRATE,
             "hydrate": constants.CompileVars.HYDRATE,
             "on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL,
             "on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL,
+            "update_vars_internal": constants.CompileVars.UPDATE_VARS_INTERNAL,
         }
         }
 
 
 
 

+ 2 - 0
reflex/constants/compiler.py

@@ -60,6 +60,8 @@ class CompileVars(SimpleNamespace):
     TO_EVENT = "Event"
     TO_EVENT = "Event"
     # The name of the internal on_load event.
     # The name of the internal on_load event.
     ON_LOAD_INTERNAL = "on_load_internal"
     ON_LOAD_INTERNAL = "on_load_internal"
+    # The name of the internal event to update generic state vars.
+    UPDATE_VARS_INTERNAL = "update_vars_internal"
 
 
 
 
 class PageNames(SimpleNamespace):
 class PageNames(SimpleNamespace):

+ 0 - 8
reflex/middleware/hydrate_middleware.py

@@ -39,14 +39,6 @@ class HydrateMiddleware(Middleware):
         # Mark state as not hydrated (until on_loads are complete)
         # Mark state as not hydrated (until on_loads are complete)
         setattr(state, constants.CompileVars.IS_HYDRATED, False)
         setattr(state, constants.CompileVars.IS_HYDRATED, False)
 
 
-        # Apply client side storage values to state
-        for storage_type in (constants.COOKIES, constants.LOCAL_STORAGE):
-            if storage_type in event.payload:
-                for key, value in event.payload[storage_type].items():
-                    state_name, _, var_name = key.rpartition(".")
-                    var_state = state.get_substate(state_name.split("."))
-                    setattr(var_state, var_name, value)
-
         # Get the initial state.
         # Get the initial state.
         delta = format.format_state(state.dict())
         delta = format.format_state(state.dict())
         # since a full dict was captured, clean any dirtiness
         # since a full dict was captured, clean any dirtiness

+ 21 - 0
reflex/state.py

@@ -1405,6 +1405,23 @@ class State(BaseState):
             type(self).set_is_hydrated(True),  # type: ignore
             type(self).set_is_hydrated(True),  # type: ignore
         ]
         ]
 
 
+    def update_vars_internal(self, vars: dict[str, Any]) -> None:
+        """Apply updates to fully qualified state vars.
+
+        The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
+        and each value will be set on the appropriate substate instance.
+
+        This function is primarily used to apply cookie and local storage
+        updates from the frontend to the appropriate substate.
+
+        Args:
+            vars: The fully qualified vars and values to update.
+        """
+        for var, value in vars.items():
+            state_name, _, var_name = var.rpartition(".")
+            var_state = self.get_substate(state_name.split("."))
+            setattr(var_state, var_name, value)
+
 
 
 class StateProxy(wrapt.ObjectProxy):
 class StateProxy(wrapt.ObjectProxy):
     """Proxy of a state instance to control mutability of vars for a background task.
     """Proxy of a state instance to control mutability of vars for a background task.
@@ -1949,6 +1966,7 @@ class LocalStorage(ClientStorageBase, str):
     """Represents a state Var that is stored in localStorage in the browser."""
     """Represents a state Var that is stored in localStorage in the browser."""
 
 
     name: str | None
     name: str | None
+    sync: bool = False
 
 
     def __new__(
     def __new__(
         cls,
         cls,
@@ -1957,6 +1975,7 @@ class LocalStorage(ClientStorageBase, str):
         errors: str | None = None,
         errors: str | None = None,
         /,
         /,
         name: str | None = None,
         name: str | None = None,
+        sync: bool = False,
     ) -> "LocalStorage":
     ) -> "LocalStorage":
         """Create a client-side localStorage (str).
         """Create a client-side localStorage (str).
 
 
@@ -1965,6 +1984,7 @@ class LocalStorage(ClientStorageBase, str):
             encoding: The encoding to use.
             encoding: The encoding to use.
             errors: The error handling scheme to use.
             errors: The error handling scheme to use.
             name: The name of the storage key on the client side.
             name: The name of the storage key on the client side.
+            sync: Whether changes should be propagated to other tabs.
 
 
         Returns:
         Returns:
             The client-side localStorage object.
             The client-side localStorage object.
@@ -1974,6 +1994,7 @@ class LocalStorage(ClientStorageBase, str):
         else:
         else:
             inst = super().__new__(cls, object)
             inst = super().__new__(cls, object)
         inst.name = name
         inst.name = name
+        inst.sync = sync
         return inst
         return inst