فهرست منبع

fix rx.cond with ComputedVars and use union type (#3336)

benedikt-bartscher 1 سال پیش
والد
کامیت
656914edef
2فایلهای تغییر یافته به همراه39 افزوده شده و 6 حذف شده
  1. 9 3
      reflex/components/core/cond.py
  2. 30 3
      tests/components/core/test_cond.py

+ 9 - 3
reflex/components/core/cond.py

@@ -1,7 +1,8 @@
 """Create a list of components from an iterable."""
 """Create a list of components from an iterable."""
+
 from __future__ import annotations
 from __future__ import annotations
 
 
-from typing import Any, Dict, Optional, overload
+from typing import Any, Dict, Optional, Union, overload
 
 
 from reflex.components.base.fragment import Fragment
 from reflex.components.base.fragment import Fragment
 from reflex.components.component import BaseComponent, Component, MemoizationLeaf
 from reflex.components.component import BaseComponent, Component, MemoizationLeaf
@@ -10,7 +11,7 @@ from reflex.constants import Dirs
 from reflex.constants.colors import Color
 from reflex.constants.colors import Color
 from reflex.style import LIGHT_COLOR_MODE, color_mode
 from reflex.style import LIGHT_COLOR_MODE, color_mode
 from reflex.utils import format, imports
 from reflex.utils import format, imports
-from reflex.vars import BaseVar, Var, VarData
+from reflex.vars import Var, VarData
 
 
 _IS_TRUE_IMPORT = {
 _IS_TRUE_IMPORT = {
     f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="isTrue")],
     f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="isTrue")],
@@ -171,6 +172,11 @@ def cond(condition: Any, c1: Any, c2: Any = None):
     c2 = create_var(c2)
     c2 = create_var(c2)
     var_datas.extend([c1._var_data, c2._var_data])
     var_datas.extend([c1._var_data, c2._var_data])
 
 
+    c1_type = c1._var_type if isinstance(c1, Var) else type(c1)
+    c2_type = c2._var_type if isinstance(c2, Var) else type(c2)
+
+    var_type = c1_type if c1_type == c2_type else Union[c1_type, c2_type]
+
     # Create the conditional var.
     # Create the conditional var.
     return cond_var._replace(
     return cond_var._replace(
         _var_name=format.format_cond(
         _var_name=format.format_cond(
@@ -179,7 +185,7 @@ def cond(condition: Any, c1: Any, c2: Any = None):
             false_value=c2,
             false_value=c2,
             is_prop=True,
             is_prop=True,
         ),
         ),
-        _var_type=c1._var_type if isinstance(c1, BaseVar) else type(c1),
+        _var_type=var_type,
         _var_is_local=False,
         _var_is_local=False,
         _var_full_name_needs_state_prefix=False,
         _var_full_name_needs_state_prefix=False,
         merge_var_data=VarData.merge(*var_datas),
         merge_var_data=VarData.merge(*var_datas),

+ 30 - 3
tests/components/core/test_cond.py

@@ -1,13 +1,14 @@
 import json
 import json
-from typing import Any
+from typing import Any, Union
 
 
 import pytest
 import pytest
 
 
 from reflex.components.base.fragment import Fragment
 from reflex.components.base.fragment import Fragment
 from reflex.components.core.cond import Cond, cond
 from reflex.components.core.cond import Cond, cond
 from reflex.components.radix.themes.typography.text import Text
 from reflex.components.radix.themes.typography.text import Text
-from reflex.state import BaseState
-from reflex.vars import Var
+from reflex.state import BaseState, State
+from reflex.utils.format import format_state_name
+from reflex.vars import BaseVar, Var, computed_var
 
 
 
 
 @pytest.fixture
 @pytest.fixture
@@ -118,3 +119,29 @@ def test_cond_no_else():
     # Props do not support the use of cond without else
     # Props do not support the use of cond without else
     with pytest.raises(ValueError):
     with pytest.raises(ValueError):
         cond(True, "hello")  # type: ignore
         cond(True, "hello")  # type: ignore
+
+
+def test_cond_computed_var():
+    """Test if cond works with computed vars."""
+
+    class CondStateComputed(State):
+        @computed_var
+        def computed_int(self) -> int:
+            return 0
+
+        @computed_var
+        def computed_str(self) -> str:
+            return "a string"
+
+    comp = cond(True, CondStateComputed.computed_int, CondStateComputed.computed_str)
+
+    # TODO: shouln't this be a ComputedVar?
+    assert isinstance(comp, BaseVar)
+
+    state_name = format_state_name(CondStateComputed.get_full_name())
+    assert (
+        str(comp)
+        == f"{{isTrue(true) ? {state_name}.computed_int : {state_name}.computed_str}}"
+    )
+
+    assert comp._var_type == Union[int, str]