1
0
Эх сурвалжийг харах

[REF-1742] Radio group prop types fix (#2452)

Elijah Ahianyo 1 жил өмнө
parent
commit
d2fd0d3b92

+ 34 - 10
reflex/components/radix/themes/components/radiogroup.py

@@ -1,5 +1,5 @@
 """Interactive components provided by @radix-ui/themes."""
-from typing import Any, Dict, List, Literal
+from typing import Any, Dict, List, Literal, Optional, Union
 
 import reflex as rx
 from reflex.components.component import Component
@@ -97,7 +97,11 @@ class HighLevelRadioGroup(RadioGroupRoot):
     size: Var[Literal["1", "2", "3"]] = Var.create_safe("2")
 
     @classmethod
-    def create(cls, items: Var[List[str]], **props) -> Component:
+    def create(
+        cls,
+        items: Var[List[Optional[Union[str, int, float, list, dict, bool]]]],
+        **props
+    ) -> Component:
         """Create a radio group component.
 
         Args:
@@ -110,29 +114,49 @@ class HighLevelRadioGroup(RadioGroupRoot):
         direction = props.pop("direction", "column")
         gap = props.pop("gap", "2")
         size = props.pop("size", "2")
+        default_value = props.pop("default_value", "")
+
+        # convert only non-strings to json(JSON.stringify) so quotes are not rendered
+        # for string literal types.
+        if (
+            type(default_value) is str
+            or isinstance(default_value, Var)
+            and default_value._var_type is str
+        ):
+            default_value = Var.create(default_value, _var_is_string=True)  # type: ignore
+        else:
+            default_value = (
+                Var.create(default_value).to_string()._replace(_var_is_local=False)  # type: ignore
+            )
+
+        def radio_group_item(value: str | Var) -> Component:
+            item_value = Var.create(value)  # type: ignore
+            item_value = rx.cond(
+                item_value._type() == str,  # type: ignore
+                item_value,
+                item_value.to_string()._replace(_var_is_local=False),  # type: ignore
+            )._replace(_var_type=str)
 
-        def radio_group_item(value: str) -> Component:
             return Text.create(
                 Flex.create(
-                    RadioGroupItem.create(value=value),
-                    value,
+                    RadioGroupItem.create(value=item_value),
+                    item_value,
                     gap="2",
                 ),
                 size=size,
                 as_="label",
             )
 
-        if isinstance(items, Var):
-            child = [rx.foreach(items, radio_group_item)]
-        else:
-            child = [radio_group_item(value) for value in items]  #  type: ignore
+        items = Var.create(items)  # type: ignore
+        children = [rx.foreach(items, radio_group_item)]
 
         return RadioGroupRoot.create(
             Flex.create(
-                *child,
+                *children,
                 direction=direction,
                 gap=gap,
             ),
             size=size,
+            default_value=default_value,
             **props,
         )

+ 1 - 1
reflex/components/radix/themes/components/radiogroup.pyi

@@ -7,7 +7,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
-from typing import Any, Dict, List, Literal
+from typing import Any, Dict, List, Literal, Optional, Union
 import reflex as rx
 from reflex.components.component import Component
 from reflex.components.radix.themes.layout.flex import Flex

+ 32 - 4
reflex/vars.py

@@ -87,6 +87,15 @@ REPLACED_NAMES = {
     "deps": "_deps",
 }
 
+PYTHON_JS_TYPE_MAP = {
+    (int, float): "number",
+    (str,): "string",
+    (bool,): "boolean",
+    (list, tuple): "Array",
+    (dict,): "Object",
+    (None,): "null",
+}
+
 
 def get_unique_variable_name() -> str:
     """Get a unique variable name.
@@ -739,13 +748,13 @@ class Var:
                 operation_name = format.wrap(operation_name, "(")
         else:
             # apply operator to left operand (<operator> left_operand)
-            operation_name = f"{op}{self._var_full_name}"
+            operation_name = f"{op}{get_operand_full_name(self)}"
             # apply function to operands
             if fn is not None:
                 operation_name = (
                     f"{fn}({operation_name})"
                     if not invoke_fn
-                    else f"{self._var_full_name}.{fn}()"
+                    else f"{get_operand_full_name(self)}.{fn}()"
                 )
 
         return self._replace(
@@ -839,7 +848,20 @@ class Var:
             _var_is_string=False,
         )
 
-    def __eq__(self, other: Var) -> Var:
+    def _type(self) -> Var:
+        """Get the type of the Var in Javascript.
+
+        Returns:
+            A var representing the type check.
+        """
+        return self._replace(
+            _var_name=f"typeof {self._var_full_name}",
+            _var_type=str,
+            _var_is_string=False,
+            _var_full_name_needs_state_prefix=False,
+        )
+
+    def __eq__(self, other: Union[Var, Type]) -> Var:
         """Perform an equality comparison.
 
         Args:
@@ -848,9 +870,12 @@ class Var:
         Returns:
             A var representing the equality comparison.
         """
+        for python_types, js_type in PYTHON_JS_TYPE_MAP.items():
+            if not isinstance(other, Var) and other in python_types:
+                return self.compare("===", Var.create(js_type, _var_is_string=True))  # type: ignore
         return self.compare("===", other)
 
-    def __ne__(self, other: Var) -> Var:
+    def __ne__(self, other: Union[Var, Type]) -> Var:
         """Perform an inequality comparison.
 
         Args:
@@ -859,6 +884,9 @@ class Var:
         Returns:
             A var representing the inequality comparison.
         """
+        for python_types, js_type in PYTHON_JS_TYPE_MAP.items():
+            if not isinstance(other, Var) and other in python_types:
+                return self.compare("!==", Var.create(js_type, _var_is_string=True))  # type: ignore
         return self.compare("!==", other)
 
     def __gt__(self, other: Var) -> Var:

+ 53 - 0
tests/test_var.py

@@ -316,6 +316,59 @@ def test_basic_operations(TestObj):
         str(BaseVar(_var_name="foo", _var_type=list).reverse())
         == "{[...foo].reverse()}"
     )
+    assert str(BaseVar(_var_name="foo", _var_type=str)._type()) == "{typeof foo}"  # type: ignore
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() == str)  # type: ignore
+        == "{(typeof foo === `string`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() == str)  # type: ignore
+        == "{(typeof foo === `string`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() == int)  # type: ignore
+        == "{(typeof foo === `number`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() == list)  # type: ignore
+        == "{(typeof foo === `Array`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() == float)  # type: ignore
+        == "{(typeof foo === `number`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() == tuple)  # type: ignore
+        == "{(typeof foo === `Array`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() == dict)  # type: ignore
+        == "{(typeof foo === `Object`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() != str)  # type: ignore
+        == "{(typeof foo !== `string`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() != int)  # type: ignore
+        == "{(typeof foo !== `number`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() != list)  # type: ignore
+        == "{(typeof foo !== `Array`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() != float)  # type: ignore
+        == "{(typeof foo !== `number`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() != tuple)  # type: ignore
+        == "{(typeof foo !== `Array`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() != dict)  # type: ignore
+        == "{(typeof foo !== `Object`)}"
+    )
 
 
 @pytest.mark.parametrize(