Browse Source

[REF-2265] ComponentState: scaffold for copying State per Component instance (#2923)

* [REF-2265] ComponentState: scaffold for copying State per Component instance

Define a base ComponentState which can be used to easily create copies of the
given State definition (Vars and EventHandlers) that are tied to a particular
instance of a Component (returned by get_component)

* Define `State` field on `Component` for typing compatibility.

This is an Optional field of Type[State] and is populated by ComponentState.

* Add integration/test_component_state.py

Create two independent counters and increment them separately

* Add unit test for ComponentState
Masen Furer 1 year ago
parent
commit
5510eaf820

+ 107 - 0
integration/test_component_state.py

@@ -0,0 +1,107 @@
+"""Test that per-component state scaffold works and operates independently."""
+from typing import Generator
+
+import pytest
+from selenium.webdriver.common.by import By
+
+from reflex.testing import AppHarness
+
+from . import utils
+
+
+def ComponentStateApp():
+    """App using per component state."""
+    import reflex as rx
+
+    class MultiCounter(rx.ComponentState):
+        count: int = 0
+
+        def increment(self):
+            self.count += 1
+
+        @classmethod
+        def get_component(cls, *children, **props):
+            return rx.vstack(
+                *children,
+                rx.heading(cls.count, id=f"count-{props.get('id', 'default')}"),
+                rx.button(
+                    "Increment",
+                    on_click=cls.increment,
+                    id=f"button-{props.get('id', 'default')}",
+                ),
+                **props,
+            )
+
+    app = rx.App(state=rx.State)  # noqa
+
+    @rx.page()
+    def index():
+        mc_a = MultiCounter.create(id="a")
+        mc_b = MultiCounter.create(id="b")
+        assert mc_a.State != mc_b.State
+        return rx.vstack(
+            mc_a,
+            mc_b,
+            rx.button(
+                "Inc A",
+                on_click=mc_a.State.increment,  # type: ignore
+                id="inc-a",
+            ),
+        )
+
+
+@pytest.fixture()
+def component_state_app(tmp_path) -> Generator[AppHarness, None, None]:
+    """Start ComponentStateApp app at tmp_path via AppHarness.
+
+    Args:
+        tmp_path: pytest tmp_path fixture
+
+    Yields:
+        running AppHarness instance
+    """
+    with AppHarness.create(
+        root=tmp_path,
+        app_source=ComponentStateApp,  # type: ignore
+    ) as harness:
+        yield harness
+
+
+@pytest.mark.asyncio
+async def test_component_state_app(component_state_app: AppHarness):
+    """Increment counters independently.
+
+    Args:
+        component_state_app: harness for ComponentStateApp app
+    """
+    assert component_state_app.app_instance is not None, "app is not running"
+    driver = component_state_app.frontend()
+
+    ss = utils.SessionStorage(driver)
+    token = AppHarness._poll_for(lambda: ss.get("token") is not None)
+    assert token is not None
+
+    count_a = driver.find_element(By.ID, "count-a")
+    count_b = driver.find_element(By.ID, "count-b")
+    button_a = driver.find_element(By.ID, "button-a")
+    button_b = driver.find_element(By.ID, "button-b")
+    button_inc_a = driver.find_element(By.ID, "inc-a")
+
+    assert count_a.text == "0"
+
+    button_a.click()
+    assert component_state_app.poll_for_content(count_a, exp_not_equal="0") == "1"
+
+    button_a.click()
+    assert component_state_app.poll_for_content(count_a, exp_not_equal="1") == "2"
+
+    button_inc_a.click()
+    assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3"
+
+    assert count_b.text == "0"
+
+    button_b.click()
+    assert component_state_app.poll_for_content(count_b, exp_not_equal="0") == "1"
+
+    button_b.click()
+    assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2"

+ 11 - 0
integration/utils.py

@@ -38,6 +38,8 @@ class LocalStorage:
     https://stackoverflow.com/a/46361900
     """
 
+    storage_key = "localStorage"
+
     def __init__(self, driver: WebDriver):
         """Initialize the class.
 
@@ -171,3 +173,12 @@ class LocalStorage:
             An iterator over the items in local storage.
         """
         return iter(self.keys())
+
+
+class SessionStorage(LocalStorage):
+    """Class to access session storage.
+
+    https://stackoverflow.com/a/46361900
+    """
+
+    storage_key = "sessionStorage"

+ 8 - 1
reflex/__init__.py

@@ -153,7 +153,14 @@ _MAPPING = {
     "reflex.model": ["model", "session", "Model"],
     "reflex.page": ["page"],
     "reflex.route": ["route"],
-    "reflex.state": ["state", "var", "Cookie", "LocalStorage", "State"],
+    "reflex.state": [
+        "state",
+        "var",
+        "Cookie",
+        "LocalStorage",
+        "ComponentState",
+        "State",
+    ],
     "reflex.style": ["style", "toggle_color_mode"],
     "reflex.testing": ["testing"],
     "reflex.utils": ["utils"],

+ 1 - 0
reflex/__init__.pyi

@@ -141,6 +141,7 @@ from reflex import state as state
 from reflex.state import var as var
 from reflex.state import Cookie as Cookie
 from reflex.state import LocalStorage as LocalStorage
+from reflex.state import ComponentState as ComponentState
 from reflex.state import State as State
 from reflex import style as style
 from reflex.style import toggle_color_mode as toggle_color_mode

+ 4 - 0
reflex/components/component.py

@@ -21,6 +21,7 @@ from typing import (
     Union,
 )
 
+import reflex.state
 from reflex.base import Base
 from reflex.compiler.templates import STATEFUL_COMPONENT
 from reflex.components.tags import Tag
@@ -214,6 +215,9 @@ class Component(BaseComponent, ABC):
     # When to memoize this component and its children.
     _memoization_mode: MemoizationMode = MemoizationMode()
 
+    # State class associated with this component instance
+    State: Optional[Type[reflex.state.State]] = None
+
     @classmethod
     def __init_subclass__(cls, **kwargs):
         """Set default properties.

+ 44 - 0
reflex/state.py

@@ -15,6 +15,7 @@ from abc import ABC, abstractmethod
 from collections import defaultdict
 from types import FunctionType, MethodType
 from typing import (
+    TYPE_CHECKING,
     Any,
     AsyncIterator,
     Callable,
@@ -47,6 +48,10 @@ from reflex.utils.exec import is_testing_env
 from reflex.utils.serializers import SerializedType, serialize, serializer
 from reflex.vars import BaseVar, ComputedVar, Var, computed_var
 
+if TYPE_CHECKING:
+    from reflex.components.component import Component
+
+
 Delta = Dict[str, Any]
 var = computed_var
 
@@ -1835,6 +1840,45 @@ class OnLoadInternalState(State):
         ]
 
 
+class ComponentState(Base):
+    """The base class for a State that is copied for each Component associated with it."""
+
+    _per_component_state_instance_count: ClassVar[int] = 0
+
+    @classmethod
+    def get_component(cls, *children, **props) -> "Component":
+        """Get the component instance.
+
+        Args:
+            children: The children of the component.
+            props: The props of the component.
+
+        Raises:
+            NotImplementedError: if the subclass does not override this method.
+        """
+        raise NotImplementedError(
+            f"{cls.__name__} must implement get_component to return the component instance."
+        )
+
+    @classmethod
+    def create(cls, *children, **props) -> "Component":
+        """Create a new instance of the Component.
+
+        Args:
+            children: The children of the component.
+            props: The props of the component.
+
+        Returns:
+            A new instance of the Component with an independent copy of the State.
+        """
+        cls._per_component_state_instance_count += 1
+        state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}"
+        component_state = type(state_cls_name, (cls, State), {})
+        component = component_state.get_component(*children, **props)
+        component.State = component_state
+        return component
+
+
 class StateProxy(wrapt.ObjectProxy):
     """Proxy of a state instance to control mutability of vars for a background task.
 

+ 1 - 0
scripts/pyi_generator.py

@@ -57,6 +57,7 @@ EXCLUDED_PROPS = [
     "_rename_props",
     "_valid_children",
     "_valid_parents",
+    "State",
 ]
 
 DEFAULT_TYPING_IMPORTS = {

+ 42 - 0
tests/components/test_component_state.py

@@ -0,0 +1,42 @@
+"""Ensure that Components returned by ComponentState.create have independent State classes."""
+
+import reflex as rx
+from reflex.components.base.bare import Bare
+
+
+def test_component_state():
+    """Create two components with independent state classes."""
+
+    class CS(rx.ComponentState):
+        count: int = 0
+
+        def increment(self):
+            self.count += 1
+
+        @classmethod
+        def get_component(cls, *children, **props):
+            return rx.el.div(
+                *children,
+                **props,
+            )
+
+    cs1, cs2 = CS.create("a", id="a"), CS.create("b", id="b")
+    assert isinstance(cs1, rx.Component)
+    assert isinstance(cs2, rx.Component)
+    assert cs1.State is not None
+    assert cs2.State is not None
+    assert cs1.State != cs2.State
+    assert issubclass(cs1.State, CS)
+    assert issubclass(cs1.State, rx.State)
+    assert issubclass(cs2.State, CS)
+    assert issubclass(cs2.State, rx.State)
+    assert CS._per_component_state_instance_count == 2
+    assert isinstance(cs1.State.increment, rx.event.EventHandler)
+    assert cs1.State.increment != cs2.State.increment
+
+    assert len(cs1.children) == 1
+    assert cs1.children[0].render() == Bare.create("{`a`}").render()
+    assert cs1.id == "a"
+    assert len(cs2.children) == 1
+    assert cs2.children[0].render() == Bare.create("{`b`}").render()
+    assert cs2.id == "b"