Browse Source

Explicit deps and interval for computed vars (#3231)

benedikt-bartscher 1 năm trước cách đây
mục cha
commit
93de407007
4 tập tin đã thay đổi với 321 bổ sung5 xóa
  1. 210 0
      integration/test_computed_vars.py
  2. 16 0
      reflex/state.py
  3. 79 5
      reflex/vars.py
  4. 16 0
      reflex/vars.pyi

+ 210 - 0
integration/test_computed_vars.py

@@ -0,0 +1,210 @@
+"""Test computed vars."""
+
+from __future__ import annotations
+
+import time
+from typing import Generator
+
+import pytest
+from selenium.webdriver.common.by import By
+
+from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
+
+
+def ComputedVars():
+    """Test app for computed vars."""
+    import reflex as rx
+
+    class State(rx.State):
+        count: int = 0
+
+        # cached var with dep on count
+        @rx.cached_var(interval=15)
+        def count1(self) -> int:
+            return self.count
+
+        # same as above, different notation
+        @rx.var(interval=15, cache=True)
+        def count2(self) -> int:
+            return self.count
+
+        # explicit disabled auto_deps
+        @rx.var(interval=15, cache=True, auto_deps=False)
+        def count3(self) -> int:
+            # this will not add deps, because auto_deps is False
+            print(self.count1)
+            print(self.count2)
+
+            return self.count
+
+        # explicit dependency on count1 var
+        @rx.cached_var(deps=[count1], auto_deps=False)
+        def depends_on_count1(self) -> int:
+            return self.count
+
+        @rx.var(deps=[count3], auto_deps=False, cache=True)
+        def depends_on_count3(self) -> int:
+            return self.count
+
+        def increment(self):
+            self.count += 1
+
+        def mark_dirty(self):
+            self._mark_dirty()
+
+    def index() -> rx.Component:
+        return rx.center(
+            rx.vstack(
+                rx.input(
+                    id="token",
+                    value=State.router.session.client_token,
+                    is_read_only=True,
+                ),
+                rx.button("Increment", on_click=State.increment, id="increment"),
+                rx.button("Do nothing", on_click=State.mark_dirty, id="mark_dirty"),
+                rx.text("count:"),
+                rx.text(State.count, id="count"),
+                rx.text("count1:"),
+                rx.text(State.count1, id="count1"),
+                rx.text("count2:"),
+                rx.text(State.count2, id="count2"),
+                rx.text("count3:"),
+                rx.text(State.count3, id="count3"),
+                rx.text("depends_on_count1:"),
+                rx.text(
+                    State.depends_on_count1,
+                    id="depends_on_count1",
+                ),
+                rx.text("depends_on_count3:"),
+                rx.text(
+                    State.depends_on_count3,
+                    id="depends_on_count3",
+                ),
+            ),
+        )
+
+    # raise Exception(State.count3._deps(objclass=State))
+    app = rx.App()
+    app.add_page(index)
+
+
+@pytest.fixture(scope="module")
+def computed_vars(
+    tmp_path_factory,
+) -> Generator[AppHarness, None, None]:
+    """Start ComputedVars app at tmp_path via AppHarness.
+
+    Args:
+        tmp_path_factory: pytest tmp_path_factory fixture
+
+    Yields:
+        running AppHarness instance
+    """
+    with AppHarness.create(
+        root=tmp_path_factory.mktemp(f"computed_vars"),
+        app_source=ComputedVars,  # type: ignore
+    ) as harness:
+        yield harness
+
+
+@pytest.fixture
+def driver(computed_vars: AppHarness) -> Generator[WebDriver, None, None]:
+    """Get an instance of the browser open to the computed_vars app.
+
+    Args:
+        computed_vars: harness for ComputedVars app
+
+    Yields:
+        WebDriver instance.
+    """
+    assert computed_vars.app_instance is not None, "app is not running"
+    driver = computed_vars.frontend()
+    try:
+        yield driver
+    finally:
+        driver.quit()
+
+
+@pytest.fixture()
+def token(computed_vars: AppHarness, driver: WebDriver) -> str:
+    """Get a function that returns the active token.
+
+    Args:
+        computed_vars: harness for ComputedVars app.
+        driver: WebDriver instance.
+
+    Returns:
+        The token for the connected client
+    """
+    assert computed_vars.app_instance is not None
+    token_input = driver.find_element(By.ID, "token")
+    assert token_input
+
+    # wait for the backend connection to send the token
+    token = computed_vars.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
+    assert token is not None
+
+    return token
+
+
+def test_computed_vars(
+    computed_vars: AppHarness,
+    driver: WebDriver,
+    token: str,
+):
+    """Test that computed vars are working as expected.
+
+    Args:
+        computed_vars: harness for ComputedVars app.
+        driver: WebDriver instance.
+        token: The token for the connected client.
+    """
+    assert computed_vars.app_instance is not None
+
+    count = driver.find_element(By.ID, "count")
+    assert count
+    assert count.text == "0"
+
+    count1 = driver.find_element(By.ID, "count1")
+    assert count1
+    assert count1.text == "0"
+
+    count2 = driver.find_element(By.ID, "count2")
+    assert count2
+    assert count2.text == "0"
+
+    count3 = driver.find_element(By.ID, "count3")
+    assert count3
+    assert count3.text == "0"
+
+    depends_on_count1 = driver.find_element(By.ID, "depends_on_count1")
+    assert depends_on_count1
+    assert depends_on_count1.text == "0"
+
+    depends_on_count3 = driver.find_element(By.ID, "depends_on_count3")
+    assert depends_on_count3
+    assert depends_on_count3.text == "0"
+
+    increment = driver.find_element(By.ID, "increment")
+    assert increment.is_enabled()
+
+    mark_dirty = driver.find_element(By.ID, "mark_dirty")
+    assert mark_dirty.is_enabled()
+
+    mark_dirty.click()
+
+    increment.click()
+    assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1"
+    assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1"
+    assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1"
+
+    mark_dirty.click()
+    with pytest.raises(TimeoutError):
+        computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0")
+
+    time.sleep(10)
+    assert count3.text == "0"
+    assert depends_on_count3.text == "0"
+    mark_dirty.click()
+    assert computed_vars.poll_for_content(count3, timeout=2, exp_not_equal="0") == "1"
+    assert depends_on_count3.text == "1"

+ 16 - 0
reflex/state.py

@@ -1536,6 +1536,18 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 if actual_var is not None:
                 if actual_var is not None:
                     actual_var.mark_dirty(instance=self)
                     actual_var.mark_dirty(instance=self)
 
 
+    def _expired_computed_vars(self) -> set[str]:
+        """Determine ComputedVars that need to be recalculated based on the expiration time.
+
+        Returns:
+            Set of computed vars to include in the delta.
+        """
+        return set(
+            cvar
+            for cvar in self.computed_vars
+            if self.computed_vars[cvar].needs_update(instance=self)
+        )
+
     def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]:
     def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]:
         """Determine ComputedVars that need to be recalculated based on the given vars.
         """Determine ComputedVars that need to be recalculated based on the given vars.
 
 
@@ -1588,6 +1600,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         # and always dirty computed vars (cache=False)
         # and always dirty computed vars (cache=False)
         delta_vars = (
         delta_vars = (
             self.dirty_vars.intersection(self.base_vars)
             self.dirty_vars.intersection(self.base_vars)
+            .union(self.dirty_vars.intersection(self.computed_vars))
             .union(self._dirty_computed_vars())
             .union(self._dirty_computed_vars())
             .union(self._always_dirty_computed_vars)
             .union(self._always_dirty_computed_vars)
         )
         )
@@ -1621,6 +1634,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             self.parent_state.dirty_substates.add(self.get_name())
             self.parent_state.dirty_substates.add(self.get_name())
             self.parent_state._mark_dirty()
             self.parent_state._mark_dirty()
 
 
+        # Append expired computed vars to dirty_vars to trigger recalculation
+        self.dirty_vars.update(self._expired_computed_vars())
+
         # have to mark computed vars dirty to allow access to newly computed
         # have to mark computed vars dirty to allow access to newly computed
         # values within the same ComputedVar function
         # values within the same ComputedVar function
         self._mark_dirty_computed_vars()
         self._mark_dirty_computed_vars()

+ 79 - 5
reflex/vars.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 
 
 import contextlib
 import contextlib
 import dataclasses
 import dataclasses
+import datetime
 import dis
 import dis
 import functools
 import functools
 import inspect
 import inspect
@@ -1873,13 +1874,26 @@ class ComputedVar(Var, property):
     # Whether to track dependencies and cache computed values
     # Whether to track dependencies and cache computed values
     _cache: bool = dataclasses.field(default=False)
     _cache: bool = dataclasses.field(default=False)
 
 
-    _initial_value: Any | types.Unset = dataclasses.field(default_factory=types.Unset)
+    # The initial value of the computed var
+    _initial_value: Any | types.Unset = dataclasses.field(default=types.Unset())
+
+    # Explicit var dependencies to track
+    _static_deps: set[str] = dataclasses.field(default_factory=set)
+
+    # Whether var dependencies should be auto-determined
+    _auto_deps: bool = dataclasses.field(default=True)
+
+    # Interval at which the computed var should be updated
+    _update_interval: Optional[datetime.timedelta] = dataclasses.field(default=None)
 
 
     def __init__(
     def __init__(
         self,
         self,
         fget: Callable[[BaseState], Any],
         fget: Callable[[BaseState], Any],
         initial_value: Any | types.Unset = types.Unset(),
         initial_value: Any | types.Unset = types.Unset(),
         cache: bool = False,
         cache: bool = False,
+        deps: Optional[List[Union[str, Var]]] = None,
+        auto_deps: bool = True,
+        interval: Optional[Union[int, datetime.timedelta]] = None,
         **kwargs,
         **kwargs,
     ):
     ):
         """Initialize a ComputedVar.
         """Initialize a ComputedVar.
@@ -1888,10 +1902,22 @@ class ComputedVar(Var, property):
             fget: The getter function.
             fget: The getter function.
             initial_value: The initial value of the computed var.
             initial_value: The initial value of the computed var.
             cache: Whether to cache the computed value.
             cache: Whether to cache the computed value.
+            deps: Explicit var dependencies to track.
+            auto_deps: Whether var dependencies should be auto-determined.
+            interval: Interval at which the computed var should be updated.
             **kwargs: additional attributes to set on the instance
             **kwargs: additional attributes to set on the instance
         """
         """
         self._initial_value = initial_value
         self._initial_value = initial_value
         self._cache = cache
         self._cache = cache
+        if isinstance(interval, int):
+            interval = datetime.timedelta(seconds=interval)
+        self._update_interval = interval
+        if deps is None:
+            deps = []
+        self._static_deps = {
+            dep._var_name if isinstance(dep, Var) else dep for dep in deps
+        }
+        self._auto_deps = auto_deps
         property.__init__(self, fget)
         property.__init__(self, fget)
         kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__)
         kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__)
         kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
         kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
@@ -1912,6 +1938,9 @@ class ComputedVar(Var, property):
             fget=kwargs.get("fget", self.fget),
             fget=kwargs.get("fget", self.fget),
             initial_value=kwargs.get("initial_value", self._initial_value),
             initial_value=kwargs.get("initial_value", self._initial_value),
             cache=kwargs.get("cache", self._cache),
             cache=kwargs.get("cache", self._cache),
+            deps=kwargs.get("deps", self._static_deps),
+            auto_deps=kwargs.get("auto_deps", self._auto_deps),
+            interval=kwargs.get("interval", self._update_interval),
             _var_name=kwargs.get("_var_name", self._var_name),
             _var_name=kwargs.get("_var_name", self._var_name),
             _var_type=kwargs.get("_var_type", self._var_type),
             _var_type=kwargs.get("_var_type", self._var_type),
             _var_is_local=kwargs.get("_var_is_local", self._var_is_local),
             _var_is_local=kwargs.get("_var_is_local", self._var_is_local),
@@ -1932,7 +1961,32 @@ class ComputedVar(Var, property):
         """
         """
         return f"__cached_{self._var_name}"
         return f"__cached_{self._var_name}"
 
 
-    def __get__(self, instance, owner):
+    @property
+    def _last_updated_attr(self) -> str:
+        """Get the attribute used to store the last updated timestamp.
+
+        Returns:
+            An attribute name.
+        """
+        return f"__last_updated_{self._var_name}"
+
+    def needs_update(self, instance: BaseState) -> bool:
+        """Check if the computed var needs to be updated.
+
+        Args:
+            instance: The state instance that the computed var is attached to.
+
+        Returns:
+            True if the computed var needs to be updated, False otherwise.
+        """
+        if self._update_interval is None:
+            return False
+        last_updated = getattr(instance, self._last_updated_attr, None)
+        if last_updated is None:
+            return True
+        return datetime.datetime.now() - last_updated > self._update_interval
+
+    def __get__(self, instance: BaseState | None, owner):
         """Get the ComputedVar value.
         """Get the ComputedVar value.
 
 
         If the value is already cached on the instance, return the cached value.
         If the value is already cached on the instance, return the cached value.
@@ -1948,10 +2002,13 @@ class ComputedVar(Var, property):
             return super().__get__(instance, owner)
             return super().__get__(instance, owner)
 
 
         # handle caching
         # handle caching
-        if not hasattr(instance, self._cache_attr):
+        if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
+            # Set cache attr on state instance.
             setattr(instance, self._cache_attr, super().__get__(instance, owner))
             setattr(instance, self._cache_attr, super().__get__(instance, owner))
             # Ensure the computed var gets serialized to redis.
             # Ensure the computed var gets serialized to redis.
             instance._was_touched = True
             instance._was_touched = True
+            # Set the last updated timestamp on the state instance.
+            setattr(instance, self._last_updated_attr, datetime.datetime.now())
         return getattr(instance, self._cache_attr)
         return getattr(instance, self._cache_attr)
 
 
     def _deps(
     def _deps(
@@ -1978,7 +2035,9 @@ class ComputedVar(Var, property):
             VarValueError: if the function references the get_state, parent_state, or substates attributes
             VarValueError: if the function references the get_state, parent_state, or substates attributes
                 (cannot track deps in a related state, only implicitly via parent state).
                 (cannot track deps in a related state, only implicitly via parent state).
         """
         """
-        d = set()
+        if not self._auto_deps:
+            return self._static_deps
+        d = self._static_deps.copy()
         if obj is None:
         if obj is None:
             fget = property.__getattribute__(self, "fget")
             fget = property.__getattribute__(self, "fget")
             if fget is not None:
             if fget is not None:
@@ -2076,6 +2135,9 @@ def computed_var(
     fget: Callable[[BaseState], Any] | None = None,
     fget: Callable[[BaseState], Any] | None = None,
     initial_value: Any | None = None,
     initial_value: Any | None = None,
     cache: bool = False,
     cache: bool = False,
+    deps: Optional[List[Union[str, Var]]] = None,
+    auto_deps: bool = True,
+    interval: Optional[Union[datetime.timedelta, int]] = None,
     **kwargs,
     **kwargs,
 ) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
 ) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
     """A ComputedVar decorator with or without kwargs.
     """A ComputedVar decorator with or without kwargs.
@@ -2084,19 +2146,31 @@ def computed_var(
         fget: The getter function.
         fget: The getter function.
         initial_value: The initial value of the computed var.
         initial_value: The initial value of the computed var.
         cache: Whether to cache the computed value.
         cache: Whether to cache the computed value.
+        deps: Explicit var dependencies to track.
+        auto_deps: Whether var dependencies should be auto-determined.
+        interval: Interval at which the computed var should be updated.
         **kwargs: additional attributes to set on the instance
         **kwargs: additional attributes to set on the instance
 
 
     Returns:
     Returns:
         A ComputedVar instance.
         A ComputedVar instance.
+
+    Raises:
+        ValueError: If caching is disabled and an update interval is set.
     """
     """
+    if cache is False and interval is not None:
+        raise ValueError("Cannot set update interval without caching.")
+
     if fget is not None:
     if fget is not None:
         return ComputedVar(fget=fget, cache=cache)
         return ComputedVar(fget=fget, cache=cache)
 
 
-    def wrapper(fget):
+    def wrapper(fget: Callable[[BaseState], Any]) -> ComputedVar:
         return ComputedVar(
         return ComputedVar(
             fget=fget,
             fget=fget,
             initial_value=initial_value,
             initial_value=initial_value,
             cache=cache,
             cache=cache,
+            deps=deps,
+            auto_deps=auto_deps,
+            interval=interval,
             **kwargs,
             **kwargs,
         )
         )
 
 

+ 16 - 0
reflex/vars.pyi

@@ -2,6 +2,7 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
+import datetime
 from dataclasses import dataclass
 from dataclasses import dataclass
 from _typeshed import Incomplete
 from _typeshed import Incomplete
 from reflex import constants as constants
 from reflex import constants as constants
@@ -141,6 +142,7 @@ class ComputedVar(Var):
     def _deps(self, objclass: Type, obj: Optional[FunctionType] = ...) -> Set[str]: ...
     def _deps(self, objclass: Type, obj: Optional[FunctionType] = ...) -> Set[str]: ...
     def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar: ...
     def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar: ...
     def mark_dirty(self, instance) -> None: ...
     def mark_dirty(self, instance) -> None: ...
+    def needs_update(self, instance) -> bool: ...
     def _determine_var_type(self) -> Type: ...
     def _determine_var_type(self) -> Type: ...
     @overload
     @overload
     def __init__(
     def __init__(
@@ -155,10 +157,24 @@ class ComputedVar(Var):
 def computed_var(
 def computed_var(
     fget: Callable[[BaseState], Any] | None = None,
     fget: Callable[[BaseState], Any] | None = None,
     initial_value: Any | None = None,
     initial_value: Any | None = None,
+    cache: bool = False,
+    deps: Optional[List[Union[str, Var]]] = None,
+    auto_deps: bool = True,
+    interval: Optional[Union[datetime.timedelta, int]] = None,
     **kwargs,
     **kwargs,
 ) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
 ) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
 @overload
 @overload
 def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
 def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
+@overload
+def cached_var(
+    fget: Callable[[BaseState], Any] | None = None,
+    initial_value: Any | None = None,
+    deps: Optional[List[Union[str, Var]]] = None,
+    auto_deps: bool = True,
+    interval: Optional[Union[datetime.timedelta, int]] = None,
+    **kwargs,
+) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
+@overload
 def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
 def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
 
 
 class CallableVar(BaseVar):
 class CallableVar(BaseVar):