浏览代码

Initial values for computed vars (#2670)

* initial values for computed vars draft

* add tests, add computed_var overloads

* fix darglint

* pass initial to substates when calling dict

* add tests for for child states

* format black

* allow None as initial value

* rename runtime_only to raises_at_runtime

* cleanup unused arguments of ComputedVars

* refactor cached_var to be partial of computed_var
benedikt-bartscher 1 年之前
父节点
当前提交
93f402c773
共有 6 个文件被更改,包括 196 次插入37 次删除
  1. 1 1
      reflex/compiler/utils.py
  2. 20 9
      reflex/state.py
  3. 23 0
      reflex/utils/types.py
  4. 38 17
      reflex/vars.py
  5. 8 3
      reflex/vars.pyi
  6. 106 7
      tests/test_var.py

+ 1 - 1
reflex/compiler/utils.py

@@ -138,7 +138,7 @@ def compile_state(state: Type[BaseState]) -> dict:
         A dictionary of the compiled state.
         A dictionary of the compiled state.
     """
     """
     try:
     try:
-        initial_state = state().dict()
+        initial_state = state().dict(initial=True)
     except Exception as e:
     except Exception as e:
         console.warn(
         console.warn(
             f"Failed to compile initial state with computed vars, excluding them: {e}"
             f"Failed to compile initial state with computed vars, excluding them: {e}"

+ 20 - 9
reflex/state.py

@@ -46,10 +46,10 @@ from reflex.event import (
 from reflex.utils import console, format, prerequisites, types
 from reflex.utils import console, format, prerequisites, types
 from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
 from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
 from reflex.utils.serializers import SerializedType, serialize, serializer
 from reflex.utils.serializers import SerializedType, serialize, serializer
-from reflex.vars import BaseVar, ComputedVar, Var
+from reflex.vars import BaseVar, ComputedVar, Var, computed_var
 
 
 Delta = Dict[str, Any]
 Delta = Dict[str, Any]
-var = ComputedVar
+var = computed_var
 
 
 
 
 class HeaderData(Base):
 class HeaderData(Base):
@@ -1328,11 +1328,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             return super().get_value(key.__wrapped__)
             return super().get_value(key.__wrapped__)
         return super().get_value(key)
         return super().get_value(key)
 
 
-    def dict(self, include_computed: bool = True, **kwargs) -> dict[str, Any]:
+    def dict(
+        self, include_computed: bool = True, initial: bool = False, **kwargs
+    ) -> dict[str, Any]:
         """Convert the object to a dictionary.
         """Convert the object to a dictionary.
 
 
         Args:
         Args:
             include_computed: Whether to include computed vars.
             include_computed: Whether to include computed vars.
+            initial: Whether to get the initial value of computed vars.
             **kwargs: Kwargs to pass to the pydantic dict method.
             **kwargs: Kwargs to pass to the pydantic dict method.
 
 
         Returns:
         Returns:
@@ -1348,21 +1351,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             prop_name: self.get_value(getattr(self, prop_name))
             prop_name: self.get_value(getattr(self, prop_name))
             for prop_name in self.base_vars
             for prop_name in self.base_vars
         }
         }
-        computed_vars = (
-            {
+        if initial:
+            computed_vars = {
+                # Include initial computed vars.
+                prop_name: cv._initial_value
+                if isinstance(cv, ComputedVar)
+                and not isinstance(cv._initial_value, types.Unset)
+                else self.get_value(getattr(self, prop_name))
+                for prop_name, cv in self.computed_vars.items()
+            }
+        elif include_computed:
+            computed_vars = {
                 # Include the computed vars.
                 # Include the computed vars.
                 prop_name: self.get_value(getattr(self, prop_name))
                 prop_name: self.get_value(getattr(self, prop_name))
                 for prop_name in self.computed_vars
                 for prop_name in self.computed_vars
             }
             }
-            if include_computed
-            else {}
-        )
+        else:
+            computed_vars = {}
         variables = {**base_vars, **computed_vars}
         variables = {**base_vars, **computed_vars}
         d = {
         d = {
             self.get_full_name(): {k: variables[k] for k in sorted(variables)},
             self.get_full_name(): {k: variables[k] for k in sorted(variables)},
         }
         }
         for substate_d in [
         for substate_d in [
-            v.dict(include_computed=include_computed, **kwargs)
+            v.dict(include_computed=include_computed, initial=initial, **kwargs)
             for v in self.substates.values()
             for v in self.substates.values()
         ]:
         ]:
             d.update(substate_d)
             d.update(substate_d)

+ 23 - 0
reflex/utils/types.py

@@ -43,6 +43,29 @@ StateIterVar = Union[list, set, tuple]
 ArgsSpec = Callable
 ArgsSpec = Callable
 
 
 
 
+class Unset:
+    """A class to represent an unset value.
+
+    This is used to differentiate between a value that is not set and a value that is set to None.
+    """
+
+    def __repr__(self) -> str:
+        """Return the string representation of the class.
+
+        Returns:
+            The string representation of the class.
+        """
+        return "Unset"
+
+    def __bool__(self) -> bool:
+        """Return False when the class is used in a boolean context.
+
+        Returns:
+            False
+        """
+        return False
+
+
 def is_generic_alias(cls: GenericType) -> bool:
 def is_generic_alias(cls: GenericType) -> bool:
     """Check whether the class is a generic alias.
     """Check whether the class is a generic alias.
 
 

+ 38 - 17
reflex/vars.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 import contextlib
 import contextlib
 import dataclasses
 import dataclasses
 import dis
 import dis
+import functools
 import inspect
 import inspect
 import json
 import json
 import random
 import random
@@ -1802,24 +1803,26 @@ class ComputedVar(Var, property):
     # Whether to track dependencies and cache computed values
     # Whether to track dependencies and cache computed values
     _cache: bool = dataclasses.field(default=False)
     _cache: bool = dataclasses.field(default=False)
 
 
+    _initial_value: Any | types.Unset = dataclasses.field(default_factory=types.Unset)
+
     def __init__(
     def __init__(
         self,
         self,
         fget: Callable[[BaseState], Any],
         fget: Callable[[BaseState], Any],
-        fset: Callable[[BaseState, Any], None] | None = None,
-        fdel: Callable[[BaseState], Any] | None = None,
-        doc: str | None = None,
+        initial_value: Any | types.Unset = types.Unset(),
+        cache: bool = False,
         **kwargs,
         **kwargs,
     ):
     ):
         """Initialize a ComputedVar.
         """Initialize a ComputedVar.
 
 
         Args:
         Args:
             fget: The getter function.
             fget: The getter function.
-            fset: The setter function.
-            fdel: The deleter function.
-            doc: The docstring.
+            initial_value: The initial value of the computed var.
+            cache: Whether to cache the computed value.
             **kwargs: additional attributes to set on the instance
             **kwargs: additional attributes to set on the instance
         """
         """
-        property.__init__(self, fget, fset, fdel, doc)
+        self._initial_value = initial_value
+        self._cache = cache
+        property.__init__(self, fget)
         kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__)
         kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__)
         kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
         kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
         BaseVar.__init__(self, **kwargs)  # type: ignore
         BaseVar.__init__(self, **kwargs)  # type: ignore
@@ -1960,21 +1963,39 @@ class ComputedVar(Var, property):
         return Any
         return Any
 
 
 
 
-def cached_var(fget: Callable[[Any], Any]) -> ComputedVar:
-    """A field with computed getter that tracks other state dependencies.
-
-    The cached_var will only be recalculated when other state vars that it
-    depends on are modified.
+def computed_var(
+    fget: Callable[[BaseState], Any] | None = None,
+    initial_value: Any | None = None,
+    cache: bool = False,
+    **kwargs,
+) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
+    """A ComputedVar decorator with or without kwargs.
 
 
     Args:
     Args:
-        fget: the function that calculates the variable value.
+        fget: The getter function.
+        initial_value: The initial value of the computed var.
+        cache: Whether to cache the computed value.
+        **kwargs: additional attributes to set on the instance
 
 
     Returns:
     Returns:
-        ComputedVar that is recomputed when dependencies change.
+        A ComputedVar instance.
     """
     """
-    cvar = ComputedVar(fget=fget)
-    cvar._cache = True
-    return cvar
+    if fget is not None:
+        return ComputedVar(fget=fget, cache=cache)
+
+    def wrapper(fget):
+        return ComputedVar(
+            fget=fget,
+            initial_value=initial_value,
+            cache=cache,
+            **kwargs,
+        )
+
+    return wrapper
+
+
+# Partial function of computed_var with cache=True
+cached_var = functools.partial(computed_var, cache=True)
 
 
 
 
 class CallableVar(BaseVar):
 class CallableVar(BaseVar):

+ 8 - 3
reflex/vars.pyi

@@ -144,14 +144,19 @@ class ComputedVar(Var):
     def __init__(
     def __init__(
         self,
         self,
         fget: Callable[[BaseState], Any],
         fget: Callable[[BaseState], Any],
-        fset: Callable[[BaseState, Any], None] | None = None,
-        fdel: Callable[[BaseState], Any] | None = None,
-        doc: str | None = None,
         **kwargs,
         **kwargs,
     ) -> None: ...
     ) -> None: ...
     @overload
     @overload
     def __init__(self, func) -> None: ...
     def __init__(self, func) -> None: ...
 
 
+@overload
+def computed_var(
+    fget: Callable[[BaseState], Any] | None = None,
+    initial_value: Any | None = None,
+    **kwargs,
+) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
+@overload
+def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
 def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
 def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
 
 
 class CallableVar(BaseVar):
 class CallableVar(BaseVar):

+ 106 - 7
tests/test_var.py

@@ -9,8 +9,8 @@ from reflex.base import Base
 from reflex.state import BaseState
 from reflex.state import BaseState
 from reflex.vars import (
 from reflex.vars import (
     BaseVar,
     BaseVar,
-    ComputedVar,
     Var,
     Var,
+    computed_var,
 )
 )
 
 
 test_vars = [
 test_vars = [
@@ -46,7 +46,7 @@ def ParentState(TestObj):
         foo: int
         foo: int
         bar: int
         bar: int
 
 
-        @ComputedVar
+        @computed_var
         def var_without_annotation(self):
         def var_without_annotation(self):
             return TestObj
             return TestObj
 
 
@@ -56,7 +56,7 @@ def ParentState(TestObj):
 @pytest.fixture
 @pytest.fixture
 def ChildState(ParentState, TestObj):
 def ChildState(ParentState, TestObj):
     class ChildState(ParentState):
     class ChildState(ParentState):
-        @ComputedVar
+        @computed_var
         def var_without_annotation(self):
         def var_without_annotation(self):
             return TestObj
             return TestObj
 
 
@@ -66,7 +66,7 @@ def ChildState(ParentState, TestObj):
 @pytest.fixture
 @pytest.fixture
 def GrandChildState(ChildState, TestObj):
 def GrandChildState(ChildState, TestObj):
     class GrandChildState(ChildState):
     class GrandChildState(ChildState):
-        @ComputedVar
+        @computed_var
         def var_without_annotation(self):
         def var_without_annotation(self):
             return TestObj
             return TestObj
 
 
@@ -76,7 +76,7 @@ def GrandChildState(ChildState, TestObj):
 @pytest.fixture
 @pytest.fixture
 def StateWithAnyVar(TestObj):
 def StateWithAnyVar(TestObj):
     class StateWithAnyVar(BaseState):
     class StateWithAnyVar(BaseState):
-        @ComputedVar
+        @computed_var
         def var_without_annotation(self) -> typing.Any:
         def var_without_annotation(self) -> typing.Any:
             return TestObj
             return TestObj
 
 
@@ -86,7 +86,7 @@ def StateWithAnyVar(TestObj):
 @pytest.fixture
 @pytest.fixture
 def StateWithCorrectVarAnnotation():
 def StateWithCorrectVarAnnotation():
     class StateWithCorrectVarAnnotation(BaseState):
     class StateWithCorrectVarAnnotation(BaseState):
-        @ComputedVar
+        @computed_var
         def var_with_annotation(self) -> str:
         def var_with_annotation(self) -> str:
             return "Correct annotation"
             return "Correct annotation"
 
 
@@ -96,13 +96,53 @@ def StateWithCorrectVarAnnotation():
 @pytest.fixture
 @pytest.fixture
 def StateWithWrongVarAnnotation(TestObj):
 def StateWithWrongVarAnnotation(TestObj):
     class StateWithWrongVarAnnotation(BaseState):
     class StateWithWrongVarAnnotation(BaseState):
-        @ComputedVar
+        @computed_var
         def var_with_annotation(self) -> str:
         def var_with_annotation(self) -> str:
             return TestObj
             return TestObj
 
 
     return StateWithWrongVarAnnotation
     return StateWithWrongVarAnnotation
 
 
 
 
+@pytest.fixture
+def StateWithInitialComputedVar():
+    class StateWithInitialComputedVar(BaseState):
+        @computed_var(initial_value="Initial value")
+        def var_with_initial_value(self) -> str:
+            return "Runtime value"
+
+    return StateWithInitialComputedVar
+
+
+@pytest.fixture
+def ChildWithInitialComputedVar(StateWithInitialComputedVar):
+    class ChildWithInitialComputedVar(StateWithInitialComputedVar):
+        @computed_var(initial_value="Initial value")
+        def var_with_initial_value_child(self) -> str:
+            return "Runtime value"
+
+    return ChildWithInitialComputedVar
+
+
+@pytest.fixture
+def StateWithRuntimeOnlyVar():
+    class StateWithRuntimeOnlyVar(BaseState):
+        @computed_var(initial_value=None)
+        def var_raises_at_runtime(self) -> str:
+            raise ValueError("So nicht, mein Freund")
+
+    return StateWithRuntimeOnlyVar
+
+
+@pytest.fixture
+def ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar):
+    class ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar):
+        @computed_var(initial_value="Initial value")
+        def var_raises_at_runtime_child(self) -> str:
+            raise ValueError("So nicht, mein Freund")
+
+    return ChildWithRuntimeOnlyVar
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "prop,expected",
     "prop,expected",
     zip(
     zip(
@@ -731,6 +771,65 @@ def test_computed_var_with_annotation_error(request, fixture, full_name):
     )
     )
 
 
 
 
+@pytest.mark.parametrize(
+    "fixture,var_name,expected_initial,expected_runtime,raises_at_runtime",
+    [
+        (
+            "StateWithInitialComputedVar",
+            "var_with_initial_value",
+            "Initial value",
+            "Runtime value",
+            False,
+        ),
+        (
+            "ChildWithInitialComputedVar",
+            "var_with_initial_value_child",
+            "Initial value",
+            "Runtime value",
+            False,
+        ),
+        (
+            "StateWithRuntimeOnlyVar",
+            "var_raises_at_runtime",
+            None,
+            None,
+            True,
+        ),
+        (
+            "ChildWithRuntimeOnlyVar",
+            "var_raises_at_runtime_child",
+            "Initial value",
+            None,
+            True,
+        ),
+    ],
+)
+def test_state_with_initial_computed_var(
+    request, fixture, var_name, expected_initial, expected_runtime, raises_at_runtime
+):
+    """Test that the initial and runtime values of a computed var are correct.
+
+    Args:
+        request: Fixture Request.
+        fixture: The state fixture.
+        var_name: The name of the computed var.
+        expected_initial: The expected initial value of the computed var.
+        expected_runtime: The expected runtime value of the computed var.
+        raises_at_runtime: Whether the computed var is runtime only.
+    """
+    state = request.getfixturevalue(fixture)()
+    state_name = state.get_full_name()
+    initial_dict = state.dict(initial=True)[state_name]
+    assert initial_dict[var_name] == expected_initial
+
+    if raises_at_runtime:
+        with pytest.raises(ValueError):
+            state.dict()[state_name][var_name]
+    else:
+        runtime_dict = state.dict()[state_name]
+        assert runtime_dict[var_name] == expected_runtime
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "out, expected",
     "out, expected",
     [
     [