Przeglądaj źródła

fix: do not allow instantiation of State mixins (#4347)

* fix: do not allow instantiation of State mixins
Closes #4343

* improve error message for ComponentState mixins

* fix typo

Co-authored-by: Masen Furer <m_github@0x26.net>

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
benedikt-bartscher 6 miesięcy temu
rodzic
commit
1f9a17539c

+ 22 - 0
reflex/state.py

@@ -87,6 +87,7 @@ from reflex.utils.exceptions import (
     ImmutableStateError,
     ImmutableStateError,
     InvalidStateManagerMode,
     InvalidStateManagerMode,
     LockExpiredError,
     LockExpiredError,
+    ReflexRuntimeError,
     SetUndefinedStateVarError,
     SetUndefinedStateVarError,
     StateSchemaMismatchError,
     StateSchemaMismatchError,
 )
 )
@@ -387,6 +388,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 "State classes should not be instantiated directly in a Reflex app. "
                 "State classes should not be instantiated directly in a Reflex app. "
                 "See https://reflex.dev/docs/state/ for further information."
                 "See https://reflex.dev/docs/state/ for further information."
             )
             )
+        if type(self)._mixin:
+            raise ReflexRuntimeError(
+                f"{type(self).__name__} is a state mixin and cannot be instantiated directly."
+            )
         kwargs["parent_state"] = parent_state
         kwargs["parent_state"] = parent_state
         super().__init__()
         super().__init__()
         for name, value in kwargs.items():
         for name, value in kwargs.items():
@@ -2367,6 +2372,23 @@ class ComponentState(State, mixin=True):
     # The number of components created from this class.
     # The number of components created from this class.
     _per_component_state_instance_count: ClassVar[int] = 0
     _per_component_state_instance_count: ClassVar[int] = 0
 
 
+    def __init__(self, *args, **kwargs):
+        """Do not allow direct initialization of the ComponentState.
+
+        Args:
+            *args: The args to pass to the State init method.
+            **kwargs: The kwargs to pass to the State init method.
+
+        Raises:
+            ReflexRuntimeError: If the ComponentState is initialized directly.
+        """
+        if type(self)._mixin:
+            raise ReflexRuntimeError(
+                f"{ComponentState.__name__} {type(self).__name__} is not meant to be initialized directly. "
+                + "Use the `create` method to create a new instance and access the state via the `State` attribute."
+            )
+        super().__init__(*args, **kwargs)
+
     @classmethod
     @classmethod
     def __init_subclass__(cls, mixin: bool = True, **kwargs):
     def __init_subclass__(cls, mixin: bool = True, **kwargs):
         """Overwrite mixin default to True.
         """Overwrite mixin default to True.

+ 21 - 0
tests/units/components/test_component_state.py

@@ -1,7 +1,10 @@
 """Ensure that Components returned by ComponentState.create have independent State classes."""
 """Ensure that Components returned by ComponentState.create have independent State classes."""
 
 
+import pytest
+
 import reflex as rx
 import reflex as rx
 from reflex.components.base.bare import Bare
 from reflex.components.base.bare import Bare
+from reflex.utils.exceptions import ReflexRuntimeError
 
 
 
 
 def test_component_state():
 def test_component_state():
@@ -40,3 +43,21 @@ def test_component_state():
     assert len(cs2.children) == 1
     assert len(cs2.children) == 1
     assert cs2.children[0].render() == Bare.create("b").render()
     assert cs2.children[0].render() == Bare.create("b").render()
     assert cs2.id == "b"
     assert cs2.id == "b"
+
+
+def test_init_component_state() -> None:
+    """Ensure that ComponentState subclasses cannot be instantiated directly."""
+
+    class CS(rx.ComponentState):
+        @classmethod
+        def get_component(cls, *children, **props):
+            return rx.el.div()
+
+    with pytest.raises(ReflexRuntimeError):
+        CS()
+
+    class SubCS(CS):
+        pass
+
+    with pytest.raises(ReflexRuntimeError):
+        SubCS()

+ 17 - 1
tests/units/test_state.py

@@ -43,7 +43,7 @@ from reflex.state import (
 )
 )
 from reflex.testing import chdir
 from reflex.testing import chdir
 from reflex.utils import format, prerequisites, types
 from reflex.utils import format, prerequisites, types
-from reflex.utils.exceptions import SetUndefinedStateVarError
+from reflex.utils.exceptions import ReflexRuntimeError, SetUndefinedStateVarError
 from reflex.utils.format import json_dumps
 from reflex.utils.format import json_dumps
 from reflex.vars.base import Var, computed_var
 from reflex.vars.base import Var, computed_var
 from tests.units.states.mutation import MutableSQLAModel, MutableTestState
 from tests.units.states.mutation import MutableSQLAModel, MutableTestState
@@ -3441,3 +3441,19 @@ def test_get_value():
             "bar": "foo",
             "bar": "foo",
         }
         }
     }
     }
+
+
+def test_init_mixin() -> None:
+    """Ensure that State mixins can not be instantiated directly."""
+
+    class Mixin(BaseState, mixin=True):
+        pass
+
+    with pytest.raises(ReflexRuntimeError):
+        Mixin()
+
+    class SubMixin(Mixin, mixin=True):
+        pass
+
+    with pytest.raises(ReflexRuntimeError):
+        SubMixin()