Browse Source

[ENG-5929] Improve type checking for Async rx.var (#5304)

Ensure that async computed vars can be properly type checked in frontend code.

Fix #5303
Masen Furer 1 day ago
parent
commit
25be166cd9
2 changed files with 62 additions and 1 deletions
  1. 35 1
      reflex/vars/base.py
  2. 27 0
      tests/units/test_var.py

+ 35 - 1
reflex/vars/base.py

@@ -2694,6 +2694,27 @@ if TYPE_CHECKING:
     BASE_STATE = TypeVar("BASE_STATE", bound=BaseState)
 
 
+class _ComputedVarDecorator(Protocol):
+    """A protocol for the ComputedVar decorator."""
+
+    @overload
+    def __call__(
+        self,
+        fget: Callable[[BASE_STATE], Coroutine[Any, Any, RETURN_TYPE]],
+    ) -> AsyncComputedVar[RETURN_TYPE]: ...
+
+    @overload
+    def __call__(
+        self,
+        fget: Callable[[BASE_STATE], RETURN_TYPE],
+    ) -> ComputedVar[RETURN_TYPE]: ...
+
+    def __call__(
+        self,
+        fget: Callable[[BASE_STATE], Any],
+    ) -> ComputedVar[Any]: ...
+
+
 @overload
 def computed_var(
     fget: None = None,
@@ -2704,7 +2725,20 @@ def computed_var(
     interval: datetime.timedelta | int | None = None,
     backend: bool | None = None,
     **kwargs,
-) -> Callable[[Callable[[BASE_STATE], RETURN_TYPE]], ComputedVar[RETURN_TYPE]]: ...  # pyright: ignore [reportInvalidTypeVarUse]
+) -> _ComputedVarDecorator: ...
+
+
+@overload
+def computed_var(
+    fget: Callable[[BASE_STATE], Coroutine[Any, Any, RETURN_TYPE]],
+    initial_value: RETURN_TYPE | types.Unset = types.Unset(),
+    cache: bool = True,
+    deps: list[str | Var] | None = None,
+    auto_deps: bool = True,
+    interval: datetime.timedelta | int | None = None,
+    backend: bool | None = None,
+    **kwargs,
+) -> AsyncComputedVar[RETURN_TYPE]: ...
 
 
 @overload

+ 27 - 0
tests/units/test_var.py

@@ -1961,3 +1961,30 @@ def test_decimal_var_type_compatibility():
 
     result = (dec_num + int_num) / float_num
     assert str(result) == "((123.456 + 42) / 3.14)"
+
+
+def test_computed_var_type_compatibility():
+    """Test that different ComputedVar are compatible with Var annotations of their returned type."""
+
+    class ComputedVarTypeState(BaseState):
+        @computed_var
+        def sync_plain(self) -> str:
+            return "Hello"
+
+        @computed_var(initial_value="Test")
+        def sync_wrapper(self) -> str:
+            return "World"
+
+        @computed_var
+        async def async_plain(self) -> str:
+            return "Hello"
+
+        @computed_var(initial_value="Test")
+        async def async_wrapper(self) -> str:
+            return "World"
+
+    # All of these vars should be assignable to a str field statically.
+    rx.input(placeholder=ComputedVarTypeState.sync_plain)
+    rx.input(placeholder=ComputedVarTypeState.sync_wrapper)
+    rx.input(placeholder=ComputedVarTypeState.async_plain)
+    rx.input(placeholder=ComputedVarTypeState.async_wrapper)