Prechádzať zdrojové kódy

Feat: Add Session storage to store data on client storage (#3420)

Kelechi Ebiri 11 mesiacov pred
rodič
commit
2b2cdf9847

+ 1 - 1
.pre-commit-config.yaml

@@ -31,4 +31,4 @@ repos:
         always_run: true
         language: system
         description: 'Update pyi files as needed'
-        entry: python scripts/make_pyi.py
+        entry: python3 scripts/make_pyi.py

+ 97 - 1
integration/test_client_storage.py

@@ -46,6 +46,11 @@ def ClientSide():
         l5: str = rx.LocalStorage(sync=True)
         l6: str = rx.LocalStorage(sync=True, name="l6")
 
+        # Session storage
+        s1: str = rx.SessionStorage()
+        s2: rx.SessionStorage = "s2 default"  # type: ignore
+        s3: str = rx.SessionStorage(name="s3")
+
         def set_l6(self, my_param: str):
             self.l6 = my_param
 
@@ -56,6 +61,7 @@ def ClientSide():
     class ClientSideSubSubState(ClientSideSubState):
         c1s: str = rx.Cookie()
         l1s: str = rx.LocalStorage()
+        s1s: str = rx.SessionStorage()
 
         def set_var(self):
             setattr(self, self.state_var, self.input_value)
@@ -103,8 +109,12 @@ def ClientSide():
             rx.box(ClientSideSubState.l4, id="l4"),
             rx.box(ClientSideSubState.l5, id="l5"),
             rx.box(ClientSideSubState.l6, id="l6"),
+            rx.box(ClientSideSubState.s1, id="s1"),
+            rx.box(ClientSideSubState.s2, id="s2"),
+            rx.box(ClientSideSubState.s3, id="s3"),
             rx.box(ClientSideSubSubState.c1s, id="c1s"),
             rx.box(ClientSideSubSubState.l1s, id="l1s"),
+            rx.box(ClientSideSubSubState.s1s, id="s1s"),
         )
 
     app = rx.App(state=rx.State)
@@ -162,6 +172,21 @@ def local_storage(driver: WebDriver) -> Generator[utils.LocalStorage, None, None
     ls.clear()
 
 
+@pytest.fixture()
+def session_storage(driver: WebDriver) -> Generator[utils.SessionStorage, None, None]:
+    """Get an instance of the session storage helper.
+
+    Args:
+        driver: WebDriver instance.
+
+    Yields:
+        Session storage helper.
+    """
+    ss = utils.SessionStorage(driver)
+    yield ss
+    ss.clear()
+
+
 @pytest.fixture(autouse=True)
 def delete_all_cookies(driver: WebDriver) -> Generator[None, None, None]:
     """Delete all cookies after each test.
@@ -190,7 +215,10 @@ def cookie_info_map(driver: WebDriver) -> dict[str, dict[str, str]]:
 
 @pytest.mark.asyncio
 async def test_client_side_state(
-    client_side: AppHarness, driver: WebDriver, local_storage: utils.LocalStorage
+    client_side: AppHarness,
+    driver: WebDriver,
+    local_storage: utils.LocalStorage,
+    session_storage: utils.SessionStorage,
 ):
     """Test client side state.
 
@@ -198,6 +226,7 @@ async def test_client_side_state(
         client_side: harness for ClientSide app.
         driver: WebDriver instance.
         local_storage: Local storage helper.
+        session_storage: Session storage helper.
     """
     assert client_side.app_instance is not None
     assert client_side.frontend_url is not None
@@ -251,8 +280,12 @@ async def test_client_side_state(
     l2 = driver.find_element(By.ID, "l2")
     l3 = driver.find_element(By.ID, "l3")
     l4 = driver.find_element(By.ID, "l4")
+    s1 = driver.find_element(By.ID, "s1")
+    s2 = driver.find_element(By.ID, "s2")
+    s3 = driver.find_element(By.ID, "s3")
     c1s = driver.find_element(By.ID, "c1s")
     l1s = driver.find_element(By.ID, "l1s")
+    s1s = driver.find_element(By.ID, "s1s")
 
     # assert on defaults where present
     assert c1.text == ""
@@ -266,8 +299,12 @@ async def test_client_side_state(
     assert l2.text == "l2 default"
     assert l3.text == ""
     assert l4.text == "l4 default"
+    assert s1.text == ""
+    assert s2.text == "s2 default"
+    assert s3.text == ""
     assert c1s.text == ""
     assert l1s.text == ""
+    assert s1s.text == ""
 
     # no cookies should be set yet!
     assert not driver.get_cookies()
@@ -287,8 +324,12 @@ async def test_client_side_state(
     set_sub("l2", "l2 value")
     set_sub("l3", "l3 value")
     set_sub("l4", "l4 value")
+    set_sub("s1", "s1 value")
+    set_sub("s2", "s2 value")
+    set_sub("s3", "s3 value")
     set_sub_sub("c1s", "c1s value")
     set_sub_sub("l1s", "l1s value")
+    set_sub_sub("s1s", "s1s value")
 
     exp_cookies = {
         "state.client_side_state.client_side_sub_state.c1": {
@@ -405,6 +446,25 @@ async def test_client_side_state(
     )
     assert not local_storage_items
 
+    session_storage_items = session_storage.items()
+    session_storage_items.pop("token", None)
+    assert (
+        session_storage_items.pop("state.client_side_state.client_side_sub_state.s1")
+        == "s1 value"
+    )
+    assert (
+        session_storage_items.pop("state.client_side_state.client_side_sub_state.s2")
+        == "s2 value"
+    )
+    assert session_storage_items.pop("s3") == "s3 value"
+    assert (
+        session_storage_items.pop(
+            "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.s1s"
+        )
+        == "s1s value"
+    )
+    assert not session_storage_items
+
     assert c1.text == "c1 value"
     assert c2.text == "c2 value"
     assert c3.text == "c3 value"
@@ -416,8 +476,12 @@ async def test_client_side_state(
     assert l2.text == "l2 value"
     assert l3.text == "l3 value"
     assert l4.text == "l4 value"
+    assert s1.text == "s1 value"
+    assert s2.text == "s2 value"
+    assert s3.text == "s3 value"
     assert c1s.text == "c1s value"
     assert l1s.text == "l1s value"
+    assert s1s.text == "s1s value"
 
     # navigate to the /foo route
     with utils.poll_for_navigation(driver):
@@ -435,8 +499,12 @@ async def test_client_side_state(
     l2 = driver.find_element(By.ID, "l2")
     l3 = driver.find_element(By.ID, "l3")
     l4 = driver.find_element(By.ID, "l4")
+    s1 = driver.find_element(By.ID, "s1")
+    s2 = driver.find_element(By.ID, "s2")
+    s3 = driver.find_element(By.ID, "s3")
     c1s = driver.find_element(By.ID, "c1s")
     l1s = driver.find_element(By.ID, "l1s")
+    s1s = driver.find_element(By.ID, "s1s")
 
     assert c1.text == "c1 value"
     assert c2.text == "c2 value"
@@ -449,8 +517,12 @@ async def test_client_side_state(
     assert l2.text == "l2 value"
     assert l3.text == "l3 value"
     assert l4.text == "l4 value"
+    assert s1.text == "s1 value"
+    assert s2.text == "s2 value"
+    assert s3.text == "s3 value"
     assert c1s.text == "c1s value"
     assert l1s.text == "l1s value"
+    assert s1s.text == "s1s value"
 
     # reset the backend state to force refresh from client storage
     async with client_side.modify_state(f"{token}_state.client_side_state") as state:
@@ -475,8 +547,12 @@ async def test_client_side_state(
     l2 = driver.find_element(By.ID, "l2")
     l3 = driver.find_element(By.ID, "l3")
     l4 = driver.find_element(By.ID, "l4")
+    s1 = driver.find_element(By.ID, "s1")
+    s2 = driver.find_element(By.ID, "s2")
+    s3 = driver.find_element(By.ID, "s3")
     c1s = driver.find_element(By.ID, "c1s")
     l1s = driver.find_element(By.ID, "l1s")
+    s1s = driver.find_element(By.ID, "s1s")
 
     assert c1.text == "c1 value"
     assert c2.text == "c2 value"
@@ -489,8 +565,12 @@ async def test_client_side_state(
     assert l2.text == "l2 value"
     assert l3.text == "l3 value"
     assert l4.text == "l4 value"
+    assert s1.text == "s1 value"
+    assert s2.text == "s2 value"
+    assert s3.text == "s3 value"
     assert c1s.text == "c1s value"
     assert l1s.text == "l1s value"
+    assert s1s.text == "s1s value"
 
     # make sure c5 cookie shows up on the `/foo` route
     AppHarness._poll_for(
@@ -525,6 +605,15 @@ async def test_client_side_state(
     assert AppHarness._poll_for(lambda: l6.text == "l6 value")
     assert l5.text == "l5 value"
 
+    # Set session storage values in the new tab
+    set_sub("s1", "other tab s1")
+    s1 = driver.find_element(By.ID, "s1")
+    s2 = driver.find_element(By.ID, "s2")
+    s3 = driver.find_element(By.ID, "s3")
+    assert AppHarness._poll_for(lambda: s1.text == "other tab s1")
+    assert s2.text == "s2 default"
+    assert s3.text == ""
+
     # Switch back to main window.
     driver.switch_to.window(main_tab)
 
@@ -534,6 +623,13 @@ async def test_client_side_state(
     assert AppHarness._poll_for(lambda: l6.text == "l6 value")
     assert l5.text == "l5 value"
 
+    s1 = driver.find_element(By.ID, "s1")
+    s2 = driver.find_element(By.ID, "s2")
+    s3 = driver.find_element(By.ID, "s3")
+    assert AppHarness._poll_for(lambda: s1.text == "s1 value")
+    assert s2.text == "s2 value"
+    assert s3.text == "s3 value"
+
     # clear the cookie jar and local storage, ensure state reset to default
     driver.delete_all_cookies()
     local_storage.clear()

+ 32 - 1
reflex/.templates/web/utils/state.js

@@ -185,6 +185,18 @@ export const applyEvent = async (event, socket) => {
     return false;
   }
 
+  if (event.name == "_clear_session_storage") {
+    sessionStorage.clear();
+    queueEvents(initialEvents(), socket);
+    return false;
+  }
+
+  if (event.name == "_remove_session_storage") {
+    sessionStorage.removeItem(event.payload.key);
+    queueEvents(initialEvents(), socket);
+    return false;
+  }
+
   if (event.name == "_set_clipboard") {
     const content = event.payload.content;
     navigator.clipboard.writeText(content);
@@ -538,7 +550,18 @@ export const hydrateClientStorage = (client_storage) => {
       }
     }
   }
-  if (client_storage.cookies || client_storage.local_storage) {
+  if (client_storage.session_storage && typeof window != "undefined") {
+    for (const state_key in client_storage.session_storage) {
+      const session_options = client_storage.session_storage[state_key];
+      const session_storage_value = sessionStorage.getItem(
+        session_options.name || state_key
+      );
+      if (session_storage_value != null) {
+        client_storage_values[state_key] = session_storage_value;
+      }
+    }
+  }
+  if (client_storage.cookies || client_storage.local_storage || client_storage.session_storage) {
     return client_storage_values;
   }
   return {};
@@ -578,7 +601,15 @@ const applyClientStorageDelta = (client_storage, delta) => {
       ) {
         const options = client_storage.local_storage[state_key];
         localStorage.setItem(options.name || state_key, delta[substate][key]);
+      } else if(
+        client_storage.session_storage &&
+        state_key in client_storage.session_storage &&
+        typeof window !== "undefined"
+      ) {
+        const session_options = client_storage.session_storage[state_key];
+        sessionStorage.setItem(session_options.name || state_key, delta[substate][key]);
       }
+
     }
   }
 };

+ 3 - 0
reflex/__init__.py

@@ -287,12 +287,14 @@ _MAPPING: dict = {
         "background",
         "call_script",
         "clear_local_storage",
+        "clear_session_storage",
         "console_log",
         "download",
         "prevent_default",
         "redirect",
         "remove_cookie",
         "remove_local_storage",
+        "remove_session_storage",
         "set_clipboard",
         "set_focus",
         "scroll_to",
@@ -307,6 +309,7 @@ _MAPPING: dict = {
         "var",
         "Cookie",
         "LocalStorage",
+        "SessionStorage",
         "ComponentState",
         "State",
     ],

+ 3 - 0
reflex/__init__.pyi

@@ -157,12 +157,14 @@ from .event import EventHandler as EventHandler
 from .event import background as background
 from .event import call_script as call_script
 from .event import clear_local_storage as clear_local_storage
+from .event import clear_session_storage as clear_session_storage
 from .event import console_log as console_log
 from .event import download as download
 from .event import prevent_default as prevent_default
 from .event import redirect as redirect
 from .event import remove_cookie as remove_cookie
 from .event import remove_local_storage as remove_local_storage
+from .event import remove_session_storage as remove_session_storage
 from .event import set_clipboard as set_clipboard
 from .event import set_focus as set_focus
 from .event import scroll_to as scroll_to
@@ -177,6 +179,7 @@ from .model import Model as Model
 from .state import var as var
 from .state import Cookie as Cookie
 from .state import LocalStorage as LocalStorage
+from .state import SessionStorage as SessionStorage
 from .state import ComponentState as ComponentState
 from .state import State as State
 from .style import Style as Style

+ 22 - 11
reflex/compiler/utils.py

@@ -28,7 +28,7 @@ from reflex.components.base import (
     Title,
 )
 from reflex.components.component import Component, ComponentStyle, CustomComponent
-from reflex.state import BaseState, Cookie, LocalStorage
+from reflex.state import BaseState, Cookie, LocalStorage, SessionStorage
 from reflex.style import Style
 from reflex.utils import console, format, imports, path_ops
 from reflex.utils.imports import ImportVar, ParsedImportDict
@@ -158,8 +158,11 @@ def compile_state(state: Type[BaseState]) -> dict:
 
 def _compile_client_storage_field(
     field: ModelField,
-) -> tuple[Type[Cookie] | Type[LocalStorage] | None, dict[str, Any] | None]:
-    """Compile the given cookie or local_storage field.
+) -> tuple[
+    Type[Cookie] | Type[LocalStorage] | Type[SessionStorage] | None,
+    dict[str, Any] | None,
+]:
+    """Compile the given cookie, local_storage or session_storage field.
 
     Args:
         field: The possible cookie field to compile.
@@ -167,7 +170,7 @@ def _compile_client_storage_field(
     Returns:
         A dictionary of the compiled cookie or None if the field is not cookie-like.
     """
-    for field_type in (Cookie, LocalStorage):
+    for field_type in (Cookie, LocalStorage, SessionStorage):
         if isinstance(field.default, field_type):
             cs_obj = field.default
         elif isinstance(field.type_, type) and issubclass(field.type_, field_type):
@@ -180,7 +183,7 @@ def _compile_client_storage_field(
 
 def _compile_client_storage_recursive(
     state: Type[BaseState],
-) -> tuple[dict[str, dict], dict[str, dict[str, str]]]:
+) -> tuple[dict[str, dict], dict[str, dict], dict[str, dict]]:
     """Compile the client-side storage for the given state recursively.
 
     Args:
@@ -191,10 +194,12 @@ def _compile_client_storage_recursive(
             (
                 cookies: dict[str, dict],
                 local_storage: dict[str, dict[str, str]]
-            )
+                session_storage: dict[str, dict[str, str]]
+            ).
     """
     cookies = {}
     local_storage = {}
+    session_storage = {}
     state_name = state.get_full_name()
     for name, field in state.__fields__.items():
         if name in state.inherited_vars:
@@ -206,15 +211,20 @@ def _compile_client_storage_recursive(
             cookies[state_key] = options
         elif field_type is LocalStorage:
             local_storage[state_key] = options
+        elif field_type is SessionStorage:
+            session_storage[state_key] = options
         else:
             continue
     for substate in state.get_substates():
-        substate_cookies, substate_local_storage = _compile_client_storage_recursive(
-            substate
-        )
+        (
+            substate_cookies,
+            substate_local_storage,
+            substate_session_storage,
+        ) = _compile_client_storage_recursive(substate)
         cookies.update(substate_cookies)
         local_storage.update(substate_local_storage)
-    return cookies, local_storage
+        session_storage.update(substate_session_storage)
+    return cookies, local_storage, session_storage
 
 
 def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
@@ -226,10 +236,11 @@ def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
     Returns:
         A dictionary of the compiled client-side storage info.
     """
-    cookies, local_storage = _compile_client_storage_recursive(state)
+    cookies, local_storage, session_storage = _compile_client_storage_recursive(state)
     return {
         constants.COOKIES: cookies,
         constants.LOCAL_STORAGE: local_storage,
+        constants.SESSION_STORAGE: session_storage,
     }
 
 

+ 2 - 0
reflex/constants/__init__.py

@@ -10,6 +10,7 @@ from .base import (
     REFLEX_VAR_CLOSING_TAG,
     REFLEX_VAR_OPENING_TAG,
     RELOAD_CONFIG,
+    SESSION_STORAGE,
     SKIP_COMPILE_ENV_VAR,
     ColorMode,
     Dirs,
@@ -88,6 +89,7 @@ __ALL__ = [
     Imports,
     IS_WINDOWS,
     LOCAL_STORAGE,
+    SESSION_STORAGE,
     LogLevel,
     MemoizationDisposition,
     MemoizationMode,

+ 1 - 0
reflex/constants/base.py

@@ -178,6 +178,7 @@ class Ping(SimpleNamespace):
 # Keys in the client_side_storage dict
 COOKIES = "cookies"
 LOCAL_STORAGE = "local_storage"
+SESSION_STORAGE = "session_storage"
 
 # If this env var is set to "yes", App.compile will be a no-op
 SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"

+ 28 - 0
reflex/event.py

@@ -617,6 +617,34 @@ def remove_local_storage(key: str) -> EventSpec:
     )
 
 
+def clear_session_storage() -> EventSpec:
+    """Set a value in the session storage on the frontend.
+
+    Returns:
+        EventSpec: An event to clear the session storage.
+    """
+    return server_side(
+        "_clear_session_storage",
+        get_fn_signature(clear_session_storage),
+    )
+
+
+def remove_session_storage(key: str) -> EventSpec:
+    """Set a value in the session storage on the frontend.
+
+    Args:
+        key: The key identifying the variable in the session storage to remove.
+
+    Returns:
+        EventSpec: An event to remove an item based on the provided key in session storage.
+    """
+    return server_side(
+        "_remove_session_storage",
+        get_fn_signature(remove_session_storage),
+        key=key,
+    )
+
+
 def set_clipboard(content: str) -> EventSpec:
     """Set the text in content in the clipboard.
 

+ 32 - 0
reflex/state.py

@@ -2835,6 +2835,38 @@ class LocalStorage(ClientStorageBase, str):
         return inst
 
 
+class SessionStorage(ClientStorageBase, str):
+    """Represents a state Var that is stored in sessionStorage in the browser."""
+
+    name: str | None
+
+    def __new__(
+        cls,
+        object: Any = "",
+        encoding: str | None = None,
+        errors: str | None = None,
+        /,
+        name: str | None = None,
+    ) -> "SessionStorage":
+        """Create a client-side sessionStorage (str).
+
+        Args:
+            object: The initial object.
+            encoding: The encoding to use.
+            errors: The error handling scheme to use
+            name: The name of the storage on the client side
+
+        Returns:
+            The client-side sessionStorage object.
+        """
+        if encoding or errors:
+            inst = super().__new__(cls, object, encoding or "utf-8", errors or "strict")
+        else:
+            inst = super().__new__(cls, object)
+        inst.name = name
+        return inst
+
+
 class MutableProxy(wrapt.ObjectProxy):
     """A proxy for a mutable object that tracks changes."""