Browse Source

Fix get_parent_state and get_root_state when using `mixin=True` (#4976)

If a state class inherits directly from a mixin class, find the proper parent
class by looking through the `mro()` until a non-mixin class is found.
Masen Furer 2 months ago
parent
commit
d45e569737
2 changed files with 41 additions and 1 deletions
  1. 8 1
      reflex/state.py
  2. 33 0
      tests/units/test_state.py

+ 8 - 1
reflex/state.py

@@ -905,7 +905,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         ]
         if len(parent_states) >= 2:
             raise ValueError(f"Only one parent state is allowed {parent_states}.")
-        return parent_states[0] if len(parent_states) == 1 else None
+        # The first non-mixin state in the mro is our parent.
+        for base in cls.mro()[1:]:
+            if base._mixin or not issubclass(base, BaseState):
+                continue
+            if base is BaseState:
+                break
+            return base
+        return None  # No known parent
 
     @classmethod
     @functools.lru_cache()

+ 33 - 0
tests/units/test_state.py

@@ -3424,6 +3424,18 @@ class ChildUsesMixinState(UsesMixinState):
     pass
 
 
+class ChildMixinState(ChildUsesMixinState, mixin=True):
+    """A mixin state that inherits from a concrete state that uses mixins."""
+
+    pass
+
+
+class GrandchildUsesMixinState(ChildMixinState):
+    """A grandchild state that uses the mixin state."""
+
+    pass
+
+
 def test_mixin_state() -> None:
     """Test that a mixin state works correctly."""
     assert "num" in UsesMixinState.base_vars
@@ -3438,6 +3450,9 @@ def test_mixin_state() -> None:
         is not UsesMixinState.backend_vars["_backend_no_default"]
     )
 
+    assert UsesMixinState.get_parent_state() == State
+    assert UsesMixinState.get_root_state() == State
+
 
 def test_child_mixin_state() -> None:
     """Test that mixin vars are only applied to the highest state in the hierarchy."""
@@ -3447,6 +3462,24 @@ def test_child_mixin_state() -> None:
     assert "computed" in ChildUsesMixinState.inherited_vars
     assert "computed" not in ChildUsesMixinState.computed_vars
 
+    assert ChildUsesMixinState.get_parent_state() == UsesMixinState
+    assert ChildUsesMixinState.get_root_state() == State
+
+
+def test_grandchild_mixin_state() -> None:
+    """Test that a mixin can inherit from a concrete state class."""
+    assert "num" in GrandchildUsesMixinState.inherited_vars
+    assert "num" not in GrandchildUsesMixinState.base_vars
+
+    assert "computed" in GrandchildUsesMixinState.inherited_vars
+    assert "computed" not in GrandchildUsesMixinState.computed_vars
+
+    assert ChildMixinState.get_parent_state() == ChildUsesMixinState
+    assert ChildMixinState.get_root_state() == State
+
+    assert GrandchildUsesMixinState.get_parent_state() == ChildUsesMixinState
+    assert GrandchildUsesMixinState.get_root_state() == State
+
 
 def test_assignment_to_undeclared_vars():
     """Test that an attribute error is thrown when undeclared vars are set."""