Explorar el Código

StateProxy rebinds functools.partial and methods that are bound to the proxied State (#1853)

Masen Furer hace 1 año
padre
commit
83d7a044fe
Se han modificado 2 ficheros con 77 adiciones y 2 borrados
  1. 12 1
      reflex/state.py
  2. 65 1
      tests/test_state.py

+ 12 - 1
reflex/state.py

@@ -12,7 +12,7 @@ import urllib.parse
 import uuid
 from abc import ABC, abstractmethod
 from collections import defaultdict
-from types import FunctionType
+from types import FunctionType, MethodType
 from typing import (
     Any,
     AsyncIterator,
@@ -1177,6 +1177,17 @@ class StateProxy(wrapt.ObjectProxy):
                 state=self,  # type: ignore
                 field_name=value._self_field_name,
             )
+        if isinstance(value, functools.partial) and value.args[0] is self.__wrapped__:
+            # Rebind event handler to the proxy instance
+            value = functools.partial(
+                value.func,
+                self,
+                *value.args[1:],
+                **value.keywords,
+            )
+        if isinstance(value, MethodType) and value.__self__ is self.__wrapped__:
+            # Rebind methods to the proxy instance
+            value = type(value)(value.__func__, self)  # type: ignore
         return value
 
     def __setattr__(self, name: str, value: Any) -> None:

+ 65 - 1
tests/test_state.py

@@ -1699,6 +1699,14 @@ class BackgroundTaskState(State):
             # Even nested access to mutables raises an exception.
             self.dict_list["foo"].append(42)
 
+        with pytest.raises(ImmutableStateError):
+            # Direct calling another handler that modifies state raises an exception.
+            self.other()
+
+        with pytest.raises(ImmutableStateError):
+            # Calling other methods that modify state raises an exception.
+            self._private_method()
+
         # wait for some other event to happen
         while len(self.order) == 1:
             await asyncio.sleep(0.01)
@@ -1707,6 +1715,22 @@ class BackgroundTaskState(State):
 
         async with self:
             self.order.append("background_task:stop")
+            self.other()  # direct calling event handlers works in context
+            self._private_method()
+
+    @rx.background
+    async def background_task_reset(self):
+        """A background task that resets the state."""
+        with pytest.raises(ImmutableStateError):
+            # Resetting the state should be explicitly blocked.
+            self.reset()
+
+        async with self:
+            self.order.append("foo")
+            self.reset()
+        assert not self.order
+        async with self:
+            self.order.append("reset")
 
     @rx.background
     async def background_task_generator(self):
@@ -1721,6 +1745,10 @@ class BackgroundTaskState(State):
         """Some other event that updates the state."""
         self.order.append("other")
 
+    def _private_method(self):
+        """Some private method that updates the state."""
+        self.order.append("private")
+
     async def bad_chain1(self):
         """Test that a background task cannot be chained."""
         await self.background_task()
@@ -1755,7 +1783,6 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
     ):
         # background task returns empty update immediately
         assert update == StateUpdate()
-    assert len(mock_app.background_tasks) == 1
 
     # wait for the coroutine to start
     await asyncio.sleep(0.5 if CI else 0.1)
@@ -1795,6 +1822,43 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
         "background_task:start",
         "other",
         "background_task:stop",
+        "other",
+        "private",
+    ]
+
+
+@pytest.mark.asyncio
+async def test_background_task_reset(mock_app: rx.App, token: str):
+    """Test that a background task calling reset is protected by the state proxy.
+
+    Args:
+        mock_app: An app that will be returned by `get_app()`
+        token: A token.
+    """
+    router_data = {"query": {}}
+    mock_app.state_manager.state = mock_app.state = BackgroundTaskState
+    async for update in rx.app.process(  # type: ignore
+        mock_app,
+        Event(
+            token=token,
+            name=f"{BackgroundTaskState.get_name()}.background_task_reset",
+            router_data=router_data,
+            payload={},
+        ),
+        sid="",
+        headers={},
+        client_ip="",
+    ):
+        # background task returns empty update immediately
+        assert update == StateUpdate()
+
+    # Explicit wait for background tasks
+    for task in tuple(mock_app.background_tasks):
+        await task
+    assert not mock_app.background_tasks
+
+    assert (await mock_app.state_manager.get_state(token)).order == [
+        "reset",
     ]