Bladeren bron

Event Handlers should not shadow base state methods (#1543)

Elijah Ahianyo 1 jaar geleden
bovenliggende
commit
2fa087a0fa
4 gewijzigde bestanden met toevoegingen van 76 en 25 verwijderingen
  1. 1 1
      reflex/middleware/hydrate_middleware.py
  2. 51 14
      reflex/state.py
  3. 1 1
      reflex/testing.py
  4. 23 9
      tests/test_state.py

+ 1 - 1
reflex/middleware/hydrate_middleware.py

@@ -40,7 +40,7 @@ class HydrateMiddleware(Middleware):
         setattr(state, constants.IS_HYDRATED, False)
         delta = format.format_state({state.get_name(): state.dict()})
         # since a full dict was captured, clean any dirtiness
-        state.clean()
+        state._clean()
 
         # Get the route for on_load events.
         route = event.router_data.get(constants.RouteVar.PATH, "")

+ 51 - 14
reflex/state.py

@@ -105,7 +105,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         # Setup the substates.
         for substate in self.get_substates():
             self.substates[substate.get_name()] = substate(parent_state=self)
-
         # Convert the event handlers to functions.
         for name, event_handler in self.event_handlers.items():
             fn = functools.partial(event_handler.fn, self)
@@ -154,7 +153,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             if types._issubclass(field.type_, Union[List, Dict]):
                 setattr(self, field.name, value_in_rx_data)
 
-        self.clean()
+        self._clean()
 
     def _reassign_field(self, field_name: str):
         """Reassign the given field.
@@ -186,6 +185,8 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             **kwargs: The kwargs to pass to the pydantic init_subclass method.
         """
         super().__init_subclass__(**kwargs)
+        # Event handlers should not shadow builtin state methods.
+        cls._check_overridden_methods()
 
         # Get the parent vars.
         parent_state = cls.get_parent_state()
@@ -238,6 +239,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             cls.event_handlers[name] = handler
             setattr(cls, name, handler)
 
+    @classmethod
+    def _check_overridden_methods(cls):
+        """Check for shadow methods and raise error if any.
+
+        Raises:
+            NameError: When an event handler shadows an inbuilt state method.
+        """
+        overridden_methods = set()
+        state_base_functions = cls._get_base_functions()
+        for name, method in inspect.getmembers(cls, inspect.isfunction):
+            # Check if the method is overridden and not a dunder method
+            if (
+                not name.startswith("__")
+                and method.__name__ in state_base_functions
+                and state_base_functions[method.__name__] != method
+            ):
+                overridden_methods.add(method.__name__)
+
+        for method_name in overridden_methods:
+            raise NameError(
+                f"The event handler name `{method_name}` shadows a builtin State method; use a different name instead"
+            )
+
     @classmethod
     def get_skip_vars(cls) -> Set[str]:
         """Get the vars to skip when serializing.
@@ -444,6 +468,19 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             field.required = False
             field.default = default_value
 
+    @staticmethod
+    def _get_base_functions() -> Dict[str, FunctionType]:
+        """Get all functions of the state class excluding dunder methods.
+
+        Returns:
+            The functions of rx.State class as a dict.
+        """
+        return {
+            func[0]: func[1]
+            for func in inspect.getmembers(State, predicate=inspect.isfunction)
+            if not func[0].startswith("__")
+        }
+
     def get_token(self) -> str:
         """Return the token of the client associated with this state.
 
@@ -598,7 +635,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         if types.is_backend_variable(name) and name != "_backend_vars":
             self._backend_vars.__setitem__(name, value)
             self.dirty_vars.add(name)
-            self.mark_dirty()
+            self._mark_dirty()
             return
 
         # Make sure lists and dicts are converted to ReflexList, ReflexDict and ReflexSet.
@@ -611,12 +648,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         # Add the var to the dirty list.
         if name in self.vars or name in self.computed_var_dependencies:
             self.dirty_vars.add(name)
-            self.mark_dirty()
+            self._mark_dirty()
 
         # For now, handle router_data updates as a special case
         if name == constants.ROUTER_DATA:
             self.dirty_vars.add(name)
-            self.mark_dirty()
+            self._mark_dirty()
             # propagate router_data updates down the state tree
             for substate in self.substates.values():
                 setattr(substate, name, value)
@@ -685,7 +722,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         )
 
         # Clean the state before processing the event.
-        self.clean()
+        self._clean()
 
         # Run the event generator and return state updates.
         async for events, final in event_iter:
@@ -699,7 +736,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             yield StateUpdate(delta=delta, events=events, final=final)
 
             # Clean the state to prepare for the next event.
-            self.clean()
+            self._clean()
 
     async def _process_event(
         self, handler: EventHandler, state: State, payload: Dict
@@ -806,7 +843,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
 
         # Apply dirty variables down into substates
         self.dirty_vars.update(self._always_dirty_computed_vars())
-        self.mark_dirty()
+        self._mark_dirty()
 
         # Return the dirty vars for this instance, any cached/dependent computed vars,
         # and always dirty computed vars (cache=False)
@@ -835,7 +872,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         # Return the delta.
         return delta
 
-    def mark_dirty(self):
+    def _mark_dirty(self):
         """Mark the substate and all parent states as dirty."""
         state_name = self.get_name()
         if (
@@ -843,7 +880,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             and state_name not in self.parent_state.dirty_substates
         ):
             self.parent_state.dirty_substates.add(self.get_name())
-            self.parent_state.mark_dirty()
+            self.parent_state._mark_dirty()
 
         # have to mark computed vars dirty to allow access to newly computed
         # values within the same ComputedVar function
@@ -856,13 +893,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
                 self.dirty_substates.add(substate_name)
                 substate = substates[substate_name]
                 substate.dirty_vars.add(var)
-                substate.mark_dirty()
+                substate._mark_dirty()
 
-    def clean(self):
+    def _clean(self):
         """Reset the dirty vars."""
         # Recursively clean the substates.
         for substate in self.dirty_substates:
-            self.substates[substate].clean()
+            self.substates[substate]._clean()
 
         # Clean this state.
         self.dirty_vars = set()
@@ -882,7 +919,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             # Apply dirty variables down into substates to allow never-cached ComputedVar to
             # trigger recalculation of dependent vars
             self.dirty_vars.update(self._always_dirty_computed_vars())
-            self.mark_dirty()
+            self._mark_dirty()
 
         base_vars = {
             prop_name: self.get_value(getattr(self, prop_name))

+ 1 - 1
reflex/testing.py

@@ -365,7 +365,7 @@ class AppHarness:
             delta = state.get_delta()
             if delta:
                 update = reflex.state.StateUpdate(delta=delta, events=[], final=True)
-                state.clean()
+                state._clean()
                 # Emit the event.
                 pending.append(
                     event_ns.emit(

+ 23 - 9
tests/test_state.py

@@ -498,7 +498,7 @@ def test_set_dirty_var(test_state):
     assert test_state.dirty_vars == {"num1", "num2", "sum"}
 
     # Cleaning the state should remove all dirty vars.
-    test_state.clean()
+    test_state._clean()
     assert test_state.dirty_vars == set()
 
 
@@ -524,7 +524,7 @@ def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_st
     assert child_state.dirty_substates == set()
 
     # Cleaning the parent state should remove the dirty substate.
-    test_state.clean()
+    test_state._clean()
     assert test_state.dirty_substates == set()
     assert child_state.dirty_vars == set()
 
@@ -534,7 +534,7 @@ def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_st
     assert test_state.dirty_substates == {"child_state"}
 
     # Cleaning the middle state should keep the parent state dirty.
-    child_state.clean()
+    child_state._clean()
     assert test_state.dirty_substates == {"child_state"}
     assert child_state.dirty_substates == set()
     assert grandchild_state.dirty_vars == set()
@@ -626,7 +626,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
         "test_state": {"sum": 3.14, "upper": ""},
         "test_state.child_state": {"value": "HI", "count": 24},
     }
-    test_state.clean()
+    test_state._clean()
 
     # Test with the granchild state.
     assert grandchild_state.value2 == ""
@@ -1044,23 +1044,23 @@ def test_computed_var_cached_depends_on_non_cached():
     cs = ComputedState()
     assert cs.dirty_vars == set()
     assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
-    cs.clean()
+    cs._clean()
     assert cs.dirty_vars == set()
     assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
-    cs.clean()
+    cs._clean()
     assert cs.dirty_vars == set()
     cs.v = 1
     assert cs.dirty_vars == {"v", "comp_v", "dep_v", "no_cache_v"}
     assert cs.get_delta() == {
         cs.get_name(): {"v": 1, "no_cache_v": 1, "dep_v": 1, "comp_v": 1}
     }
-    cs.clean()
+    cs._clean()
     assert cs.dirty_vars == set()
     assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
-    cs.clean()
+    cs._clean()
     assert cs.dirty_vars == set()
     assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
-    cs.clean()
+    cs._clean()
     assert cs.dirty_vars == set()
 
 
@@ -1191,3 +1191,17 @@ def test_setattr_of_mutable_types(mutable_state):
     assert isinstance(hashmap["mod_third_key"], ReflexDict)
 
     assert isinstance(test_set, ReflexSet)
+
+
+def test_error_on_state_method_shadow():
+    """Test that an error is thrown when an event handler shadows a state method."""
+    with pytest.raises(NameError) as err:
+
+        class InvalidTest(rx.State):
+            def reset(self):
+                pass
+
+    assert (
+        err.value.args[0]
+        == f"The event handler name `reset` shadows a builtin State method; use a different name instead"
+    )