Browse Source

Feat: Add Session storage to store data on client storage (#3420)

Kelechi Ebiri 11 tháng trước cách đây
mục cha
commit
2b2cdf9847

+ 1 - 1
.pre-commit-config.yaml

@@ -31,4 +31,4 @@ repos:
         always_run: true
         always_run: true
         language: system
         language: system
         description: 'Update pyi files as needed'
         description: 'Update pyi files as needed'
-        entry: python scripts/make_pyi.py
+        entry: python3 scripts/make_pyi.py

+ 97 - 1
integration/test_client_storage.py

@@ -46,6 +46,11 @@ def ClientSide():
         l5: str = rx.LocalStorage(sync=True)
         l5: str = rx.LocalStorage(sync=True)
         l6: str = rx.LocalStorage(sync=True, name="l6")
         l6: str = rx.LocalStorage(sync=True, name="l6")
 
 
+        # Session storage
+        s1: str = rx.SessionStorage()
+        s2: rx.SessionStorage = "s2 default"  # type: ignore
+        s3: str = rx.SessionStorage(name="s3")
+
         def set_l6(self, my_param: str):
         def set_l6(self, my_param: str):
             self.l6 = my_param
             self.l6 = my_param
 
 
@@ -56,6 +61,7 @@ def ClientSide():
     class ClientSideSubSubState(ClientSideSubState):
     class ClientSideSubSubState(ClientSideSubState):
         c1s: str = rx.Cookie()
         c1s: str = rx.Cookie()
         l1s: str = rx.LocalStorage()
         l1s: str = rx.LocalStorage()
+        s1s: str = rx.SessionStorage()
 
 
         def set_var(self):
         def set_var(self):
             setattr(self, self.state_var, self.input_value)
             setattr(self, self.state_var, self.input_value)
@@ -103,8 +109,12 @@ def ClientSide():
             rx.box(ClientSideSubState.l4, id="l4"),
             rx.box(ClientSideSubState.l4, id="l4"),
             rx.box(ClientSideSubState.l5, id="l5"),
             rx.box(ClientSideSubState.l5, id="l5"),
             rx.box(ClientSideSubState.l6, id="l6"),
             rx.box(ClientSideSubState.l6, id="l6"),
+            rx.box(ClientSideSubState.s1, id="s1"),
+            rx.box(ClientSideSubState.s2, id="s2"),
+            rx.box(ClientSideSubState.s3, id="s3"),
             rx.box(ClientSideSubSubState.c1s, id="c1s"),
             rx.box(ClientSideSubSubState.c1s, id="c1s"),
             rx.box(ClientSideSubSubState.l1s, id="l1s"),
             rx.box(ClientSideSubSubState.l1s, id="l1s"),
+            rx.box(ClientSideSubSubState.s1s, id="s1s"),
         )
         )
 
 
     app = rx.App(state=rx.State)
     app = rx.App(state=rx.State)
@@ -162,6 +172,21 @@ def local_storage(driver: WebDriver) -> Generator[utils.LocalStorage, None, None
     ls.clear()
     ls.clear()
 
 
 
 
+@pytest.fixture()
+def session_storage(driver: WebDriver) -> Generator[utils.SessionStorage, None, None]:
+    """Get an instance of the session storage helper.
+
+    Args:
+        driver: WebDriver instance.
+
+    Yields:
+        Session storage helper.
+    """
+    ss = utils.SessionStorage(driver)
+    yield ss
+    ss.clear()
+
+
 @pytest.fixture(autouse=True)
 @pytest.fixture(autouse=True)
 def delete_all_cookies(driver: WebDriver) -> Generator[None, None, None]:
 def delete_all_cookies(driver: WebDriver) -> Generator[None, None, None]:
     """Delete all cookies after each test.
     """Delete all cookies after each test.
@@ -190,7 +215,10 @@ def cookie_info_map(driver: WebDriver) -> dict[str, dict[str, str]]:
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_client_side_state(
 async def test_client_side_state(
-    client_side: AppHarness, driver: WebDriver, local_storage: utils.LocalStorage
+    client_side: AppHarness,
+    driver: WebDriver,
+    local_storage: utils.LocalStorage,
+    session_storage: utils.SessionStorage,
 ):
 ):
     """Test client side state.
     """Test client side state.
 
 
@@ -198,6 +226,7 @@ async def test_client_side_state(
         client_side: harness for ClientSide app.
         client_side: harness for ClientSide app.
         driver: WebDriver instance.
         driver: WebDriver instance.
         local_storage: Local storage helper.
         local_storage: Local storage helper.
+        session_storage: Session storage helper.
     """
     """
     assert client_side.app_instance is not None
     assert client_side.app_instance is not None
     assert client_side.frontend_url is not None
     assert client_side.frontend_url is not None
@@ -251,8 +280,12 @@ async def test_client_side_state(
     l2 = driver.find_element(By.ID, "l2")
     l2 = driver.find_element(By.ID, "l2")
     l3 = driver.find_element(By.ID, "l3")
     l3 = driver.find_element(By.ID, "l3")
     l4 = driver.find_element(By.ID, "l4")
     l4 = driver.find_element(By.ID, "l4")
+    s1 = driver.find_element(By.ID, "s1")
+    s2 = driver.find_element(By.ID, "s2")
+    s3 = driver.find_element(By.ID, "s3")
     c1s = driver.find_element(By.ID, "c1s")
     c1s = driver.find_element(By.ID, "c1s")
     l1s = driver.find_element(By.ID, "l1s")
     l1s = driver.find_element(By.ID, "l1s")
+    s1s = driver.find_element(By.ID, "s1s")
 
 
     # assert on defaults where present
     # assert on defaults where present
     assert c1.text == ""
     assert c1.text == ""
@@ -266,8 +299,12 @@ async def test_client_side_state(
     assert l2.text == "l2 default"
     assert l2.text == "l2 default"
     assert l3.text == ""
     assert l3.text == ""
     assert l4.text == "l4 default"
     assert l4.text == "l4 default"
+    assert s1.text == ""
+    assert s2.text == "s2 default"
+    assert s3.text == ""
     assert c1s.text == ""
     assert c1s.text == ""
     assert l1s.text == ""
     assert l1s.text == ""
+    assert s1s.text == ""
 
 
     # no cookies should be set yet!
     # no cookies should be set yet!
     assert not driver.get_cookies()
     assert not driver.get_cookies()
@@ -287,8 +324,12 @@ async def test_client_side_state(
     set_sub("l2", "l2 value")
     set_sub("l2", "l2 value")
     set_sub("l3", "l3 value")
     set_sub("l3", "l3 value")
     set_sub("l4", "l4 value")
     set_sub("l4", "l4 value")
+    set_sub("s1", "s1 value")
+    set_sub("s2", "s2 value")
+    set_sub("s3", "s3 value")
     set_sub_sub("c1s", "c1s value")
     set_sub_sub("c1s", "c1s value")
     set_sub_sub("l1s", "l1s value")
     set_sub_sub("l1s", "l1s value")
+    set_sub_sub("s1s", "s1s value")
 
 
     exp_cookies = {
     exp_cookies = {
         "state.client_side_state.client_side_sub_state.c1": {
         "state.client_side_state.client_side_sub_state.c1": {
@@ -405,6 +446,25 @@ async def test_client_side_state(
     )
     )
     assert not local_storage_items
     assert not local_storage_items
 
 
+    session_storage_items = session_storage.items()
+    session_storage_items.pop("token", None)
+    assert (
+        session_storage_items.pop("state.client_side_state.client_side_sub_state.s1")
+        == "s1 value"
+    )
+    assert (
+        session_storage_items.pop("state.client_side_state.client_side_sub_state.s2")
+        == "s2 value"
+    )
+    assert session_storage_items.pop("s3") == "s3 value"
+    assert (
+        session_storage_items.pop(
+            "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.s1s"
+        )
+        == "s1s value"
+    )
+    assert not session_storage_items
+
     assert c1.text == "c1 value"
     assert c1.text == "c1 value"
     assert c2.text == "c2 value"
     assert c2.text == "c2 value"
     assert c3.text == "c3 value"
     assert c3.text == "c3 value"
@@ -416,8 +476,12 @@ async def test_client_side_state(
     assert l2.text == "l2 value"
     assert l2.text == "l2 value"
     assert l3.text == "l3 value"
     assert l3.text == "l3 value"
     assert l4.text == "l4 value"
     assert l4.text == "l4 value"
+    assert s1.text == "s1 value"
+    assert s2.text == "s2 value"
+    assert s3.text == "s3 value"
     assert c1s.text == "c1s value"
     assert c1s.text == "c1s value"
     assert l1s.text == "l1s value"
     assert l1s.text == "l1s value"
+    assert s1s.text == "s1s value"
 
 
     # navigate to the /foo route
     # navigate to the /foo route
     with utils.poll_for_navigation(driver):
     with utils.poll_for_navigation(driver):
@@ -435,8 +499,12 @@ async def test_client_side_state(
     l2 = driver.find_element(By.ID, "l2")
     l2 = driver.find_element(By.ID, "l2")
     l3 = driver.find_element(By.ID, "l3")
     l3 = driver.find_element(By.ID, "l3")
     l4 = driver.find_element(By.ID, "l4")
     l4 = driver.find_element(By.ID, "l4")
+    s1 = driver.find_element(By.ID, "s1")
+    s2 = driver.find_element(By.ID, "s2")
+    s3 = driver.find_element(By.ID, "s3")
     c1s = driver.find_element(By.ID, "c1s")
     c1s = driver.find_element(By.ID, "c1s")
     l1s = driver.find_element(By.ID, "l1s")
     l1s = driver.find_element(By.ID, "l1s")
+    s1s = driver.find_element(By.ID, "s1s")
 
 
     assert c1.text == "c1 value"
     assert c1.text == "c1 value"
     assert c2.text == "c2 value"
     assert c2.text == "c2 value"
@@ -449,8 +517,12 @@ async def test_client_side_state(
     assert l2.text == "l2 value"
     assert l2.text == "l2 value"
     assert l3.text == "l3 value"
     assert l3.text == "l3 value"
     assert l4.text == "l4 value"
     assert l4.text == "l4 value"
+    assert s1.text == "s1 value"
+    assert s2.text == "s2 value"
+    assert s3.text == "s3 value"
     assert c1s.text == "c1s value"
     assert c1s.text == "c1s value"
     assert l1s.text == "l1s value"
     assert l1s.text == "l1s value"
+    assert s1s.text == "s1s value"
 
 
     # reset the backend state to force refresh from client storage
     # reset the backend state to force refresh from client storage
     async with client_side.modify_state(f"{token}_state.client_side_state") as state:
     async with client_side.modify_state(f"{token}_state.client_side_state") as state:
@@ -475,8 +547,12 @@ async def test_client_side_state(
     l2 = driver.find_element(By.ID, "l2")
     l2 = driver.find_element(By.ID, "l2")
     l3 = driver.find_element(By.ID, "l3")
     l3 = driver.find_element(By.ID, "l3")
     l4 = driver.find_element(By.ID, "l4")
     l4 = driver.find_element(By.ID, "l4")
+    s1 = driver.find_element(By.ID, "s1")
+    s2 = driver.find_element(By.ID, "s2")
+    s3 = driver.find_element(By.ID, "s3")
     c1s = driver.find_element(By.ID, "c1s")
     c1s = driver.find_element(By.ID, "c1s")
     l1s = driver.find_element(By.ID, "l1s")
     l1s = driver.find_element(By.ID, "l1s")
+    s1s = driver.find_element(By.ID, "s1s")
 
 
     assert c1.text == "c1 value"
     assert c1.text == "c1 value"
     assert c2.text == "c2 value"
     assert c2.text == "c2 value"
@@ -489,8 +565,12 @@ async def test_client_side_state(
     assert l2.text == "l2 value"
     assert l2.text == "l2 value"
     assert l3.text == "l3 value"
     assert l3.text == "l3 value"
     assert l4.text == "l4 value"
     assert l4.text == "l4 value"
+    assert s1.text == "s1 value"
+    assert s2.text == "s2 value"
+    assert s3.text == "s3 value"
     assert c1s.text == "c1s value"
     assert c1s.text == "c1s value"
     assert l1s.text == "l1s value"
     assert l1s.text == "l1s value"
+    assert s1s.text == "s1s value"
 
 
     # make sure c5 cookie shows up on the `/foo` route
     # make sure c5 cookie shows up on the `/foo` route
     AppHarness._poll_for(
     AppHarness._poll_for(
@@ -525,6 +605,15 @@ async def test_client_side_state(
     assert AppHarness._poll_for(lambda: l6.text == "l6 value")
     assert AppHarness._poll_for(lambda: l6.text == "l6 value")
     assert l5.text == "l5 value"
     assert l5.text == "l5 value"
 
 
+    # Set session storage values in the new tab
+    set_sub("s1", "other tab s1")
+    s1 = driver.find_element(By.ID, "s1")
+    s2 = driver.find_element(By.ID, "s2")
+    s3 = driver.find_element(By.ID, "s3")
+    assert AppHarness._poll_for(lambda: s1.text == "other tab s1")
+    assert s2.text == "s2 default"
+    assert s3.text == ""
+
     # Switch back to main window.
     # Switch back to main window.
     driver.switch_to.window(main_tab)
     driver.switch_to.window(main_tab)
 
 
@@ -534,6 +623,13 @@ async def test_client_side_state(
     assert AppHarness._poll_for(lambda: l6.text == "l6 value")
     assert AppHarness._poll_for(lambda: l6.text == "l6 value")
     assert l5.text == "l5 value"
     assert l5.text == "l5 value"
 
 
+    s1 = driver.find_element(By.ID, "s1")
+    s2 = driver.find_element(By.ID, "s2")
+    s3 = driver.find_element(By.ID, "s3")
+    assert AppHarness._poll_for(lambda: s1.text == "s1 value")
+    assert s2.text == "s2 value"
+    assert s3.text == "s3 value"
+
     # clear the cookie jar and local storage, ensure state reset to default
     # clear the cookie jar and local storage, ensure state reset to default
     driver.delete_all_cookies()
     driver.delete_all_cookies()
     local_storage.clear()
     local_storage.clear()

+ 32 - 1
reflex/.templates/web/utils/state.js

@@ -185,6 +185,18 @@ export const applyEvent = async (event, socket) => {
     return false;
     return false;
   }
   }
 
 
+  if (event.name == "_clear_session_storage") {
+    sessionStorage.clear();
+    queueEvents(initialEvents(), socket);
+    return false;
+  }
+
+  if (event.name == "_remove_session_storage") {
+    sessionStorage.removeItem(event.payload.key);
+    queueEvents(initialEvents(), socket);
+    return false;
+  }
+
   if (event.name == "_set_clipboard") {
   if (event.name == "_set_clipboard") {
     const content = event.payload.content;
     const content = event.payload.content;
     navigator.clipboard.writeText(content);
     navigator.clipboard.writeText(content);
@@ -538,7 +550,18 @@ export const hydrateClientStorage = (client_storage) => {
       }
       }
     }
     }
   }
   }
-  if (client_storage.cookies || client_storage.local_storage) {
+  if (client_storage.session_storage && typeof window != "undefined") {
+    for (const state_key in client_storage.session_storage) {
+      const session_options = client_storage.session_storage[state_key];
+      const session_storage_value = sessionStorage.getItem(
+        session_options.name || state_key
+      );
+      if (session_storage_value != null) {
+        client_storage_values[state_key] = session_storage_value;
+      }
+    }
+  }
+  if (client_storage.cookies || client_storage.local_storage || client_storage.session_storage) {
     return client_storage_values;
     return client_storage_values;
   }
   }
   return {};
   return {};
@@ -578,7 +601,15 @@ const applyClientStorageDelta = (client_storage, delta) => {
       ) {
       ) {
         const options = client_storage.local_storage[state_key];
         const options = client_storage.local_storage[state_key];
         localStorage.setItem(options.name || state_key, delta[substate][key]);
         localStorage.setItem(options.name || state_key, delta[substate][key]);
+      } else if(
+        client_storage.session_storage &&
+        state_key in client_storage.session_storage &&
+        typeof window !== "undefined"
+      ) {
+        const session_options = client_storage.session_storage[state_key];
+        sessionStorage.setItem(session_options.name || state_key, delta[substate][key]);
       }
       }
+
     }
     }
   }
   }
 };
 };

+ 3 - 0
reflex/__init__.py

@@ -287,12 +287,14 @@ _MAPPING: dict = {
         "background",
         "background",
         "call_script",
         "call_script",
         "clear_local_storage",
         "clear_local_storage",
+        "clear_session_storage",
         "console_log",
         "console_log",
         "download",
         "download",
         "prevent_default",
         "prevent_default",
         "redirect",
         "redirect",
         "remove_cookie",
         "remove_cookie",
         "remove_local_storage",
         "remove_local_storage",
+        "remove_session_storage",
         "set_clipboard",
         "set_clipboard",
         "set_focus",
         "set_focus",
         "scroll_to",
         "scroll_to",
@@ -307,6 +309,7 @@ _MAPPING: dict = {
         "var",
         "var",
         "Cookie",
         "Cookie",
         "LocalStorage",
         "LocalStorage",
+        "SessionStorage",
         "ComponentState",
         "ComponentState",
         "State",
         "State",
     ],
     ],

+ 3 - 0
reflex/__init__.pyi

@@ -157,12 +157,14 @@ from .event import EventHandler as EventHandler
 from .event import background as background
 from .event import background as background
 from .event import call_script as call_script
 from .event import call_script as call_script
 from .event import clear_local_storage as clear_local_storage
 from .event import clear_local_storage as clear_local_storage
+from .event import clear_session_storage as clear_session_storage
 from .event import console_log as console_log
 from .event import console_log as console_log
 from .event import download as download
 from .event import download as download
 from .event import prevent_default as prevent_default
 from .event import prevent_default as prevent_default
 from .event import redirect as redirect
 from .event import redirect as redirect
 from .event import remove_cookie as remove_cookie
 from .event import remove_cookie as remove_cookie
 from .event import remove_local_storage as remove_local_storage
 from .event import remove_local_storage as remove_local_storage
+from .event import remove_session_storage as remove_session_storage
 from .event import set_clipboard as set_clipboard
 from .event import set_clipboard as set_clipboard
 from .event import set_focus as set_focus
 from .event import set_focus as set_focus
 from .event import scroll_to as scroll_to
 from .event import scroll_to as scroll_to
@@ -177,6 +179,7 @@ from .model import Model as Model
 from .state import var as var
 from .state import var as var
 from .state import Cookie as Cookie
 from .state import Cookie as Cookie
 from .state import LocalStorage as LocalStorage
 from .state import LocalStorage as LocalStorage
+from .state import SessionStorage as SessionStorage
 from .state import ComponentState as ComponentState
 from .state import ComponentState as ComponentState
 from .state import State as State
 from .state import State as State
 from .style import Style as Style
 from .style import Style as Style

+ 22 - 11
reflex/compiler/utils.py

@@ -28,7 +28,7 @@ from reflex.components.base import (
     Title,
     Title,
 )
 )
 from reflex.components.component import Component, ComponentStyle, CustomComponent
 from reflex.components.component import Component, ComponentStyle, CustomComponent
-from reflex.state import BaseState, Cookie, LocalStorage
+from reflex.state import BaseState, Cookie, LocalStorage, SessionStorage
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import console, format, imports, path_ops
 from reflex.utils import console, format, imports, path_ops
 from reflex.utils.imports import ImportVar, ParsedImportDict
 from reflex.utils.imports import ImportVar, ParsedImportDict
@@ -158,8 +158,11 @@ def compile_state(state: Type[BaseState]) -> dict:
 
 
 def _compile_client_storage_field(
 def _compile_client_storage_field(
     field: ModelField,
     field: ModelField,
-) -> tuple[Type[Cookie] | Type[LocalStorage] | None, dict[str, Any] | None]:
-    """Compile the given cookie or local_storage field.
+) -> tuple[
+    Type[Cookie] | Type[LocalStorage] | Type[SessionStorage] | None,
+    dict[str, Any] | None,
+]:
+    """Compile the given cookie, local_storage or session_storage field.
 
 
     Args:
     Args:
         field: The possible cookie field to compile.
         field: The possible cookie field to compile.
@@ -167,7 +170,7 @@ def _compile_client_storage_field(
     Returns:
     Returns:
         A dictionary of the compiled cookie or None if the field is not cookie-like.
         A dictionary of the compiled cookie or None if the field is not cookie-like.
     """
     """
-    for field_type in (Cookie, LocalStorage):
+    for field_type in (Cookie, LocalStorage, SessionStorage):
         if isinstance(field.default, field_type):
         if isinstance(field.default, field_type):
             cs_obj = field.default
             cs_obj = field.default
         elif isinstance(field.type_, type) and issubclass(field.type_, field_type):
         elif isinstance(field.type_, type) and issubclass(field.type_, field_type):
@@ -180,7 +183,7 @@ def _compile_client_storage_field(
 
 
 def _compile_client_storage_recursive(
 def _compile_client_storage_recursive(
     state: Type[BaseState],
     state: Type[BaseState],
-) -> tuple[dict[str, dict], dict[str, dict[str, str]]]:
+) -> tuple[dict[str, dict], dict[str, dict], dict[str, dict]]:
     """Compile the client-side storage for the given state recursively.
     """Compile the client-side storage for the given state recursively.
 
 
     Args:
     Args:
@@ -191,10 +194,12 @@ def _compile_client_storage_recursive(
             (
             (
                 cookies: dict[str, dict],
                 cookies: dict[str, dict],
                 local_storage: dict[str, dict[str, str]]
                 local_storage: dict[str, dict[str, str]]
-            )
+                session_storage: dict[str, dict[str, str]]
+            ).
     """
     """
     cookies = {}
     cookies = {}
     local_storage = {}
     local_storage = {}
+    session_storage = {}
     state_name = state.get_full_name()
     state_name = state.get_full_name()
     for name, field in state.__fields__.items():
     for name, field in state.__fields__.items():
         if name in state.inherited_vars:
         if name in state.inherited_vars:
@@ -206,15 +211,20 @@ def _compile_client_storage_recursive(
             cookies[state_key] = options
             cookies[state_key] = options
         elif field_type is LocalStorage:
         elif field_type is LocalStorage:
             local_storage[state_key] = options
             local_storage[state_key] = options
+        elif field_type is SessionStorage:
+            session_storage[state_key] = options
         else:
         else:
             continue
             continue
     for substate in state.get_substates():
     for substate in state.get_substates():
-        substate_cookies, substate_local_storage = _compile_client_storage_recursive(
-            substate
-        )
+        (
+            substate_cookies,
+            substate_local_storage,
+            substate_session_storage,
+        ) = _compile_client_storage_recursive(substate)
         cookies.update(substate_cookies)
         cookies.update(substate_cookies)
         local_storage.update(substate_local_storage)
         local_storage.update(substate_local_storage)
-    return cookies, local_storage
+        session_storage.update(substate_session_storage)
+    return cookies, local_storage, session_storage
 
 
 
 
 def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
 def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
@@ -226,10 +236,11 @@ def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
     Returns:
     Returns:
         A dictionary of the compiled client-side storage info.
         A dictionary of the compiled client-side storage info.
     """
     """
-    cookies, local_storage = _compile_client_storage_recursive(state)
+    cookies, local_storage, session_storage = _compile_client_storage_recursive(state)
     return {
     return {
         constants.COOKIES: cookies,
         constants.COOKIES: cookies,
         constants.LOCAL_STORAGE: local_storage,
         constants.LOCAL_STORAGE: local_storage,
+        constants.SESSION_STORAGE: session_storage,
     }
     }
 
 
 
 

+ 2 - 0
reflex/constants/__init__.py

@@ -10,6 +10,7 @@ from .base import (
     REFLEX_VAR_CLOSING_TAG,
     REFLEX_VAR_CLOSING_TAG,
     REFLEX_VAR_OPENING_TAG,
     REFLEX_VAR_OPENING_TAG,
     RELOAD_CONFIG,
     RELOAD_CONFIG,
+    SESSION_STORAGE,
     SKIP_COMPILE_ENV_VAR,
     SKIP_COMPILE_ENV_VAR,
     ColorMode,
     ColorMode,
     Dirs,
     Dirs,
@@ -88,6 +89,7 @@ __ALL__ = [
     Imports,
     Imports,
     IS_WINDOWS,
     IS_WINDOWS,
     LOCAL_STORAGE,
     LOCAL_STORAGE,
+    SESSION_STORAGE,
     LogLevel,
     LogLevel,
     MemoizationDisposition,
     MemoizationDisposition,
     MemoizationMode,
     MemoizationMode,

+ 1 - 0
reflex/constants/base.py

@@ -178,6 +178,7 @@ class Ping(SimpleNamespace):
 # Keys in the client_side_storage dict
 # Keys in the client_side_storage dict
 COOKIES = "cookies"
 COOKIES = "cookies"
 LOCAL_STORAGE = "local_storage"
 LOCAL_STORAGE = "local_storage"
+SESSION_STORAGE = "session_storage"
 
 
 # If this env var is set to "yes", App.compile will be a no-op
 # If this env var is set to "yes", App.compile will be a no-op
 SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"
 SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"

+ 28 - 0
reflex/event.py

@@ -617,6 +617,34 @@ def remove_local_storage(key: str) -> EventSpec:
     )
     )
 
 
 
 
+def clear_session_storage() -> EventSpec:
+    """Set a value in the session storage on the frontend.
+
+    Returns:
+        EventSpec: An event to clear the session storage.
+    """
+    return server_side(
+        "_clear_session_storage",
+        get_fn_signature(clear_session_storage),
+    )
+
+
+def remove_session_storage(key: str) -> EventSpec:
+    """Set a value in the session storage on the frontend.
+
+    Args:
+        key: The key identifying the variable in the session storage to remove.
+
+    Returns:
+        EventSpec: An event to remove an item based on the provided key in session storage.
+    """
+    return server_side(
+        "_remove_session_storage",
+        get_fn_signature(remove_session_storage),
+        key=key,
+    )
+
+
 def set_clipboard(content: str) -> EventSpec:
 def set_clipboard(content: str) -> EventSpec:
     """Set the text in content in the clipboard.
     """Set the text in content in the clipboard.
 
 

+ 32 - 0
reflex/state.py

@@ -2835,6 +2835,38 @@ class LocalStorage(ClientStorageBase, str):
         return inst
         return inst
 
 
 
 
+class SessionStorage(ClientStorageBase, str):
+    """Represents a state Var that is stored in sessionStorage in the browser."""
+
+    name: str | None
+
+    def __new__(
+        cls,
+        object: Any = "",
+        encoding: str | None = None,
+        errors: str | None = None,
+        /,
+        name: str | None = None,
+    ) -> "SessionStorage":
+        """Create a client-side sessionStorage (str).
+
+        Args:
+            object: The initial object.
+            encoding: The encoding to use.
+            errors: The error handling scheme to use
+            name: The name of the storage on the client side
+
+        Returns:
+            The client-side sessionStorage object.
+        """
+        if encoding or errors:
+            inst = super().__new__(cls, object, encoding or "utf-8", errors or "strict")
+        else:
+            inst = super().__new__(cls, object)
+        inst.name = name
+        return inst
+
+
 class MutableProxy(wrapt.ObjectProxy):
 class MutableProxy(wrapt.ObjectProxy):
     """A proxy for a mutable object that tracks changes."""
     """A proxy for a mutable object that tracks changes."""