Browse Source

typed mixins and ComponentState (#3196)

* typed mixins

* implicit mixin=True kwarg for ComponentState subclasses

* fix: always init other subclasses

* adjust tests: all mixins support base vars now
benedikt-bartscher 1 year ago
parent
commit
d96baac7d9
3 changed files with 75 additions and 29 deletions
  1. 1 0
      integration/test_component_state.py
  2. 44 22
      integration/test_state_inheritance.py
  3. 30 7
      reflex/state.py

+ 1 - 0
integration/test_component_state.py

@@ -1,4 +1,5 @@
 """Test that per-component state scaffold works and operates independently."""
+
 from typing import Generator
 
 import pytest

+ 44 - 22
integration/test_state_inheritance.py

@@ -45,17 +45,15 @@ def StateInheritance():
     """Test that state inheritance works as expected."""
     import reflex as rx
 
-    class ChildMixin:
-        # mixin basevars only work with pydantic/rx.Base models
-        #  child_mixin: str = "child_mixin"
+    class ChildMixin(rx.State, mixin=True):
+        child_mixin: str = "child_mixin"
 
         @rx.var
         def computed_child_mixin(self) -> str:
             return "computed_child_mixin"
 
-    class Mixin(ChildMixin):
-        # mixin basevars only work with pydantic/rx.Base models
-        #  mixin: str = "mixin"
+    class Mixin(ChildMixin, mixin=True):
+        mixin: str = "mixin"
 
         @rx.var
         def computed_mixin(self) -> str:
@@ -64,7 +62,7 @@ def StateInheritance():
         def on_click_mixin(self):
             return rx.call_script("alert('clicked')")
 
-    class OtherMixin(rx.Base):
+    class OtherMixin(rx.State, mixin=True):
         other_mixin: str = "other_mixin"
         other_mixin_clicks: int = 0
 
@@ -78,7 +76,7 @@ def StateInheritance():
                 f"{self.__class__.__name__}.clicked.{self.other_mixin_clicks}"
             )
 
-    class Base1(rx.State, Mixin):
+    class Base1(Mixin, rx.State):
         _base1: str = "_base1"
         base1: str = "base1"
 
@@ -122,14 +120,15 @@ def StateInheritance():
 
     def index() -> rx.Component:
         return rx.vstack(
-            rx.chakra.input(
+            rx.input(
                 id="token", value=Base1.router.session.client_token, is_read_only=True
             ),
-            # Base 1
+            # Base 1 (Mixin, ChildMixin)
             rx.heading(Base1.computed_mixin, id="base1-computed_mixin"),
             rx.heading(Base1.computed_basevar, id="base1-computed_basevar"),
-            rx.heading(Base1.computed_child_mixin, id="base1-child-mixin"),
+            rx.heading(Base1.computed_child_mixin, id="base1-computed-child-mixin"),
             rx.heading(Base1.base1, id="base1-base1"),
+            rx.heading(Base1.child_mixin, id="base1-child-mixin"),
             rx.button(
                 "Base1.on_click_mixin",
                 on_click=Base1.on_click_mixin,  # type: ignore
@@ -138,31 +137,33 @@ def StateInheritance():
             rx.heading(
                 Base1.computed_backend_vars_base1, id="base1-computed_backend_vars"
             ),
-            # Base 2
+            # Base 2 (no mixins)
             rx.heading(Base2.computed_basevar, id="base2-computed_basevar"),
             rx.heading(Base2.base2, id="base2-base2"),
             rx.heading(
                 Base2.computed_backend_vars_base2, id="base2-computed_backend_vars"
             ),
-            # Child 1
+            # Child 1 (Mixin, ChildMixin, OtherMixin)
             rx.heading(Child1.computed_basevar, id="child1-computed_basevar"),
             rx.heading(Child1.computed_mixin, id="child1-computed_mixin"),
             rx.heading(Child1.computed_other_mixin, id="child1-other-mixin"),
-            rx.heading(Child1.computed_child_mixin, id="child1-child-mixin"),
+            rx.heading(Child1.computed_child_mixin, id="child1-computed-child-mixin"),
             rx.heading(Child1.base1, id="child1-base1"),
             rx.heading(Child1.other_mixin, id="child1-other_mixin"),
+            rx.heading(Child1.child_mixin, id="child1-child-mixin"),
             rx.button(
                 "Child1.on_click_other_mixin",
                 on_click=Child1.on_click_other_mixin,  # type: ignore
                 id="child1-other-mixin-btn",
             ),
-            # Child 2
+            # Child 2 (Mixin, ChildMixin, OtherMixin)
             rx.heading(Child2.computed_basevar, id="child2-computed_basevar"),
             rx.heading(Child2.computed_mixin, id="child2-computed_mixin"),
             rx.heading(Child2.computed_other_mixin, id="child2-other-mixin"),
-            rx.heading(Child2.computed_child_mixin, id="child2-child-mixin"),
+            rx.heading(Child2.computed_child_mixin, id="child2-computed-child-mixin"),
             rx.heading(Child2.base2, id="child2-base2"),
             rx.heading(Child2.other_mixin, id="child2-other_mixin"),
+            rx.heading(Child2.child_mixin, id="child2-child-mixin"),
             rx.button(
                 "Child2.on_click_mixin",
                 on_click=Child2.on_click_mixin,  # type: ignore
@@ -173,15 +174,16 @@ def StateInheritance():
                 on_click=Child2.on_click_other_mixin,  # type: ignore
                 id="child2-other-mixin-btn",
             ),
-            # Child 3
+            # Child 3 (Mixin, ChildMixin, OtherMixin)
             rx.heading(Child3.computed_basevar, id="child3-computed_basevar"),
             rx.heading(Child3.computed_mixin, id="child3-computed_mixin"),
             rx.heading(Child3.computed_other_mixin, id="child3-other-mixin"),
             rx.heading(Child3.computed_childvar, id="child3-computed_childvar"),
-            rx.heading(Child3.computed_child_mixin, id="child3-child-mixin"),
+            rx.heading(Child3.computed_child_mixin, id="child3-computed-child-mixin"),
             rx.heading(Child3.child3, id="child3-child3"),
             rx.heading(Child3.base2, id="child3-base2"),
             rx.heading(Child3.other_mixin, id="child3-other_mixin"),
+            rx.heading(Child3.child_mixin, id="child3-child-mixin"),
             rx.button(
                 "Child3.on_click_mixin",
                 on_click=Child3.on_click_mixin,  # type: ignore
@@ -282,7 +284,9 @@ def test_state_inheritance(
     base1_computed_basevar = driver.find_element(By.ID, "base1-computed_basevar")
     assert base1_computed_basevar.text == "computed_basevar1"
 
-    base1_computed_child_mixin = driver.find_element(By.ID, "base1-child-mixin")
+    base1_computed_child_mixin = driver.find_element(
+        By.ID, "base1-computed-child-mixin"
+    )
     assert base1_computed_child_mixin.text == "computed_child_mixin"
 
     base1_base1 = driver.find_element(By.ID, "base1-base1")
@@ -293,6 +297,9 @@ def test_state_inheritance(
     )
     assert base1_computed_backend_vars.text == "_base1"
 
+    base1_child_mixin = driver.find_element(By.ID, "base1-child-mixin")
+    assert base1_child_mixin.text == "child_mixin"
+
     # Base 2
     base2_computed_basevar = driver.find_element(By.ID, "base2-computed_basevar")
     assert base2_computed_basevar.text == "computed_basevar2"
@@ -315,7 +322,9 @@ def test_state_inheritance(
     child1_computed_other_mixin = driver.find_element(By.ID, "child1-other-mixin")
     assert child1_computed_other_mixin.text == "other_mixin"
 
-    child1_computed_child_mixin = driver.find_element(By.ID, "child1-child-mixin")
+    child1_computed_child_mixin = driver.find_element(
+        By.ID, "child1-computed-child-mixin"
+    )
     assert child1_computed_child_mixin.text == "computed_child_mixin"
 
     child1_base1 = driver.find_element(By.ID, "child1-base1")
@@ -324,6 +333,9 @@ def test_state_inheritance(
     child1_other_mixin = driver.find_element(By.ID, "child1-other_mixin")
     assert child1_other_mixin.text == "other_mixin"
 
+    child1_child_mixin = driver.find_element(By.ID, "child1-child-mixin")
+    assert child1_child_mixin.text == "child_mixin"
+
     # Child 2
     child2_computed_basevar = driver.find_element(By.ID, "child2-computed_basevar")
     assert child2_computed_basevar.text == "computed_basevar2"
@@ -334,7 +346,9 @@ def test_state_inheritance(
     child2_computed_other_mixin = driver.find_element(By.ID, "child2-other-mixin")
     assert child2_computed_other_mixin.text == "other_mixin"
 
-    child2_computed_child_mixin = driver.find_element(By.ID, "child2-child-mixin")
+    child2_computed_child_mixin = driver.find_element(
+        By.ID, "child2-computed-child-mixin"
+    )
     assert child2_computed_child_mixin.text == "computed_child_mixin"
 
     child2_base2 = driver.find_element(By.ID, "child2-base2")
@@ -343,6 +357,9 @@ def test_state_inheritance(
     child2_other_mixin = driver.find_element(By.ID, "child2-other_mixin")
     assert child2_other_mixin.text == "other_mixin"
 
+    child2_child_mixin = driver.find_element(By.ID, "child2-child-mixin")
+    assert child2_child_mixin.text == "child_mixin"
+
     # Child 3
     child3_computed_basevar = driver.find_element(By.ID, "child3-computed_basevar")
     assert child3_computed_basevar.text == "computed_basevar2"
@@ -356,7 +373,9 @@ def test_state_inheritance(
     child3_computed_childvar = driver.find_element(By.ID, "child3-computed_childvar")
     assert child3_computed_childvar.text == "computed_childvar"
 
-    child3_computed_child_mixin = driver.find_element(By.ID, "child3-child-mixin")
+    child3_computed_child_mixin = driver.find_element(
+        By.ID, "child3-computed-child-mixin"
+    )
     assert child3_computed_child_mixin.text == "computed_child_mixin"
 
     child3_child3 = driver.find_element(By.ID, "child3-child3")
@@ -368,6 +387,9 @@ def test_state_inheritance(
     child3_other_mixin = driver.find_element(By.ID, "child3-other_mixin")
     assert child3_other_mixin.text == "other_mixin"
 
+    child3_child_mixin = driver.find_element(By.ID, "child3-child-mixin")
+    assert child3_child_mixin.text == "child_mixin"
+
     child3_computed_backend_vars = driver.find_element(
         By.ID, "child3-computed_backend_vars"
     )

+ 30 - 7
reflex/state.py

@@ -359,6 +359,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     # Whether the state has ever been touched since instantiation.
     _was_touched: bool = False
 
+    # Whether this state class is a mixin and should not be instantiated.
+    _mixin: ClassVar[bool] = False
+
     # A special event handler for setting base vars.
     setvar: ClassVar[EventHandler]
 
@@ -428,17 +431,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         """
         return [
             v
-            for mixin in cls.__mro__
-            if mixin is cls or not issubclass(mixin, (BaseState, ABC))
+            for mixin in cls._mixins() + [cls]
             for v in mixin.__dict__.values()
             if isinstance(v, ComputedVar)
         ]
 
     @classmethod
-    def __init_subclass__(cls, **kwargs):
+    def __init_subclass__(cls, mixin: bool = False, **kwargs):
         """Do some magic for the subclass initialization.
 
         Args:
+            mixin: Whether the subclass is a mixin and should not be initialized.
             **kwargs: The kwargs to pass to the pydantic init_subclass method.
 
         Raises:
@@ -447,6 +450,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         from reflex.utils.exceptions import StateValueError
 
         super().__init_subclass__(**kwargs)
+
+        cls._mixin = mixin
+        if mixin:
+            return
+
         # Event handlers should not shadow builtin state methods.
         cls._check_overridden_methods()
         # Computed vars should not shadow builtin state props.
@@ -618,8 +626,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         return [
             mixin
             for mixin in cls.__mro__
-            if not issubclass(mixin, (BaseState, ABC))
-            and mixin not in [pydantic.BaseModel, Base]
+            if (
+                mixin not in [pydantic.BaseModel, Base, cls]
+                and issubclass(mixin, BaseState)
+                and mixin._mixin is True
+            )
         ]
 
     @classmethod
@@ -742,7 +753,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         parent_states = [
             base
             for base in cls.__bases__
-            if types._issubclass(base, BaseState) and base is not BaseState
+            if issubclass(base, BaseState) and base is not BaseState and not base._mixin
         ]
         assert len(parent_states) < 2, "Only one parent state is allowed."
         return parent_states[0] if len(parent_states) == 1 else None  # type: ignore
@@ -1833,7 +1844,7 @@ class OnLoadInternalState(State):
         ]
 
 
-class ComponentState(Base):
+class ComponentState(State, mixin=True):
     """Base class to allow for the creation of a state instance per component.
 
     This allows for the bundling of UI and state logic into a single class,
@@ -1875,6 +1886,18 @@ class ComponentState(Base):
     # The number of components created from this class.
     _per_component_state_instance_count: ClassVar[int] = 0
 
+    @classmethod
+    def __init_subclass__(cls, mixin: bool = False, **kwargs):
+        """Overwrite mixin default to True.
+
+        Args:
+            mixin: Whether the subclass is a mixin and should not be initialized.
+            **kwargs: The kwargs to pass to the pydantic init_subclass method.
+        """
+        if ComponentState in cls.__bases__:
+            mixin = True
+        super().__init_subclass__(mixin=mixin, **kwargs)
+
     @classmethod
     def get_component(cls, *children, **props) -> "Component":
         """Get the component instance.