Sfoglia il codice sorgente

Explicit deps and interval for computed vars (#3231)

benedikt-bartscher 1 anno fa
parent
commit
93de407007
4 ha cambiato i file con 321 aggiunte e 5 eliminazioni
  1. 210 0
      integration/test_computed_vars.py
  2. 16 0
      reflex/state.py
  3. 79 5
      reflex/vars.py
  4. 16 0
      reflex/vars.pyi

+ 210 - 0
integration/test_computed_vars.py

@@ -0,0 +1,210 @@
+"""Test computed vars."""
+
+from __future__ import annotations
+
+import time
+from typing import Generator
+
+import pytest
+from selenium.webdriver.common.by import By
+
+from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
+
+
+def ComputedVars():
+    """Test app for computed vars."""
+    import reflex as rx
+
+    class State(rx.State):
+        count: int = 0
+
+        # cached var with dep on count
+        @rx.cached_var(interval=15)
+        def count1(self) -> int:
+            return self.count
+
+        # same as above, different notation
+        @rx.var(interval=15, cache=True)
+        def count2(self) -> int:
+            return self.count
+
+        # explicit disabled auto_deps
+        @rx.var(interval=15, cache=True, auto_deps=False)
+        def count3(self) -> int:
+            # this will not add deps, because auto_deps is False
+            print(self.count1)
+            print(self.count2)
+
+            return self.count
+
+        # explicit dependency on count1 var
+        @rx.cached_var(deps=[count1], auto_deps=False)
+        def depends_on_count1(self) -> int:
+            return self.count
+
+        @rx.var(deps=[count3], auto_deps=False, cache=True)
+        def depends_on_count3(self) -> int:
+            return self.count
+
+        def increment(self):
+            self.count += 1
+
+        def mark_dirty(self):
+            self._mark_dirty()
+
+    def index() -> rx.Component:
+        return rx.center(
+            rx.vstack(
+                rx.input(
+                    id="token",
+                    value=State.router.session.client_token,
+                    is_read_only=True,
+                ),
+                rx.button("Increment", on_click=State.increment, id="increment"),
+                rx.button("Do nothing", on_click=State.mark_dirty, id="mark_dirty"),
+                rx.text("count:"),
+                rx.text(State.count, id="count"),
+                rx.text("count1:"),
+                rx.text(State.count1, id="count1"),
+                rx.text("count2:"),
+                rx.text(State.count2, id="count2"),
+                rx.text("count3:"),
+                rx.text(State.count3, id="count3"),
+                rx.text("depends_on_count1:"),
+                rx.text(
+                    State.depends_on_count1,
+                    id="depends_on_count1",
+                ),
+                rx.text("depends_on_count3:"),
+                rx.text(
+                    State.depends_on_count3,
+                    id="depends_on_count3",
+                ),
+            ),
+        )
+
+    # raise Exception(State.count3._deps(objclass=State))
+    app = rx.App()
+    app.add_page(index)
+
+
+@pytest.fixture(scope="module")
+def computed_vars(
+    tmp_path_factory,
+) -> Generator[AppHarness, None, None]:
+    """Start ComputedVars app at tmp_path via AppHarness.
+
+    Args:
+        tmp_path_factory: pytest tmp_path_factory fixture
+
+    Yields:
+        running AppHarness instance
+    """
+    with AppHarness.create(
+        root=tmp_path_factory.mktemp(f"computed_vars"),
+        app_source=ComputedVars,  # type: ignore
+    ) as harness:
+        yield harness
+
+
+@pytest.fixture
+def driver(computed_vars: AppHarness) -> Generator[WebDriver, None, None]:
+    """Get an instance of the browser open to the computed_vars app.
+
+    Args:
+        computed_vars: harness for ComputedVars app
+
+    Yields:
+        WebDriver instance.
+    """
+    assert computed_vars.app_instance is not None, "app is not running"
+    driver = computed_vars.frontend()
+    try:
+        yield driver
+    finally:
+        driver.quit()
+
+
+@pytest.fixture()
+def token(computed_vars: AppHarness, driver: WebDriver) -> str:
+    """Get a function that returns the active token.
+
+    Args:
+        computed_vars: harness for ComputedVars app.
+        driver: WebDriver instance.
+
+    Returns:
+        The token for the connected client
+    """
+    assert computed_vars.app_instance is not None
+    token_input = driver.find_element(By.ID, "token")
+    assert token_input
+
+    # wait for the backend connection to send the token
+    token = computed_vars.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
+    assert token is not None
+
+    return token
+
+
+def test_computed_vars(
+    computed_vars: AppHarness,
+    driver: WebDriver,
+    token: str,
+):
+    """Test that computed vars are working as expected.
+
+    Args:
+        computed_vars: harness for ComputedVars app.
+        driver: WebDriver instance.
+        token: The token for the connected client.
+    """
+    assert computed_vars.app_instance is not None
+
+    count = driver.find_element(By.ID, "count")
+    assert count
+    assert count.text == "0"
+
+    count1 = driver.find_element(By.ID, "count1")
+    assert count1
+    assert count1.text == "0"
+
+    count2 = driver.find_element(By.ID, "count2")
+    assert count2
+    assert count2.text == "0"
+
+    count3 = driver.find_element(By.ID, "count3")
+    assert count3
+    assert count3.text == "0"
+
+    depends_on_count1 = driver.find_element(By.ID, "depends_on_count1")
+    assert depends_on_count1
+    assert depends_on_count1.text == "0"
+
+    depends_on_count3 = driver.find_element(By.ID, "depends_on_count3")
+    assert depends_on_count3
+    assert depends_on_count3.text == "0"
+
+    increment = driver.find_element(By.ID, "increment")
+    assert increment.is_enabled()
+
+    mark_dirty = driver.find_element(By.ID, "mark_dirty")
+    assert mark_dirty.is_enabled()
+
+    mark_dirty.click()
+
+    increment.click()
+    assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1"
+    assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1"
+    assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1"
+
+    mark_dirty.click()
+    with pytest.raises(TimeoutError):
+        computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0")
+
+    time.sleep(10)
+    assert count3.text == "0"
+    assert depends_on_count3.text == "0"
+    mark_dirty.click()
+    assert computed_vars.poll_for_content(count3, timeout=2, exp_not_equal="0") == "1"
+    assert depends_on_count3.text == "1"

+ 16 - 0
reflex/state.py

@@ -1536,6 +1536,18 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 if actual_var is not None:
                     actual_var.mark_dirty(instance=self)
 
+    def _expired_computed_vars(self) -> set[str]:
+        """Determine ComputedVars that need to be recalculated based on the expiration time.
+
+        Returns:
+            Set of computed vars to include in the delta.
+        """
+        return set(
+            cvar
+            for cvar in self.computed_vars
+            if self.computed_vars[cvar].needs_update(instance=self)
+        )
+
     def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]:
         """Determine ComputedVars that need to be recalculated based on the given vars.
 
@@ -1588,6 +1600,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         # and always dirty computed vars (cache=False)
         delta_vars = (
             self.dirty_vars.intersection(self.base_vars)
+            .union(self.dirty_vars.intersection(self.computed_vars))
             .union(self._dirty_computed_vars())
             .union(self._always_dirty_computed_vars)
         )
@@ -1621,6 +1634,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             self.parent_state.dirty_substates.add(self.get_name())
             self.parent_state._mark_dirty()
 
+        # Append expired computed vars to dirty_vars to trigger recalculation
+        self.dirty_vars.update(self._expired_computed_vars())
+
         # have to mark computed vars dirty to allow access to newly computed
         # values within the same ComputedVar function
         self._mark_dirty_computed_vars()

+ 79 - 5
reflex/vars.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 
 import contextlib
 import dataclasses
+import datetime
 import dis
 import functools
 import inspect
@@ -1873,13 +1874,26 @@ class ComputedVar(Var, property):
     # Whether to track dependencies and cache computed values
     _cache: bool = dataclasses.field(default=False)
 
-    _initial_value: Any | types.Unset = dataclasses.field(default_factory=types.Unset)
+    # The initial value of the computed var
+    _initial_value: Any | types.Unset = dataclasses.field(default=types.Unset())
+
+    # Explicit var dependencies to track
+    _static_deps: set[str] = dataclasses.field(default_factory=set)
+
+    # Whether var dependencies should be auto-determined
+    _auto_deps: bool = dataclasses.field(default=True)
+
+    # Interval at which the computed var should be updated
+    _update_interval: Optional[datetime.timedelta] = dataclasses.field(default=None)
 
     def __init__(
         self,
         fget: Callable[[BaseState], Any],
         initial_value: Any | types.Unset = types.Unset(),
         cache: bool = False,
+        deps: Optional[List[Union[str, Var]]] = None,
+        auto_deps: bool = True,
+        interval: Optional[Union[int, datetime.timedelta]] = None,
         **kwargs,
     ):
         """Initialize a ComputedVar.
@@ -1888,10 +1902,22 @@ class ComputedVar(Var, property):
             fget: The getter function.
             initial_value: The initial value of the computed var.
             cache: Whether to cache the computed value.
+            deps: Explicit var dependencies to track.
+            auto_deps: Whether var dependencies should be auto-determined.
+            interval: Interval at which the computed var should be updated.
             **kwargs: additional attributes to set on the instance
         """
         self._initial_value = initial_value
         self._cache = cache
+        if isinstance(interval, int):
+            interval = datetime.timedelta(seconds=interval)
+        self._update_interval = interval
+        if deps is None:
+            deps = []
+        self._static_deps = {
+            dep._var_name if isinstance(dep, Var) else dep for dep in deps
+        }
+        self._auto_deps = auto_deps
         property.__init__(self, fget)
         kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__)
         kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
@@ -1912,6 +1938,9 @@ class ComputedVar(Var, property):
             fget=kwargs.get("fget", self.fget),
             initial_value=kwargs.get("initial_value", self._initial_value),
             cache=kwargs.get("cache", self._cache),
+            deps=kwargs.get("deps", self._static_deps),
+            auto_deps=kwargs.get("auto_deps", self._auto_deps),
+            interval=kwargs.get("interval", self._update_interval),
             _var_name=kwargs.get("_var_name", self._var_name),
             _var_type=kwargs.get("_var_type", self._var_type),
             _var_is_local=kwargs.get("_var_is_local", self._var_is_local),
@@ -1932,7 +1961,32 @@ class ComputedVar(Var, property):
         """
         return f"__cached_{self._var_name}"
 
-    def __get__(self, instance, owner):
+    @property
+    def _last_updated_attr(self) -> str:
+        """Get the attribute used to store the last updated timestamp.
+
+        Returns:
+            An attribute name.
+        """
+        return f"__last_updated_{self._var_name}"
+
+    def needs_update(self, instance: BaseState) -> bool:
+        """Check if the computed var needs to be updated.
+
+        Args:
+            instance: The state instance that the computed var is attached to.
+
+        Returns:
+            True if the computed var needs to be updated, False otherwise.
+        """
+        if self._update_interval is None:
+            return False
+        last_updated = getattr(instance, self._last_updated_attr, None)
+        if last_updated is None:
+            return True
+        return datetime.datetime.now() - last_updated > self._update_interval
+
+    def __get__(self, instance: BaseState | None, owner):
         """Get the ComputedVar value.
 
         If the value is already cached on the instance, return the cached value.
@@ -1948,10 +2002,13 @@ class ComputedVar(Var, property):
             return super().__get__(instance, owner)
 
         # handle caching
-        if not hasattr(instance, self._cache_attr):
+        if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
+            # Set cache attr on state instance.
             setattr(instance, self._cache_attr, super().__get__(instance, owner))
             # Ensure the computed var gets serialized to redis.
             instance._was_touched = True
+            # Set the last updated timestamp on the state instance.
+            setattr(instance, self._last_updated_attr, datetime.datetime.now())
         return getattr(instance, self._cache_attr)
 
     def _deps(
@@ -1978,7 +2035,9 @@ class ComputedVar(Var, property):
             VarValueError: if the function references the get_state, parent_state, or substates attributes
                 (cannot track deps in a related state, only implicitly via parent state).
         """
-        d = set()
+        if not self._auto_deps:
+            return self._static_deps
+        d = self._static_deps.copy()
         if obj is None:
             fget = property.__getattribute__(self, "fget")
             if fget is not None:
@@ -2076,6 +2135,9 @@ def computed_var(
     fget: Callable[[BaseState], Any] | None = None,
     initial_value: Any | None = None,
     cache: bool = False,
+    deps: Optional[List[Union[str, Var]]] = None,
+    auto_deps: bool = True,
+    interval: Optional[Union[datetime.timedelta, int]] = None,
     **kwargs,
 ) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
     """A ComputedVar decorator with or without kwargs.
@@ -2084,19 +2146,31 @@ def computed_var(
         fget: The getter function.
         initial_value: The initial value of the computed var.
         cache: Whether to cache the computed value.
+        deps: Explicit var dependencies to track.
+        auto_deps: Whether var dependencies should be auto-determined.
+        interval: Interval at which the computed var should be updated.
         **kwargs: additional attributes to set on the instance
 
     Returns:
         A ComputedVar instance.
+
+    Raises:
+        ValueError: If caching is disabled and an update interval is set.
     """
+    if cache is False and interval is not None:
+        raise ValueError("Cannot set update interval without caching.")
+
     if fget is not None:
         return ComputedVar(fget=fget, cache=cache)
 
-    def wrapper(fget):
+    def wrapper(fget: Callable[[BaseState], Any]) -> ComputedVar:
         return ComputedVar(
             fget=fget,
             initial_value=initial_value,
             cache=cache,
+            deps=deps,
+            auto_deps=auto_deps,
+            interval=interval,
             **kwargs,
         )
 

+ 16 - 0
reflex/vars.pyi

@@ -2,6 +2,7 @@
 
 from __future__ import annotations
 
+import datetime
 from dataclasses import dataclass
 from _typeshed import Incomplete
 from reflex import constants as constants
@@ -141,6 +142,7 @@ class ComputedVar(Var):
     def _deps(self, objclass: Type, obj: Optional[FunctionType] = ...) -> Set[str]: ...
     def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar: ...
     def mark_dirty(self, instance) -> None: ...
+    def needs_update(self, instance) -> bool: ...
     def _determine_var_type(self) -> Type: ...
     @overload
     def __init__(
@@ -155,10 +157,24 @@ class ComputedVar(Var):
 def computed_var(
     fget: Callable[[BaseState], Any] | None = None,
     initial_value: Any | None = None,
+    cache: bool = False,
+    deps: Optional[List[Union[str, Var]]] = None,
+    auto_deps: bool = True,
+    interval: Optional[Union[datetime.timedelta, int]] = None,
     **kwargs,
 ) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
 @overload
 def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
+@overload
+def cached_var(
+    fget: Callable[[BaseState], Any] | None = None,
+    initial_value: Any | None = None,
+    deps: Optional[List[Union[str, Var]]] = None,
+    auto_deps: bool = True,
+    interval: Optional[Union[datetime.timedelta, int]] = None,
+    **kwargs,
+) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
+@overload
 def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
 
 class CallableVar(BaseVar):