Quellcode durchsuchen

fix: async default setters break setvar (#4169)

* fix: async default setters break setvar

* fix unit test
benedikt-bartscher vor 6 Monaten
Ursprung
Commit
c0ed8b7d91
3 geänderte Dateien mit 30 neuen und 1 gelöschten Zeilen
  1. 11 1
      reflex/state.py
  2. 18 0
      tests/units/test_state.py
  3. 1 0
      tests/units/utils/test_format.py

+ 11 - 1
reflex/state.py

@@ -220,6 +220,7 @@ class EventHandlerSetVar(EventHandler):
         Raises:
             AttributeError: If the given Var name does not exist on the state.
             EventHandlerValueError: If the given Var name is not a str
+            NotImplementedError: If the setter for the given Var is async
         """
         from reflex.utils.exceptions import EventHandlerValueError
 
@@ -228,11 +229,20 @@ class EventHandlerSetVar(EventHandler):
                 raise EventHandlerValueError(
                     f"Var name must be passed as a string, got {args[0]!r}"
                 )
+
+            handler = getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None)
+
             # Check that the requested Var setter exists on the State at compile time.
-            if getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) is None:
+            if handler is None:
                 raise AttributeError(
                     f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`"
                 )
+
+            if asyncio.iscoroutinefunction(handler.fn):
+                raise NotImplementedError(
+                    f"Setter for {args[0]} is async, which is not supported."
+                )
+
         return super().__call__(*args)
 
 

+ 18 - 0
tests/units/test_state.py

@@ -106,6 +106,7 @@ class TestState(BaseState):
     fig: Figure = Figure()
     dt: datetime.datetime = datetime.datetime.fromisoformat("1989-11-09T18:53:00+01:00")
     _backend: int = 0
+    asynctest: int = 0
 
     @ComputedVar
     def sum(self) -> float:
@@ -129,6 +130,14 @@ class TestState(BaseState):
         """Do something."""
         pass
 
+    async def set_asynctest(self, value: int):
+        """Set the asynctest value. Intentionally overwrite the default setter with an async one.
+
+        Args:
+            value: The new value.
+        """
+        self.asynctest = value
+
 
 class ChildState(TestState):
     """A child state fixture."""
@@ -313,6 +322,7 @@ def test_class_vars(test_state):
         "upper",
         "fig",
         "dt",
+        "asynctest",
     }
 
 
@@ -733,6 +743,7 @@ def test_reset(test_state, child_state):
         "mapping",
         "dt",
         "_backend",
+        "asynctest",
     }
 
     # The dirty vars should be reset.
@@ -3179,6 +3190,13 @@ async def test_setvar(mock_app: rx.App, token: str):
         TestState.setvar(42, 42)
 
 
+@pytest.mark.asyncio
+async def test_setvar_async_setter():
+    """Test that overridden async setters raise Exception when used with setvar."""
+    with pytest.raises(NotImplementedError):
+        TestState.setvar("asynctest", 42)
+
+
 @pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
 @pytest.mark.parametrize(
     "expiration_kwargs, expected_values",

+ 1 - 0
tests/units/utils/test_format.py

@@ -601,6 +601,7 @@ formatted_router = {
                     "sum": 3.14,
                     "upper": "",
                     "router": formatted_router,
+                    "asynctest": 0,
                 },
                 ChildState.get_full_name(): {
                     "count": 23,