浏览代码

[REF-2265] ComponentState: scaffold for copying State per Component instance (#2923)

* [REF-2265] ComponentState: scaffold for copying State per Component instance

Define a base ComponentState which can be used to easily create copies of the
given State definition (Vars and EventHandlers) that are tied to a particular
instance of a Component (returned by get_component)

* Define `State` field on `Component` for typing compatibility.

This is an Optional field of Type[State] and is populated by ComponentState.

* Add integration/test_component_state.py

Create two independent counters and increment them separately

* Add unit test for ComponentState
Masen Furer 1 年之前
父节点
当前提交
5510eaf820

+ 107 - 0
integration/test_component_state.py

@@ -0,0 +1,107 @@
+"""Test that per-component state scaffold works and operates independently."""
+from typing import Generator
+
+import pytest
+from selenium.webdriver.common.by import By
+
+from reflex.testing import AppHarness
+
+from . import utils
+
+
+def ComponentStateApp():
+    """App using per component state."""
+    import reflex as rx
+
+    class MultiCounter(rx.ComponentState):
+        count: int = 0
+
+        def increment(self):
+            self.count += 1
+
+        @classmethod
+        def get_component(cls, *children, **props):
+            return rx.vstack(
+                *children,
+                rx.heading(cls.count, id=f"count-{props.get('id', 'default')}"),
+                rx.button(
+                    "Increment",
+                    on_click=cls.increment,
+                    id=f"button-{props.get('id', 'default')}",
+                ),
+                **props,
+            )
+
+    app = rx.App(state=rx.State)  # noqa
+
+    @rx.page()
+    def index():
+        mc_a = MultiCounter.create(id="a")
+        mc_b = MultiCounter.create(id="b")
+        assert mc_a.State != mc_b.State
+        return rx.vstack(
+            mc_a,
+            mc_b,
+            rx.button(
+                "Inc A",
+                on_click=mc_a.State.increment,  # type: ignore
+                id="inc-a",
+            ),
+        )
+
+
+@pytest.fixture()
+def component_state_app(tmp_path) -> Generator[AppHarness, None, None]:
+    """Start ComponentStateApp app at tmp_path via AppHarness.
+
+    Args:
+        tmp_path: pytest tmp_path fixture
+
+    Yields:
+        running AppHarness instance
+    """
+    with AppHarness.create(
+        root=tmp_path,
+        app_source=ComponentStateApp,  # type: ignore
+    ) as harness:
+        yield harness
+
+
+@pytest.mark.asyncio
+async def test_component_state_app(component_state_app: AppHarness):
+    """Increment counters independently.
+
+    Args:
+        component_state_app: harness for ComponentStateApp app
+    """
+    assert component_state_app.app_instance is not None, "app is not running"
+    driver = component_state_app.frontend()
+
+    ss = utils.SessionStorage(driver)
+    token = AppHarness._poll_for(lambda: ss.get("token") is not None)
+    assert token is not None
+
+    count_a = driver.find_element(By.ID, "count-a")
+    count_b = driver.find_element(By.ID, "count-b")
+    button_a = driver.find_element(By.ID, "button-a")
+    button_b = driver.find_element(By.ID, "button-b")
+    button_inc_a = driver.find_element(By.ID, "inc-a")
+
+    assert count_a.text == "0"
+
+    button_a.click()
+    assert component_state_app.poll_for_content(count_a, exp_not_equal="0") == "1"
+
+    button_a.click()
+    assert component_state_app.poll_for_content(count_a, exp_not_equal="1") == "2"
+
+    button_inc_a.click()
+    assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3"
+
+    assert count_b.text == "0"
+
+    button_b.click()
+    assert component_state_app.poll_for_content(count_b, exp_not_equal="0") == "1"
+
+    button_b.click()
+    assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2"

+ 11 - 0
integration/utils.py

@@ -38,6 +38,8 @@ class LocalStorage:
     https://stackoverflow.com/a/46361900
     https://stackoverflow.com/a/46361900
     """
     """
 
 
+    storage_key = "localStorage"
+
     def __init__(self, driver: WebDriver):
     def __init__(self, driver: WebDriver):
         """Initialize the class.
         """Initialize the class.
 
 
@@ -171,3 +173,12 @@ class LocalStorage:
             An iterator over the items in local storage.
             An iterator over the items in local storage.
         """
         """
         return iter(self.keys())
         return iter(self.keys())
+
+
+class SessionStorage(LocalStorage):
+    """Class to access session storage.
+
+    https://stackoverflow.com/a/46361900
+    """
+
+    storage_key = "sessionStorage"

+ 8 - 1
reflex/__init__.py

@@ -153,7 +153,14 @@ _MAPPING = {
     "reflex.model": ["model", "session", "Model"],
     "reflex.model": ["model", "session", "Model"],
     "reflex.page": ["page"],
     "reflex.page": ["page"],
     "reflex.route": ["route"],
     "reflex.route": ["route"],
-    "reflex.state": ["state", "var", "Cookie", "LocalStorage", "State"],
+    "reflex.state": [
+        "state",
+        "var",
+        "Cookie",
+        "LocalStorage",
+        "ComponentState",
+        "State",
+    ],
     "reflex.style": ["style", "toggle_color_mode"],
     "reflex.style": ["style", "toggle_color_mode"],
     "reflex.testing": ["testing"],
     "reflex.testing": ["testing"],
     "reflex.utils": ["utils"],
     "reflex.utils": ["utils"],

+ 1 - 0
reflex/__init__.pyi

@@ -141,6 +141,7 @@ from reflex import state as state
 from reflex.state import var as var
 from reflex.state import var as var
 from reflex.state import Cookie as Cookie
 from reflex.state import Cookie as Cookie
 from reflex.state import LocalStorage as LocalStorage
 from reflex.state import LocalStorage as LocalStorage
+from reflex.state import ComponentState as ComponentState
 from reflex.state import State as State
 from reflex.state import State as State
 from reflex import style as style
 from reflex import style as style
 from reflex.style import toggle_color_mode as toggle_color_mode
 from reflex.style import toggle_color_mode as toggle_color_mode

+ 4 - 0
reflex/components/component.py

@@ -21,6 +21,7 @@ from typing import (
     Union,
     Union,
 )
 )
 
 
+import reflex.state
 from reflex.base import Base
 from reflex.base import Base
 from reflex.compiler.templates import STATEFUL_COMPONENT
 from reflex.compiler.templates import STATEFUL_COMPONENT
 from reflex.components.tags import Tag
 from reflex.components.tags import Tag
@@ -214,6 +215,9 @@ class Component(BaseComponent, ABC):
     # When to memoize this component and its children.
     # When to memoize this component and its children.
     _memoization_mode: MemoizationMode = MemoizationMode()
     _memoization_mode: MemoizationMode = MemoizationMode()
 
 
+    # State class associated with this component instance
+    State: Optional[Type[reflex.state.State]] = None
+
     @classmethod
     @classmethod
     def __init_subclass__(cls, **kwargs):
     def __init_subclass__(cls, **kwargs):
         """Set default properties.
         """Set default properties.

+ 44 - 0
reflex/state.py

@@ -15,6 +15,7 @@ from abc import ABC, abstractmethod
 from collections import defaultdict
 from collections import defaultdict
 from types import FunctionType, MethodType
 from types import FunctionType, MethodType
 from typing import (
 from typing import (
+    TYPE_CHECKING,
     Any,
     Any,
     AsyncIterator,
     AsyncIterator,
     Callable,
     Callable,
@@ -47,6 +48,10 @@ from reflex.utils.exec import is_testing_env
 from reflex.utils.serializers import SerializedType, serialize, serializer
 from reflex.utils.serializers import SerializedType, serialize, serializer
 from reflex.vars import BaseVar, ComputedVar, Var, computed_var
 from reflex.vars import BaseVar, ComputedVar, Var, computed_var
 
 
+if TYPE_CHECKING:
+    from reflex.components.component import Component
+
+
 Delta = Dict[str, Any]
 Delta = Dict[str, Any]
 var = computed_var
 var = computed_var
 
 
@@ -1835,6 +1840,45 @@ class OnLoadInternalState(State):
         ]
         ]
 
 
 
 
+class ComponentState(Base):
+    """The base class for a State that is copied for each Component associated with it."""
+
+    _per_component_state_instance_count: ClassVar[int] = 0
+
+    @classmethod
+    def get_component(cls, *children, **props) -> "Component":
+        """Get the component instance.
+
+        Args:
+            children: The children of the component.
+            props: The props of the component.
+
+        Raises:
+            NotImplementedError: if the subclass does not override this method.
+        """
+        raise NotImplementedError(
+            f"{cls.__name__} must implement get_component to return the component instance."
+        )
+
+    @classmethod
+    def create(cls, *children, **props) -> "Component":
+        """Create a new instance of the Component.
+
+        Args:
+            children: The children of the component.
+            props: The props of the component.
+
+        Returns:
+            A new instance of the Component with an independent copy of the State.
+        """
+        cls._per_component_state_instance_count += 1
+        state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}"
+        component_state = type(state_cls_name, (cls, State), {})
+        component = component_state.get_component(*children, **props)
+        component.State = component_state
+        return component
+
+
 class StateProxy(wrapt.ObjectProxy):
 class StateProxy(wrapt.ObjectProxy):
     """Proxy of a state instance to control mutability of vars for a background task.
     """Proxy of a state instance to control mutability of vars for a background task.
 
 

+ 1 - 0
scripts/pyi_generator.py

@@ -57,6 +57,7 @@ EXCLUDED_PROPS = [
     "_rename_props",
     "_rename_props",
     "_valid_children",
     "_valid_children",
     "_valid_parents",
     "_valid_parents",
+    "State",
 ]
 ]
 
 
 DEFAULT_TYPING_IMPORTS = {
 DEFAULT_TYPING_IMPORTS = {

+ 42 - 0
tests/components/test_component_state.py

@@ -0,0 +1,42 @@
+"""Ensure that Components returned by ComponentState.create have independent State classes."""
+
+import reflex as rx
+from reflex.components.base.bare import Bare
+
+
+def test_component_state():
+    """Create two components with independent state classes."""
+
+    class CS(rx.ComponentState):
+        count: int = 0
+
+        def increment(self):
+            self.count += 1
+
+        @classmethod
+        def get_component(cls, *children, **props):
+            return rx.el.div(
+                *children,
+                **props,
+            )
+
+    cs1, cs2 = CS.create("a", id="a"), CS.create("b", id="b")
+    assert isinstance(cs1, rx.Component)
+    assert isinstance(cs2, rx.Component)
+    assert cs1.State is not None
+    assert cs2.State is not None
+    assert cs1.State != cs2.State
+    assert issubclass(cs1.State, CS)
+    assert issubclass(cs1.State, rx.State)
+    assert issubclass(cs2.State, CS)
+    assert issubclass(cs2.State, rx.State)
+    assert CS._per_component_state_instance_count == 2
+    assert isinstance(cs1.State.increment, rx.event.EventHandler)
+    assert cs1.State.increment != cs2.State.increment
+
+    assert len(cs1.children) == 1
+    assert cs1.children[0].render() == Bare.create("{`a`}").render()
+    assert cs1.id == "a"
+    assert len(cs2.children) == 1
+    assert cs2.children[0].render() == Bare.create("{`b`}").render()
+    assert cs2.id == "b"