Browse Source

vars: unbox EventHandler and functools.partial for dep analysis (#1305)

When calculating the variable dependencies of a cached_var, reach into objects
with .func or .fn attributes and perform analysis on those.

Fix #1303
Masen Furer 1 year ago
parent
commit
20e2a25c9a
2 changed files with 48 additions and 2 deletions
  1. 10 2
      reflex/vars.py
  2. 38 0
      tests/test_state.py

+ 10 - 2
reflex/vars.py

@@ -869,10 +869,18 @@ class ComputedVar(Var, property):
                 obj = cast(FunctionType, self.fget)
             else:
                 return set()
-        if not obj.__code__.co_varnames:
+        with contextlib.suppress(AttributeError):
+            # unbox functools.partial
+            obj = cast(FunctionType, obj.func)  # type: ignore
+        with contextlib.suppress(AttributeError):
+            # unbox EventHandler
+            obj = cast(FunctionType, obj.fn)  # type: ignore
+
+        try:
+            self_name = obj.__code__.co_varnames[0]
+        except (AttributeError, IndexError):
             # cannot reference self if method takes no args
             return set()
-        self_name = obj.__code__.co_varnames[0]
         self_is_top_of_stack = False
         for instruction in dis.get_instructions(obj):
             if instruction.opname == "LOAD_FAST" and instruction.argval == self_name:

+ 38 - 0
tests/test_state.py

@@ -1,3 +1,4 @@
+import functools
 from typing import Dict, List
 
 import pytest
@@ -1089,6 +1090,43 @@ def test_computed_var_depends_on_parent_non_cached():
     assert counter == 6
 
 
+@pytest.mark.parametrize("use_partial", [True, False])
+def test_cached_var_depends_on_event_handler(use_partial: bool):
+    """A cached_var that calls an event handler calculates deps correctly.
+
+    Args:
+        use_partial: if true, replace the EventHandler with functools.partial
+    """
+    counter = 0
+
+    class HandlerState(State):
+        x: int = 42
+
+        def handler(self):
+            self.x = self.x + 1
+
+        @rx.cached_var
+        def cached_x_side_effect(self) -> int:
+            self.handler()
+            nonlocal counter
+            counter += 1
+            return counter
+
+    if use_partial:
+        HandlerState.handler = functools.partial(HandlerState.handler.fn)
+        assert isinstance(HandlerState.handler, functools.partial)
+    else:
+        assert isinstance(HandlerState.handler, EventHandler)
+
+    s = HandlerState()
+    assert "cached_x_side_effect" in s.computed_var_dependencies["x"]
+    assert s.cached_x_side_effect == 1
+    assert s.x == 43
+    s.handler()
+    assert s.cached_x_side_effect == 2
+    assert s.x == 45
+
+
 def test_backend_method():
     """A method with leading underscore should be callable from event handler."""