Pārlūkot izejas kodu

add computed backend vars (#3573)

* add computed backend vars

* finish computed backend vars, add tests

* fix token for AppHarness with redis state manager

* fix timing issues

* add unit tests for computed backend vars

* automagically mark cvs with _ prefix as backend var

* fully migrate backend computed vars

* rename is_backend_variable to is_backend_base_variable

* add integration test for implicit backend cv, adjust comments

* replace expensive backend var check at runtime

* keep stuff together

* simplify backend var check method, consistent naming, improve test typing

* fix: do not convert properties to cvs

* add test for property

* fix cached_properties with _ prefix in state cls
benedikt-bartscher 11 mēneši atpakaļ
vecāks
revīzija
b7651e214b

+ 36 - 1
integration/test_computed_vars.py

@@ -26,6 +26,16 @@ def ComputedVars():
         def count1(self) -> int:
         def count1(self) -> int:
             return self.count
             return self.count
 
 
+        # cached backend var with dep on count
+        @rx.var(cache=True, interval=15, backend=True)
+        def count1_backend(self) -> int:
+            return self.count
+
+        # same as above but implicit backend with `_` prefix
+        @rx.var(cache=True, interval=15)
+        def _count1_backend(self) -> int:
+            return self.count
+
         # explicit disabled auto_deps
         # explicit disabled auto_deps
         @rx.var(interval=15, cache=True, auto_deps=False)
         @rx.var(interval=15, cache=True, auto_deps=False)
         def count3(self) -> int:
         def count3(self) -> int:
@@ -70,6 +80,10 @@ def ComputedVars():
                 rx.text(State.count, id="count"),
                 rx.text(State.count, id="count"),
                 rx.text("count1:"),
                 rx.text("count1:"),
                 rx.text(State.count1, id="count1"),
                 rx.text(State.count1, id="count1"),
+                rx.text("count1_backend:"),
+                rx.text(State.count1_backend, id="count1_backend"),
+                rx.text("_count1_backend:"),
+                rx.text(State._count1_backend, id="_count1_backend"),
                 rx.text("count3:"),
                 rx.text("count3:"),
                 rx.text(State.count3, id="count3"),
                 rx.text(State.count3, id="count3"),
                 rx.text("depends_on_count:"),
                 rx.text("depends_on_count:"),
@@ -154,7 +168,8 @@ def token(computed_vars: AppHarness, driver: WebDriver) -> str:
     return token
     return token
 
 
 
 
-def test_computed_vars(
+@pytest.mark.asyncio
+async def test_computed_vars(
     computed_vars: AppHarness,
     computed_vars: AppHarness,
     driver: WebDriver,
     driver: WebDriver,
     token: str,
     token: str,
@@ -168,6 +183,20 @@ def test_computed_vars(
     """
     """
     assert computed_vars.app_instance is not None
     assert computed_vars.app_instance is not None
 
 
+    token = f"{token}_state.state"
+    state = (await computed_vars.get_state(token)).substates["state"]
+    assert state is not None
+    assert state.count1_backend == 0
+    assert state._count1_backend == 0
+
+    # test that backend var is not rendered
+    count1_backend = driver.find_element(By.ID, "count1_backend")
+    assert count1_backend
+    assert count1_backend.text == ""
+    _count1_backend = driver.find_element(By.ID, "_count1_backend")
+    assert _count1_backend
+    assert _count1_backend.text == ""
+
     count = driver.find_element(By.ID, "count")
     count = driver.find_element(By.ID, "count")
     assert count
     assert count
     assert count.text == "0"
     assert count.text == "0"
@@ -207,6 +236,12 @@ def test_computed_vars(
         computed_vars.poll_for_content(depends_on_count, timeout=2, exp_not_equal="0")
         computed_vars.poll_for_content(depends_on_count, timeout=2, exp_not_equal="0")
         == "1"
         == "1"
     )
     )
+    state = (await computed_vars.get_state(token)).substates["state"]
+    assert state is not None
+    assert state.count1_backend == 1
+    assert count1_backend.text == ""
+    assert state._count1_backend == 1
+    assert _count1_backend.text == ""
 
 
     mark_dirty.click()
     mark_dirty.click()
     with pytest.raises(TimeoutError):
     with pytest.raises(TimeoutError):

+ 22 - 20
reflex/state.py

@@ -305,10 +305,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     # Vars inherited by the parent state.
     # Vars inherited by the parent state.
     inherited_vars: ClassVar[Dict[str, Var]] = {}
     inherited_vars: ClassVar[Dict[str, Var]] = {}
 
 
-    # Backend vars that are never sent to the client.
+    # Backend base vars that are never sent to the client.
     backend_vars: ClassVar[Dict[str, Any]] = {}
     backend_vars: ClassVar[Dict[str, Any]] = {}
 
 
-    # Backend vars inherited
+    # Backend base vars inherited
     inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
     inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
 
 
     # The event handlers.
     # The event handlers.
@@ -344,7 +344,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     # The routing path that triggered the state
     # The routing path that triggered the state
     router_data: Dict[str, Any] = {}
     router_data: Dict[str, Any] = {}
 
 
-    # Per-instance copy of backend variable values
+    # Per-instance copy of backend base variable values
     _backend_vars: Dict[str, Any] = {}
     _backend_vars: Dict[str, Any] = {}
 
 
     # The router data for the current page
     # The router data for the current page
@@ -492,21 +492,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         new_backend_vars = {
         new_backend_vars = {
             name: value
             name: value
             for name, value in cls.__dict__.items()
             for name, value in cls.__dict__.items()
-            if types.is_backend_variable(name, cls)
-        }
-
-        # Get backend computed vars
-        backend_computed_vars = {
-            v._var_name: v._var_set_state(cls)
-            for v in computed_vars
-            if types.is_backend_variable(v._var_name, cls)
-            and v._var_name not in cls.inherited_backend_vars
+            if types.is_backend_base_variable(name, cls)
         }
         }
 
 
         cls.backend_vars = {
         cls.backend_vars = {
             **cls.inherited_backend_vars,
             **cls.inherited_backend_vars,
             **new_backend_vars,
             **new_backend_vars,
-            **backend_computed_vars,
         }
         }
 
 
         # Set the base and computed vars.
         # Set the base and computed vars.
@@ -548,7 +539,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                     cls.computed_vars[newcv._var_name] = newcv
                     cls.computed_vars[newcv._var_name] = newcv
                     cls.vars[newcv._var_name] = newcv
                     cls.vars[newcv._var_name] = newcv
                     continue
                     continue
-                if types.is_backend_variable(name, mixin):
+                if types.is_backend_base_variable(name, mixin):
                     cls.backend_vars[name] = copy.deepcopy(value)
                     cls.backend_vars[name] = copy.deepcopy(value)
                     continue
                     continue
                 if events.get(name) is not None:
                 if events.get(name) is not None:
@@ -1087,7 +1078,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             setattr(self.parent_state, name, value)
             setattr(self.parent_state, name, value)
             return
             return
 
 
-        if types.is_backend_variable(name, type(self)):
+        if name in self.backend_vars:
             self._backend_vars.__setitem__(name, value)
             self._backend_vars.__setitem__(name, value)
             self.dirty_vars.add(name)
             self.dirty_vars.add(name)
             self._mark_dirty()
             self._mark_dirty()
@@ -1538,11 +1529,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             if self.computed_vars[cvar].needs_update(instance=self)
             if self.computed_vars[cvar].needs_update(instance=self)
         )
         )
 
 
-    def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]:
+    def _dirty_computed_vars(
+        self, from_vars: set[str] | None = None, include_backend: bool = True
+    ) -> set[str]:
         """Determine ComputedVars that need to be recalculated based on the given vars.
         """Determine ComputedVars that need to be recalculated based on the given vars.
 
 
         Args:
         Args:
             from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
             from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
+            include_backend: whether to include backend vars in the calculation.
 
 
         Returns:
         Returns:
             Set of computed vars to include in the delta.
             Set of computed vars to include in the delta.
@@ -1551,6 +1545,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             cvar
             cvar
             for dirty_var in from_vars or self.dirty_vars
             for dirty_var in from_vars or self.dirty_vars
             for cvar in self._computed_var_dependencies[dirty_var]
             for cvar in self._computed_var_dependencies[dirty_var]
+            if include_backend or not self.computed_vars[cvar]._backend
         )
         )
 
 
     @classmethod
     @classmethod
@@ -1586,19 +1581,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         self.dirty_vars.update(self._always_dirty_computed_vars)
         self.dirty_vars.update(self._always_dirty_computed_vars)
         self._mark_dirty()
         self._mark_dirty()
 
 
+        frontend_computed_vars: set[str] = {
+            name for name, cv in self.computed_vars.items() if not cv._backend
+        }
+
         # Return the dirty vars for this instance, any cached/dependent computed vars,
         # Return the dirty vars for this instance, any cached/dependent computed vars,
         # and always dirty computed vars (cache=False)
         # and always dirty computed vars (cache=False)
         delta_vars = (
         delta_vars = (
             self.dirty_vars.intersection(self.base_vars)
             self.dirty_vars.intersection(self.base_vars)
-            .union(self.dirty_vars.intersection(self.computed_vars))
-            .union(self._dirty_computed_vars())
+            .union(self.dirty_vars.intersection(frontend_computed_vars))
+            .union(self._dirty_computed_vars(include_backend=False))
             .union(self._always_dirty_computed_vars)
             .union(self._always_dirty_computed_vars)
         )
         )
 
 
         subdelta = {
         subdelta = {
             prop: getattr(self, prop)
             prop: getattr(self, prop)
             for prop in delta_vars
             for prop in delta_vars
-            if not types.is_backend_variable(prop, type(self))
+            if not types.is_backend_base_variable(prop, type(self))
         }
         }
         if len(subdelta) > 0:
         if len(subdelta) > 0:
             delta[self.get_full_name()] = subdelta
             delta[self.get_full_name()] = subdelta
@@ -1727,12 +1726,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                     else self.get_value(getattr(self, prop_name))
                     else self.get_value(getattr(self, prop_name))
                 )
                 )
                 for prop_name, cv in self.computed_vars.items()
                 for prop_name, cv in self.computed_vars.items()
+                if not cv._backend
             }
             }
         elif include_computed:
         elif include_computed:
             computed_vars = {
             computed_vars = {
                 # Include the computed vars.
                 # Include the computed vars.
                 prop_name: self.get_value(getattr(self, prop_name))
                 prop_name: self.get_value(getattr(self, prop_name))
-                for prop_name in self.computed_vars
+                for prop_name, cv in self.computed_vars.items()
+                if not cv._backend
             }
             }
         else:
         else:
             computed_vars = {}
             computed_vars = {}
@@ -1745,6 +1746,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             for v in self.substates.values()
             for v in self.substates.values()
         ]:
         ]:
             d.update(substate_d)
             d.update(substate_d)
+
         return d
         return d
 
 
     async def __aenter__(self) -> BaseState:
     async def __aenter__(self) -> BaseState:

+ 22 - 23
reflex/utils/types.py

@@ -6,7 +6,7 @@ import contextlib
 import inspect
 import inspect
 import sys
 import sys
 import types
 import types
-from functools import wraps
+from functools import cached_property, wraps
 from typing import (
 from typing import (
     Any,
     Any,
     Callable,
     Callable,
@@ -410,7 +410,7 @@ def is_valid_var_type(type_: Type) -> bool:
     return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
     return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
 
 
 
 
-def is_backend_variable(name: str, cls: Type | None = None) -> bool:
+def is_backend_base_variable(name: str, cls: Type) -> bool:
     """Check if this variable name correspond to a backend variable.
     """Check if this variable name correspond to a backend variable.
 
 
     Args:
     Args:
@@ -429,31 +429,30 @@ def is_backend_variable(name: str, cls: Type | None = None) -> bool:
     if name.startswith("__"):
     if name.startswith("__"):
         return False
         return False
 
 
-    if cls is not None:
-        if name.startswith(f"_{cls.__name__}__"):
+    if name.startswith(f"_{cls.__name__}__"):
+        return False
+
+    hints = get_type_hints(cls)
+    if name in hints:
+        hint = get_origin(hints[name])
+        if hint == ClassVar:
             return False
             return False
-        hints = get_type_hints(cls)
-        if name in hints:
-            hint = get_origin(hints[name])
-            if hint == ClassVar:
-                return False
 
 
-        if name in cls.inherited_backend_vars:
+    if name in cls.inherited_backend_vars:
+        return False
+
+    if name in cls.__dict__:
+        value = cls.__dict__[name]
+        if type(value) == classmethod:
             return False
             return False
+        if callable(value):
+            return False
+        from reflex.vars import ComputedVar
 
 
-        if name in cls.__dict__:
-            value = cls.__dict__[name]
-            if type(value) == classmethod:
-                return False
-            if callable(value):
-                return False
-            if isinstance(value, types.FunctionType):
-                return False
-            # enable after #3573 is merged
-            # from reflex.vars import ComputedVar
-            #
-            # if isinstance(value, ComputedVar):
-            #     return False
+        if isinstance(
+            value, (types.FunctionType, property, cached_property, ComputedVar)
+        ):
+            return False
 
 
     return True
     return True
 
 

+ 13 - 0
reflex/vars.py

@@ -1944,6 +1944,9 @@ class ComputedVar(Var, property):
     # Whether to track dependencies and cache computed values
     # Whether to track dependencies and cache computed values
     _cache: bool = dataclasses.field(default=False)
     _cache: bool = dataclasses.field(default=False)
 
 
+    # Whether the computed var is a backend var
+    _backend: bool = dataclasses.field(default=False)
+
     # The initial value of the computed var
     # The initial value of the computed var
     _initial_value: Any | types.Unset = dataclasses.field(default=types.Unset())
     _initial_value: Any | types.Unset = dataclasses.field(default=types.Unset())
 
 
@@ -1964,6 +1967,7 @@ class ComputedVar(Var, property):
         deps: Optional[List[Union[str, Var]]] = None,
         deps: Optional[List[Union[str, Var]]] = None,
         auto_deps: bool = True,
         auto_deps: bool = True,
         interval: Optional[Union[int, datetime.timedelta]] = None,
         interval: Optional[Union[int, datetime.timedelta]] = None,
+        backend: bool | None = None,
         **kwargs,
         **kwargs,
     ):
     ):
         """Initialize a ComputedVar.
         """Initialize a ComputedVar.
@@ -1975,11 +1979,16 @@ class ComputedVar(Var, property):
             deps: Explicit var dependencies to track.
             deps: Explicit var dependencies to track.
             auto_deps: Whether var dependencies should be auto-determined.
             auto_deps: Whether var dependencies should be auto-determined.
             interval: Interval at which the computed var should be updated.
             interval: Interval at which the computed var should be updated.
+            backend: Whether the computed var is a backend var.
             **kwargs: additional attributes to set on the instance
             **kwargs: additional attributes to set on the instance
 
 
         Raises:
         Raises:
             TypeError: If the computed var dependencies are not Var instances or var names.
             TypeError: If the computed var dependencies are not Var instances or var names.
         """
         """
+        if backend is None:
+            backend = fget.__name__.startswith("_")
+        self._backend = backend
+
         self._initial_value = initial_value
         self._initial_value = initial_value
         self._cache = cache
         self._cache = cache
         if isinstance(interval, int):
         if isinstance(interval, int):
@@ -2023,6 +2032,7 @@ class ComputedVar(Var, property):
             deps=kwargs.get("deps", self._static_deps),
             deps=kwargs.get("deps", self._static_deps),
             auto_deps=kwargs.get("auto_deps", self._auto_deps),
             auto_deps=kwargs.get("auto_deps", self._auto_deps),
             interval=kwargs.get("interval", self._update_interval),
             interval=kwargs.get("interval", self._update_interval),
+            backend=kwargs.get("backend", self._backend),
             _var_name=kwargs.get("_var_name", self._var_name),
             _var_name=kwargs.get("_var_name", self._var_name),
             _var_type=kwargs.get("_var_type", self._var_type),
             _var_type=kwargs.get("_var_type", self._var_type),
             _var_is_local=kwargs.get("_var_is_local", self._var_is_local),
             _var_is_local=kwargs.get("_var_is_local", self._var_is_local),
@@ -2233,6 +2243,7 @@ def computed_var(
     deps: Optional[List[Union[str, Var]]] = None,
     deps: Optional[List[Union[str, Var]]] = None,
     auto_deps: bool = True,
     auto_deps: bool = True,
     interval: Optional[Union[datetime.timedelta, int]] = None,
     interval: Optional[Union[datetime.timedelta, int]] = None,
+    backend: bool | None = None,
     _deprecated_cached_var: bool = False,
     _deprecated_cached_var: bool = False,
     **kwargs,
     **kwargs,
 ) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
 ) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
@@ -2245,6 +2256,7 @@ def computed_var(
         deps: Explicit var dependencies to track.
         deps: Explicit var dependencies to track.
         auto_deps: Whether var dependencies should be auto-determined.
         auto_deps: Whether var dependencies should be auto-determined.
         interval: Interval at which the computed var should be updated.
         interval: Interval at which the computed var should be updated.
+        backend: Whether the computed var is a backend var.
         _deprecated_cached_var: Indicate usage of deprecated cached_var partial function.
         _deprecated_cached_var: Indicate usage of deprecated cached_var partial function.
         **kwargs: additional attributes to set on the instance
         **kwargs: additional attributes to set on the instance
 
 
@@ -2280,6 +2292,7 @@ def computed_var(
             deps=deps,
             deps=deps,
             auto_deps=auto_deps,
             auto_deps=auto_deps,
             interval=interval,
             interval=interval,
+            backend=backend,
             **kwargs,
             **kwargs,
         )
         )
 
 

+ 24 - 2
tests/test_state.py

@@ -925,6 +925,15 @@ class InterdependentState(BaseState):
         """
         """
         return self._v2 * 2
         return self._v2 * 2
 
 
+    @rx.var(cache=True, backend=True)
+    def v2x2_backend(self) -> int:
+        """Depends on backend var _v2.
+
+        Returns:
+            backend var _v2 multiplied by 2
+        """
+        return self._v2 * 2
+
     @rx.var(cache=True)
     @rx.var(cache=True)
     def v1x2x2(self) -> int:
     def v1x2x2(self) -> int:
         """Depends on ComputedVar v1x2.
         """Depends on ComputedVar v1x2.
@@ -1002,11 +1011,11 @@ def test_dirty_computed_var_from_backend_var(
     Args:
     Args:
         interdependent_state: A state with varying Var dependencies.
         interdependent_state: A state with varying Var dependencies.
     """
     """
+    assert InterdependentState._v3._backend is True
     interdependent_state._v2 = 2
     interdependent_state._v2 = 2
     assert interdependent_state.get_delta() == {
     assert interdependent_state.get_delta() == {
         interdependent_state.get_full_name(): {"v2x2": 4, "v3x2": 4},
         interdependent_state.get_full_name(): {"v2x2": 4, "v3x2": 4},
     }
     }
-    assert "_v3" in InterdependentState.backend_vars
 
 
 
 
 def test_per_state_backend_var(interdependent_state: InterdependentState) -> None:
 def test_per_state_backend_var(interdependent_state: InterdependentState) -> None:
@@ -1295,6 +1304,15 @@ def test_computed_var_dependencies():
             """
             """
             return self.v
             return self.v
 
 
+        @rx.var(cache=True, backend=True)
+        def comp_v_backend(self) -> int:
+            """Direct access backend var.
+
+            Returns:
+                The value of self.v.
+            """
+            return self.v
+
         @rx.var(cache=True)
         @rx.var(cache=True)
         def comp_v_via_property(self) -> int:
         def comp_v_via_property(self) -> int:
             """Access v via property.
             """Access v via property.
@@ -1345,7 +1363,11 @@ def test_computed_var_dependencies():
             return [z in self._z for z in range(5)]
             return [z in self._z for z in range(5)]
 
 
     cs = ComputedState()
     cs = ComputedState()
-    assert cs._computed_var_dependencies["v"] == {"comp_v", "comp_v_via_property"}
+    assert cs._computed_var_dependencies["v"] == {
+        "comp_v",
+        "comp_v_backend",
+        "comp_v_via_property",
+    }
     assert cs._computed_var_dependencies["w"] == {"comp_w"}
     assert cs._computed_var_dependencies["w"] == {"comp_w"}
     assert cs._computed_var_dependencies["x"] == {"comp_x"}
     assert cs._computed_var_dependencies["x"] == {"comp_x"}
     assert cs._computed_var_dependencies["y"] == {"comp_y"}
     assert cs._computed_var_dependencies["y"] == {"comp_y"}

+ 16 - 3
tests/utils/test_utils.py

@@ -1,7 +1,8 @@
 import os
 import os
 import typing
 import typing
+from functools import cached_property
 from pathlib import Path
 from pathlib import Path
-from typing import Any, ClassVar, List, Literal, Union
+from typing import Any, ClassVar, List, Literal, Type, Union
 
 
 import pytest
 import pytest
 import typer
 import typer
@@ -161,6 +162,14 @@ def test_backend_variable_cls():
         def _hidden_method(self):
         def _hidden_method(self):
             pass
             pass
 
 
+        @property
+        def _hidden_property(self):
+            pass
+
+        @cached_property
+        def _cached_hidden_property(self):
+            pass
+
     return TestBackendVariable
     return TestBackendVariable
 
 
 
 
@@ -173,10 +182,14 @@ def test_backend_variable_cls():
         ("_hidden", True),
         ("_hidden", True),
         ("not_hidden", False),
         ("not_hidden", False),
         ("__dundermethod__", False),
         ("__dundermethod__", False),
+        ("_hidden_property", False),
+        ("_cached_hidden_property", False),
     ],
     ],
 )
 )
-def test_is_backend_variable(test_backend_variable_cls, input, output):
-    assert types.is_backend_variable(input, test_backend_variable_cls) == output
+def test_is_backend_base_variable(
+    test_backend_variable_cls: Type[BaseState], input: str, output: bool
+):
+    assert types.is_backend_base_variable(input, test_backend_variable_cls) == output
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(