瀏覽代碼

DRAFT PR - Added code for computed backend vars (#2540)

* added code for computed backend vars

* fixed formatting issues

* fix small bug

* fixes ruff issue

* fixed black issue

* augment test for backend computed var

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
wassaf shahzad 1 年之前
父節點
當前提交
0a18eaa28b
共有 2 個文件被更改,包括 60 次插入8 次删除
  1. 40 7
      reflex/state.py
  2. 20 1
      tests/test_state.py

+ 40 - 7
reflex/state.py

@@ -294,7 +294,13 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         self._init_event_handlers()
 
         # Create a fresh copy of the backend variables for this instance
-        self._backend_vars = copy.deepcopy(self.backend_vars)
+        self._backend_vars = copy.deepcopy(
+            {
+                name: item
+                for name, item in self.backend_vars.items()
+                if name not in self.computed_vars
+            }
+        )
 
     def _init_event_handlers(self, state: BaseState | None = None):
         """Initialize event handlers.
@@ -330,6 +336,21 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         """
         return f"{self.__class__.__name__}({self.dict()})"
 
+    @classmethod
+    def _get_computed_vars(cls) -> list[ComputedVar]:
+        """Helper function to get all computed vars of a instance.
+
+        Returns:
+            A list of computed vars.
+        """
+        return [
+            v
+            for mixin in cls.__mro__
+            if mixin is cls or not issubclass(mixin, (BaseState, ABC))
+            for v in mixin.__dict__.values()
+            if isinstance(v, ComputedVar)
+        ]
+
     @classmethod
     def __init_subclass__(cls, **kwargs):
         """Do some magic for the subclass initialization.
@@ -376,6 +397,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             # Track this new subclass in the parent state's subclasses set.
             parent_state.class_subclasses.add(cls)
 
+        # Get computed vars.
+        computed_vars = cls._get_computed_vars()
+
         new_backend_vars = {
             name: value
             for name, value in cls.__dict__.items()
@@ -383,9 +407,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             and name not in RESERVED_BACKEND_VAR_NAMES
             and name not in cls.inherited_backend_vars
             and not isinstance(value, FunctionType)
+            and not isinstance(value, ComputedVar)
         }
 
-        cls.backend_vars = {**cls.inherited_backend_vars, **new_backend_vars}
+        # Get backend computed vars
+        backend_computed_vars = {
+            v._var_name: v._var_set_state(cls)
+            for v in computed_vars
+            if types.is_backend_variable(v._var_name, cls)
+            and v._var_name not in cls.inherited_backend_vars
+        }
+
+        cls.backend_vars = {
+            **cls.inherited_backend_vars,
+            **new_backend_vars,
+            **backend_computed_vars,
+        }
 
         # Set the base and computed vars.
         cls.base_vars = {
@@ -395,11 +432,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             for f in cls.get_fields().values()
             if f.name not in cls.get_skip_vars()
         }
-        cls.computed_vars = {
-            v._var_name: v._var_set_state(cls)
-            for v in cls.__dict__.values()
-            if isinstance(v, ComputedVar)
-        }
+        cls.computed_vars = {v._var_name: v._var_set_state(cls) for v in computed_vars}
         cls.vars = {
             **cls.inherited_vars,
             **cls.base_vars,

+ 20 - 1
tests/test_state.py

@@ -955,6 +955,24 @@ class InterdependentState(BaseState):
         """
         return self.v1x2 * 2  # type: ignore
 
+    @rx.cached_var
+    def _v3(self) -> int:
+        """Depends on backend var _v2.
+
+        Returns:
+            The value of the backend variable.
+        """
+        return self._v2
+
+    @rx.cached_var
+    def v3x2(self) -> int:
+        """Depends on ComputedVar _v3.
+
+        Returns:
+            ComputedVar _v3 multiplied by 2
+        """
+        return self._v3 * 2
+
 
 @pytest.fixture
 def interdependent_state() -> BaseState:
@@ -1003,8 +1021,9 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
     """
     interdependent_state._v2 = 2
     assert interdependent_state.get_delta() == {
-        interdependent_state.get_full_name(): {"v2x2": 4},
+        interdependent_state.get_full_name(): {"v2x2": 4, "v3x2": 4},
     }
+    assert "_v3" in InterdependentState.backend_vars
 
 
 def test_per_state_backend_var(interdependent_state):