瀏覽代碼

Track var dependencies in comprehensions and nested functions (#1728)

Masen Furer 1 年之前
父節點
當前提交
b44c2176e0
共有 2 個文件被更改,包括 103 次插入8 次删除
  1. 34 8
      reflex/vars.py
  2. 69 0
      tests/test_state.py

+ 34 - 8
reflex/vars.py

@@ -7,7 +7,7 @@ import json
 import random
 import string
 from abc import ABC
-from types import FunctionType
+from types import CodeType, FunctionType
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -983,16 +983,19 @@ class ComputedVar(Var, property):
     def deps(
         self,
         objclass: Type,
-        obj: FunctionType | None = None,
+        obj: FunctionType | CodeType | None = None,
+        self_name: Optional[str] = None,
     ) -> set[str]:
         """Determine var dependencies of this ComputedVar.
 
         Save references to attributes accessed on "self".  Recursively called
-        when the function makes a method call on "self".
+        when the function makes a method call on "self" or define comprehensions
+        or nested functions that may reference "self".
 
         Args:
             objclass: the class obj this ComputedVar is attached to.
             obj: the object to disassemble (defaults to the fget function).
+            self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions.
 
         Returns:
             A set of variable names accessed by the given obj.
@@ -1010,25 +1013,48 @@ class ComputedVar(Var, property):
             # unbox EventHandler
             obj = cast(FunctionType, obj.fn)  # type: ignore
 
-        try:
-            self_name = obj.__code__.co_varnames[0]
-        except (AttributeError, IndexError):
-            # cannot reference self if method takes no args
+        if self_name is None and isinstance(obj, FunctionType):
+            try:
+                # the first argument to the function is the name of "self" arg
+                self_name = obj.__code__.co_varnames[0]
+            except (AttributeError, IndexError):
+                self_name = None
+        if self_name is None:
+            # cannot reference attributes on self if method takes no args
             return set()
         self_is_top_of_stack = False
         for instruction in dis.get_instructions(obj):
-            if instruction.opname == "LOAD_FAST" and instruction.argval == self_name:
+            if (
+                instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
+                and instruction.argval == self_name
+            ):
+                # bytecode loaded the class instance to the top of stack, next load instruction
+                # is referencing an attribute on self
                 self_is_top_of_stack = True
                 continue
             if self_is_top_of_stack and instruction.opname == "LOAD_ATTR":
+                # direct attribute access
                 d.add(instruction.argval)
             elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD":
+                # method call on self
                 d.update(
                     self.deps(
                         objclass=objclass,
                         obj=getattr(objclass, instruction.argval),
                     )
                 )
+            elif instruction.opname == "LOAD_CONST" and isinstance(
+                instruction.argval, CodeType
+            ):
+                # recurse into nested functions / comprehensions, which can reference
+                # instance attributes from the outer scope
+                d.update(
+                    self.deps(
+                        objclass=objclass,
+                        obj=instruction.argval,
+                        self_name=self_name,
+                    )
+                )
             self_is_top_of_stack = False
         return d
 

+ 69 - 0
tests/test_state.py

@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import functools
 from typing import Dict, List
 
@@ -1149,6 +1151,73 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
     assert s.x == 45
 
 
+def test_computed_var_dependencies():
+    """Test that a ComputedVar correctly tracks its dependencies."""
+
+    class ComputedState(State):
+        v: int = 0
+        w: int = 0
+        x: int = 0
+        y: List[int] = [1, 2, 3]
+        _z: List[int] = [1, 2, 3]
+
+        @rx.cached_var
+        def comp_v(self) -> int:
+            """Direct access.
+
+            Returns:
+                The value of self.v.
+            """
+            return self.v
+
+        @rx.cached_var
+        def comp_w(self):
+            """Nested lambda.
+
+            Returns:
+                A lambda that returns the value of self.w.
+            """
+            return lambda: self.w
+
+        @rx.cached_var
+        def comp_x(self):
+            """Nested function.
+
+            Returns:
+                A function that returns the value of self.x.
+            """
+
+            def _():
+                return self.x
+
+            return _
+
+        @rx.cached_var
+        def comp_y(self) -> List[int]:
+            """Comprehension iterating over attribute.
+
+            Returns:
+                A list of the values of self.y.
+            """
+            return [round(y) for y in self.y]
+
+        @rx.cached_var
+        def comp_z(self) -> List[bool]:
+            """Comprehension accesses attribute.
+
+            Returns:
+                A list of whether the values 0-4 are in self._z.
+            """
+            return [z in self._z for z in range(5)]
+
+    cs = ComputedState()
+    assert cs.computed_var_dependencies["v"] == {"comp_v"}
+    assert cs.computed_var_dependencies["w"] == {"comp_w"}
+    assert cs.computed_var_dependencies["x"] == {"comp_x"}
+    assert cs.computed_var_dependencies["y"] == {"comp_y"}
+    assert cs.computed_var_dependencies["_z"] == {"comp_z"}
+
+
 def test_backend_method():
     """A method with leading underscore should be callable from event handler."""