Browse Source

Initial values for computed vars (#2670)

* initial values for computed vars draft

* add tests, add computed_var overloads

* fix darglint

* pass initial to substates when calling dict

* add tests for for child states

* format black

* allow None as initial value

* rename runtime_only to raises_at_runtime

* cleanup unused arguments of ComputedVars

* refactor cached_var to be partial of computed_var
benedikt-bartscher 1 year ago
parent
commit
93f402c773
6 changed files with 196 additions and 37 deletions
  1. 1 1
      reflex/compiler/utils.py
  2. 20 9
      reflex/state.py
  3. 23 0
      reflex/utils/types.py
  4. 38 17
      reflex/vars.py
  5. 8 3
      reflex/vars.pyi
  6. 106 7
      tests/test_var.py

+ 1 - 1
reflex/compiler/utils.py

@@ -138,7 +138,7 @@ def compile_state(state: Type[BaseState]) -> dict:
         A dictionary of the compiled state.
     """
     try:
-        initial_state = state().dict()
+        initial_state = state().dict(initial=True)
     except Exception as e:
         console.warn(
             f"Failed to compile initial state with computed vars, excluding them: {e}"

+ 20 - 9
reflex/state.py

@@ -46,10 +46,10 @@ from reflex.event import (
 from reflex.utils import console, format, prerequisites, types
 from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
 from reflex.utils.serializers import SerializedType, serialize, serializer
-from reflex.vars import BaseVar, ComputedVar, Var
+from reflex.vars import BaseVar, ComputedVar, Var, computed_var
 
 Delta = Dict[str, Any]
-var = ComputedVar
+var = computed_var
 
 
 class HeaderData(Base):
@@ -1328,11 +1328,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             return super().get_value(key.__wrapped__)
         return super().get_value(key)
 
-    def dict(self, include_computed: bool = True, **kwargs) -> dict[str, Any]:
+    def dict(
+        self, include_computed: bool = True, initial: bool = False, **kwargs
+    ) -> dict[str, Any]:
         """Convert the object to a dictionary.
 
         Args:
             include_computed: Whether to include computed vars.
+            initial: Whether to get the initial value of computed vars.
             **kwargs: Kwargs to pass to the pydantic dict method.
 
         Returns:
@@ -1348,21 +1351,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             prop_name: self.get_value(getattr(self, prop_name))
             for prop_name in self.base_vars
         }
-        computed_vars = (
-            {
+        if initial:
+            computed_vars = {
+                # Include initial computed vars.
+                prop_name: cv._initial_value
+                if isinstance(cv, ComputedVar)
+                and not isinstance(cv._initial_value, types.Unset)
+                else self.get_value(getattr(self, prop_name))
+                for prop_name, cv in self.computed_vars.items()
+            }
+        elif include_computed:
+            computed_vars = {
                 # Include the computed vars.
                 prop_name: self.get_value(getattr(self, prop_name))
                 for prop_name in self.computed_vars
             }
-            if include_computed
-            else {}
-        )
+        else:
+            computed_vars = {}
         variables = {**base_vars, **computed_vars}
         d = {
             self.get_full_name(): {k: variables[k] for k in sorted(variables)},
         }
         for substate_d in [
-            v.dict(include_computed=include_computed, **kwargs)
+            v.dict(include_computed=include_computed, initial=initial, **kwargs)
             for v in self.substates.values()
         ]:
             d.update(substate_d)

+ 23 - 0
reflex/utils/types.py

@@ -43,6 +43,29 @@ StateIterVar = Union[list, set, tuple]
 ArgsSpec = Callable
 
 
+class Unset:
+    """A class to represent an unset value.
+
+    This is used to differentiate between a value that is not set and a value that is set to None.
+    """
+
+    def __repr__(self) -> str:
+        """Return the string representation of the class.
+
+        Returns:
+            The string representation of the class.
+        """
+        return "Unset"
+
+    def __bool__(self) -> bool:
+        """Return False when the class is used in a boolean context.
+
+        Returns:
+            False
+        """
+        return False
+
+
 def is_generic_alias(cls: GenericType) -> bool:
     """Check whether the class is a generic alias.
 

+ 38 - 17
reflex/vars.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 import contextlib
 import dataclasses
 import dis
+import functools
 import inspect
 import json
 import random
@@ -1802,24 +1803,26 @@ class ComputedVar(Var, property):
     # Whether to track dependencies and cache computed values
     _cache: bool = dataclasses.field(default=False)
 
+    _initial_value: Any | types.Unset = dataclasses.field(default_factory=types.Unset)
+
     def __init__(
         self,
         fget: Callable[[BaseState], Any],
-        fset: Callable[[BaseState, Any], None] | None = None,
-        fdel: Callable[[BaseState], Any] | None = None,
-        doc: str | None = None,
+        initial_value: Any | types.Unset = types.Unset(),
+        cache: bool = False,
         **kwargs,
     ):
         """Initialize a ComputedVar.
 
         Args:
             fget: The getter function.
-            fset: The setter function.
-            fdel: The deleter function.
-            doc: The docstring.
+            initial_value: The initial value of the computed var.
+            cache: Whether to cache the computed value.
             **kwargs: additional attributes to set on the instance
         """
-        property.__init__(self, fget, fset, fdel, doc)
+        self._initial_value = initial_value
+        self._cache = cache
+        property.__init__(self, fget)
         kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__)
         kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
         BaseVar.__init__(self, **kwargs)  # type: ignore
@@ -1960,21 +1963,39 @@ class ComputedVar(Var, property):
         return Any
 
 
-def cached_var(fget: Callable[[Any], Any]) -> ComputedVar:
-    """A field with computed getter that tracks other state dependencies.
-
-    The cached_var will only be recalculated when other state vars that it
-    depends on are modified.
+def computed_var(
+    fget: Callable[[BaseState], Any] | None = None,
+    initial_value: Any | None = None,
+    cache: bool = False,
+    **kwargs,
+) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
+    """A ComputedVar decorator with or without kwargs.
 
     Args:
-        fget: the function that calculates the variable value.
+        fget: The getter function.
+        initial_value: The initial value of the computed var.
+        cache: Whether to cache the computed value.
+        **kwargs: additional attributes to set on the instance
 
     Returns:
-        ComputedVar that is recomputed when dependencies change.
+        A ComputedVar instance.
     """
-    cvar = ComputedVar(fget=fget)
-    cvar._cache = True
-    return cvar
+    if fget is not None:
+        return ComputedVar(fget=fget, cache=cache)
+
+    def wrapper(fget):
+        return ComputedVar(
+            fget=fget,
+            initial_value=initial_value,
+            cache=cache,
+            **kwargs,
+        )
+
+    return wrapper
+
+
+# Partial function of computed_var with cache=True
+cached_var = functools.partial(computed_var, cache=True)
 
 
 class CallableVar(BaseVar):

+ 8 - 3
reflex/vars.pyi

@@ -144,14 +144,19 @@ class ComputedVar(Var):
     def __init__(
         self,
         fget: Callable[[BaseState], Any],
-        fset: Callable[[BaseState, Any], None] | None = None,
-        fdel: Callable[[BaseState], Any] | None = None,
-        doc: str | None = None,
         **kwargs,
     ) -> None: ...
     @overload
     def __init__(self, func) -> None: ...
 
+@overload
+def computed_var(
+    fget: Callable[[BaseState], Any] | None = None,
+    initial_value: Any | None = None,
+    **kwargs,
+) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
+@overload
+def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
 def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
 
 class CallableVar(BaseVar):

+ 106 - 7
tests/test_var.py

@@ -9,8 +9,8 @@ from reflex.base import Base
 from reflex.state import BaseState
 from reflex.vars import (
     BaseVar,
-    ComputedVar,
     Var,
+    computed_var,
 )
 
 test_vars = [
@@ -46,7 +46,7 @@ def ParentState(TestObj):
         foo: int
         bar: int
 
-        @ComputedVar
+        @computed_var
         def var_without_annotation(self):
             return TestObj
 
@@ -56,7 +56,7 @@ def ParentState(TestObj):
 @pytest.fixture
 def ChildState(ParentState, TestObj):
     class ChildState(ParentState):
-        @ComputedVar
+        @computed_var
         def var_without_annotation(self):
             return TestObj
 
@@ -66,7 +66,7 @@ def ChildState(ParentState, TestObj):
 @pytest.fixture
 def GrandChildState(ChildState, TestObj):
     class GrandChildState(ChildState):
-        @ComputedVar
+        @computed_var
         def var_without_annotation(self):
             return TestObj
 
@@ -76,7 +76,7 @@ def GrandChildState(ChildState, TestObj):
 @pytest.fixture
 def StateWithAnyVar(TestObj):
     class StateWithAnyVar(BaseState):
-        @ComputedVar
+        @computed_var
         def var_without_annotation(self) -> typing.Any:
             return TestObj
 
@@ -86,7 +86,7 @@ def StateWithAnyVar(TestObj):
 @pytest.fixture
 def StateWithCorrectVarAnnotation():
     class StateWithCorrectVarAnnotation(BaseState):
-        @ComputedVar
+        @computed_var
         def var_with_annotation(self) -> str:
             return "Correct annotation"
 
@@ -96,13 +96,53 @@ def StateWithCorrectVarAnnotation():
 @pytest.fixture
 def StateWithWrongVarAnnotation(TestObj):
     class StateWithWrongVarAnnotation(BaseState):
-        @ComputedVar
+        @computed_var
         def var_with_annotation(self) -> str:
             return TestObj
 
     return StateWithWrongVarAnnotation
 
 
+@pytest.fixture
+def StateWithInitialComputedVar():
+    class StateWithInitialComputedVar(BaseState):
+        @computed_var(initial_value="Initial value")
+        def var_with_initial_value(self) -> str:
+            return "Runtime value"
+
+    return StateWithInitialComputedVar
+
+
+@pytest.fixture
+def ChildWithInitialComputedVar(StateWithInitialComputedVar):
+    class ChildWithInitialComputedVar(StateWithInitialComputedVar):
+        @computed_var(initial_value="Initial value")
+        def var_with_initial_value_child(self) -> str:
+            return "Runtime value"
+
+    return ChildWithInitialComputedVar
+
+
+@pytest.fixture
+def StateWithRuntimeOnlyVar():
+    class StateWithRuntimeOnlyVar(BaseState):
+        @computed_var(initial_value=None)
+        def var_raises_at_runtime(self) -> str:
+            raise ValueError("So nicht, mein Freund")
+
+    return StateWithRuntimeOnlyVar
+
+
+@pytest.fixture
+def ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar):
+    class ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar):
+        @computed_var(initial_value="Initial value")
+        def var_raises_at_runtime_child(self) -> str:
+            raise ValueError("So nicht, mein Freund")
+
+    return ChildWithRuntimeOnlyVar
+
+
 @pytest.mark.parametrize(
     "prop,expected",
     zip(
@@ -731,6 +771,65 @@ def test_computed_var_with_annotation_error(request, fixture, full_name):
     )
 
 
+@pytest.mark.parametrize(
+    "fixture,var_name,expected_initial,expected_runtime,raises_at_runtime",
+    [
+        (
+            "StateWithInitialComputedVar",
+            "var_with_initial_value",
+            "Initial value",
+            "Runtime value",
+            False,
+        ),
+        (
+            "ChildWithInitialComputedVar",
+            "var_with_initial_value_child",
+            "Initial value",
+            "Runtime value",
+            False,
+        ),
+        (
+            "StateWithRuntimeOnlyVar",
+            "var_raises_at_runtime",
+            None,
+            None,
+            True,
+        ),
+        (
+            "ChildWithRuntimeOnlyVar",
+            "var_raises_at_runtime_child",
+            "Initial value",
+            None,
+            True,
+        ),
+    ],
+)
+def test_state_with_initial_computed_var(
+    request, fixture, var_name, expected_initial, expected_runtime, raises_at_runtime
+):
+    """Test that the initial and runtime values of a computed var are correct.
+
+    Args:
+        request: Fixture Request.
+        fixture: The state fixture.
+        var_name: The name of the computed var.
+        expected_initial: The expected initial value of the computed var.
+        expected_runtime: The expected runtime value of the computed var.
+        raises_at_runtime: Whether the computed var is runtime only.
+    """
+    state = request.getfixturevalue(fixture)()
+    state_name = state.get_full_name()
+    initial_dict = state.dict(initial=True)[state_name]
+    assert initial_dict[var_name] == expected_initial
+
+    if raises_at_runtime:
+        with pytest.raises(ValueError):
+            state.dict()[state_name][var_name]
+    else:
+        runtime_dict = state.dict()[state_name]
+        assert runtime_dict[var_name] == expected_runtime
+
+
 @pytest.mark.parametrize(
     "out, expected",
     [