Răsfoiți Sursa

Validate ComputedVar dependencies, add tests (#3527)

benedikt-bartscher 10 luni în urmă
părinte
comite
41efe12e9a
6 a modificat fișierele cu 145 adăugiri și 9 ștergeri
  1. 21 3
      integration/test_computed_vars.py
  2. 29 0
      reflex/app.py
  3. 4 0
      reflex/utils/exceptions.py
  4. 12 0
      reflex/vars.py
  5. 37 5
      tests/test_app.py
  6. 42 1
      tests/test_var.py

+ 21 - 3
integration/test_computed_vars.py

@@ -37,8 +37,13 @@ def ComputedVars():
 
             return self.count
 
+        # explicit dependency on count var
+        @rx.var(cache=True, deps=["count"], auto_deps=False)
+        def depends_on_count(self) -> int:
+            return self.count
+
         # explicit dependency on count1 var
-        @rx.cached_var(deps=[count1], auto_deps=False)
+        @rx.var(cache=True, deps=[count1], auto_deps=False)
         def depends_on_count1(self) -> int:
             return self.count
 
@@ -70,6 +75,11 @@ def ComputedVars():
                 rx.text(State.count2, id="count2"),
                 rx.text("count3:"),
                 rx.text(State.count3, id="count3"),
+                rx.text("depends_on_count:"),
+                rx.text(
+                    State.depends_on_count,
+                    id="depends_on_count",
+                ),
                 rx.text("depends_on_count1:"),
                 rx.text(
                     State.depends_on_count1,
@@ -90,7 +100,7 @@ def ComputedVars():
 
 @pytest.fixture(scope="module")
 def computed_vars(
-    tmp_path_factory,
+    tmp_path_factory: pytest.TempPathFactory,
 ) -> Generator[AppHarness, None, None]:
     """Start ComputedVars app at tmp_path via AppHarness.
 
@@ -177,6 +187,10 @@ def test_computed_vars(
     assert count3
     assert count3.text == "0"
 
+    depends_on_count = driver.find_element(By.ID, "depends_on_count")
+    assert depends_on_count
+    assert depends_on_count.text == "0"
+
     depends_on_count1 = driver.find_element(By.ID, "depends_on_count1")
     assert depends_on_count1
     assert depends_on_count1.text == "0"
@@ -197,10 +211,14 @@ def test_computed_vars(
     assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1"
     assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1"
     assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1"
+    assert (
+        computed_vars.poll_for_content(depends_on_count, timeout=2, exp_not_equal="0")
+        == "1"
+    )
 
     mark_dirty.click()
     with pytest.raises(TimeoutError):
-        computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0")
+        _ = computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0")
 
     time.sleep(10)
     assert count3.text == "0"

+ 29 - 0
reflex/app.py

@@ -797,6 +797,34 @@ class App(LifespanMixin, Base):
         for render, kwargs in DECORATED_PAGES[get_config().app_name]:
             self.add_page(render, **kwargs)
 
+    def _validate_var_dependencies(
+        self, state: Optional[Type[BaseState]] = None
+    ) -> None:
+        """Validate the dependencies of the vars in the app.
+
+        Args:
+            state: The state to validate the dependencies for.
+
+        Raises:
+            VarDependencyError: When a computed var has an invalid dependency.
+        """
+        if not self.state:
+            return
+
+        if not state:
+            state = self.state
+
+        for var in state.computed_vars.values():
+            deps = var._deps(objclass=state)
+            for dep in deps:
+                if dep not in state.vars and dep not in state.backend_vars:
+                    raise exceptions.VarDependencyError(
+                        f"ComputedVar {var._var_name} on state {state.__name__} has an invalid dependency {dep}"
+                    )
+
+        for substate in state.class_subclasses:
+            self._validate_var_dependencies(substate)
+
     def _compile(self, export: bool = False):
         """Compile the app and output it to the pages folder.
 
@@ -818,6 +846,7 @@ class App(LifespanMixin, Base):
         if not self._should_compile():
             return
 
+        self._validate_var_dependencies()
         self._setup_overlay_component()
 
         # Create a progress bar.

+ 4 - 0
reflex/utils/exceptions.py

@@ -61,6 +61,10 @@ class VarOperationTypeError(ReflexError, TypeError):
     """Custom TypeError for when unsupported operations are performed on vars."""
 
 
+class VarDependencyError(ReflexError, ValueError):
+    """Custom ValueError for when a var depends on a non-existent var."""
+
+
 class InvalidStylePropError(ReflexError, TypeError):
     """Custom Type Error when style props have invalid values."""
 

+ 12 - 0
reflex/vars.py

@@ -1971,6 +1971,9 @@ class ComputedVar(Var, property):
             auto_deps: Whether var dependencies should be auto-determined.
             interval: Interval at which the computed var should be updated.
             **kwargs: additional attributes to set on the instance
+
+        Raises:
+            TypeError: If the computed var dependencies are not Var instances or var names.
         """
         self._initial_value = initial_value
         self._cache = cache
@@ -1979,6 +1982,15 @@ class ComputedVar(Var, property):
         self._update_interval = interval
         if deps is None:
             deps = []
+        else:
+            for dep in deps:
+                if isinstance(dep, Var):
+                    continue
+                if isinstance(dep, str) and dep != "":
+                    continue
+                raise TypeError(
+                    "ComputedVar dependencies must be Var instances or var names (non-empty strings)."
+                )
         self._static_deps = {
             dep._var_name if isinstance(dep, Var) else dep for dep in deps
         }

+ 37 - 5
tests/test_app.py

@@ -46,7 +46,7 @@ from reflex.state import (
 )
 from reflex.style import Style
 from reflex.utils import exceptions, format
-from reflex.vars import ComputedVar
+from reflex.vars import computed_var
 
 from .conftest import chdir
 from .states import (
@@ -951,7 +951,7 @@ class DynamicState(BaseState):
         """Increment the counter var."""
         self.counter = self.counter + 1
 
-    @ComputedVar
+    @computed_var
     def comp_dynamic(self) -> str:
         """A computed var that depends on the dynamic var.
 
@@ -1278,7 +1278,7 @@ def compilable_app(tmp_path) -> Generator[tuple[App, Path], None, None]:
         yield app, web_dir
 
 
-def test_app_wrap_compile_theme(compilable_app):
+def test_app_wrap_compile_theme(compilable_app: tuple[App, Path]):
     """Test that the radix theme component wraps the app.
 
     Args:
@@ -1306,7 +1306,7 @@ def test_app_wrap_compile_theme(compilable_app):
     ) in "".join(app_js_lines)
 
 
-def test_app_wrap_priority(compilable_app):
+def test_app_wrap_priority(compilable_app: tuple[App, Path]):
     """Test that the app wrap components are wrapped in the correct order.
 
     Args:
@@ -1490,7 +1490,7 @@ def test_add_page_component_returning_tuple():
 
 
 @pytest.mark.parametrize("export", (True, False))
-def test_app_with_transpile_packages(compilable_app, export):
+def test_app_with_transpile_packages(compilable_app: tuple[App, Path], export: bool):
     class C1(rx.Component):
         library = "foo@1.2.3"
         tag = "Foo"
@@ -1539,3 +1539,35 @@ def test_app_with_transpile_packages(compilable_app, export):
     else:
         assert 'output: "export"' not in next_config
         assert f'distDir: "{constants.Dirs.STATIC}"' not in next_config
+
+
+def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]):
+    app, _ = compilable_app
+
+    class ValidDepState(BaseState):
+        base: int = 0
+        _backend: int = 0
+
+        @computed_var
+        def foo(self) -> str:
+            return "foo"
+
+        @computed_var(deps=["_backend", "base", foo])
+        def bar(self) -> str:
+            return "bar"
+
+    app.state = ValidDepState
+    app._compile()
+
+
+def test_app_with_invalid_var_dependencies(compilable_app: tuple[App, Path]):
+    app, _ = compilable_app
+
+    class InvalidDepState(BaseState):
+        @computed_var(deps=["foolksjdf"])
+        def bar(self) -> str:
+            return "bar"
+
+    app.state = InvalidDepState
+    with pytest.raises(exceptions.VarDependencyError):
+        app._compile()

+ 42 - 1
tests/test_var.py

@@ -1,6 +1,6 @@
 import json
 import typing
-from typing import Dict, List, Set, Tuple
+from typing import Dict, List, Set, Tuple, Union
 
 import pytest
 from pandas import DataFrame
@@ -9,6 +9,7 @@ from reflex.base import Base
 from reflex.state import BaseState
 from reflex.vars import (
     BaseVar,
+    ComputedVar,
     Var,
     computed_var,
 )
@@ -1388,3 +1389,43 @@ def test_invalid_var_operations(operand1_var: Var, operand2_var, operators: List
 )
 def test_var_name_unwrapped(var, expected):
     assert var._var_name_unwrapped == expected
+
+
+def cv_fget(state: BaseState) -> int:
+    return 1
+
+
+@pytest.mark.parametrize(
+    "deps,expected",
+    [
+        (["a"], {"a"}),
+        (["b"], {"b"}),
+        ([ComputedVar(fget=cv_fget)], {"cv_fget"}),
+    ],
+)
+def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
+    @computed_var(
+        deps=deps,
+    )
+    def test_var(state) -> int:
+        return 1
+
+    assert test_var._static_deps == expected
+
+
+@pytest.mark.parametrize(
+    "deps",
+    [
+        [""],
+        [1],
+        ["", "abc"],
+    ],
+)
+def test_invalid_computed_var_deps(deps: List):
+    with pytest.raises(TypeError):
+
+        @computed_var(
+            deps=deps,
+        )
+        def test_var(state) -> int:
+            return 1