Quellcode durchsuchen

[REF-1742] Radio group prop types fix (#2452)

Elijah Ahianyo vor 1 Jahr
Ursprung
Commit
d2fd0d3b92

+ 34 - 10
reflex/components/radix/themes/components/radiogroup.py

@@ -1,5 +1,5 @@
 """Interactive components provided by @radix-ui/themes."""
 """Interactive components provided by @radix-ui/themes."""
-from typing import Any, Dict, List, Literal
+from typing import Any, Dict, List, Literal, Optional, Union
 
 
 import reflex as rx
 import reflex as rx
 from reflex.components.component import Component
 from reflex.components.component import Component
@@ -97,7 +97,11 @@ class HighLevelRadioGroup(RadioGroupRoot):
     size: Var[Literal["1", "2", "3"]] = Var.create_safe("2")
     size: Var[Literal["1", "2", "3"]] = Var.create_safe("2")
 
 
     @classmethod
     @classmethod
-    def create(cls, items: Var[List[str]], **props) -> Component:
+    def create(
+        cls,
+        items: Var[List[Optional[Union[str, int, float, list, dict, bool]]]],
+        **props
+    ) -> Component:
         """Create a radio group component.
         """Create a radio group component.
 
 
         Args:
         Args:
@@ -110,29 +114,49 @@ class HighLevelRadioGroup(RadioGroupRoot):
         direction = props.pop("direction", "column")
         direction = props.pop("direction", "column")
         gap = props.pop("gap", "2")
         gap = props.pop("gap", "2")
         size = props.pop("size", "2")
         size = props.pop("size", "2")
+        default_value = props.pop("default_value", "")
+
+        # convert only non-strings to json(JSON.stringify) so quotes are not rendered
+        # for string literal types.
+        if (
+            type(default_value) is str
+            or isinstance(default_value, Var)
+            and default_value._var_type is str
+        ):
+            default_value = Var.create(default_value, _var_is_string=True)  # type: ignore
+        else:
+            default_value = (
+                Var.create(default_value).to_string()._replace(_var_is_local=False)  # type: ignore
+            )
+
+        def radio_group_item(value: str | Var) -> Component:
+            item_value = Var.create(value)  # type: ignore
+            item_value = rx.cond(
+                item_value._type() == str,  # type: ignore
+                item_value,
+                item_value.to_string()._replace(_var_is_local=False),  # type: ignore
+            )._replace(_var_type=str)
 
 
-        def radio_group_item(value: str) -> Component:
             return Text.create(
             return Text.create(
                 Flex.create(
                 Flex.create(
-                    RadioGroupItem.create(value=value),
-                    value,
+                    RadioGroupItem.create(value=item_value),
+                    item_value,
                     gap="2",
                     gap="2",
                 ),
                 ),
                 size=size,
                 size=size,
                 as_="label",
                 as_="label",
             )
             )
 
 
-        if isinstance(items, Var):
-            child = [rx.foreach(items, radio_group_item)]
-        else:
-            child = [radio_group_item(value) for value in items]  #  type: ignore
+        items = Var.create(items)  # type: ignore
+        children = [rx.foreach(items, radio_group_item)]
 
 
         return RadioGroupRoot.create(
         return RadioGroupRoot.create(
             Flex.create(
             Flex.create(
-                *child,
+                *children,
                 direction=direction,
                 direction=direction,
                 gap=gap,
                 gap=gap,
             ),
             ),
             size=size,
             size=size,
+            default_value=default_value,
             **props,
             **props,
         )
         )

+ 1 - 1
reflex/components/radix/themes/components/radiogroup.pyi

@@ -7,7 +7,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
 from reflex.style import Style
-from typing import Any, Dict, List, Literal
+from typing import Any, Dict, List, Literal, Optional, Union
 import reflex as rx
 import reflex as rx
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.components.radix.themes.layout.flex import Flex
 from reflex.components.radix.themes.layout.flex import Flex

+ 32 - 4
reflex/vars.py

@@ -87,6 +87,15 @@ REPLACED_NAMES = {
     "deps": "_deps",
     "deps": "_deps",
 }
 }
 
 
+PYTHON_JS_TYPE_MAP = {
+    (int, float): "number",
+    (str,): "string",
+    (bool,): "boolean",
+    (list, tuple): "Array",
+    (dict,): "Object",
+    (None,): "null",
+}
+
 
 
 def get_unique_variable_name() -> str:
 def get_unique_variable_name() -> str:
     """Get a unique variable name.
     """Get a unique variable name.
@@ -739,13 +748,13 @@ class Var:
                 operation_name = format.wrap(operation_name, "(")
                 operation_name = format.wrap(operation_name, "(")
         else:
         else:
             # apply operator to left operand (<operator> left_operand)
             # apply operator to left operand (<operator> left_operand)
-            operation_name = f"{op}{self._var_full_name}"
+            operation_name = f"{op}{get_operand_full_name(self)}"
             # apply function to operands
             # apply function to operands
             if fn is not None:
             if fn is not None:
                 operation_name = (
                 operation_name = (
                     f"{fn}({operation_name})"
                     f"{fn}({operation_name})"
                     if not invoke_fn
                     if not invoke_fn
-                    else f"{self._var_full_name}.{fn}()"
+                    else f"{get_operand_full_name(self)}.{fn}()"
                 )
                 )
 
 
         return self._replace(
         return self._replace(
@@ -839,7 +848,20 @@ class Var:
             _var_is_string=False,
             _var_is_string=False,
         )
         )
 
 
-    def __eq__(self, other: Var) -> Var:
+    def _type(self) -> Var:
+        """Get the type of the Var in Javascript.
+
+        Returns:
+            A var representing the type check.
+        """
+        return self._replace(
+            _var_name=f"typeof {self._var_full_name}",
+            _var_type=str,
+            _var_is_string=False,
+            _var_full_name_needs_state_prefix=False,
+        )
+
+    def __eq__(self, other: Union[Var, Type]) -> Var:
         """Perform an equality comparison.
         """Perform an equality comparison.
 
 
         Args:
         Args:
@@ -848,9 +870,12 @@ class Var:
         Returns:
         Returns:
             A var representing the equality comparison.
             A var representing the equality comparison.
         """
         """
+        for python_types, js_type in PYTHON_JS_TYPE_MAP.items():
+            if not isinstance(other, Var) and other in python_types:
+                return self.compare("===", Var.create(js_type, _var_is_string=True))  # type: ignore
         return self.compare("===", other)
         return self.compare("===", other)
 
 
-    def __ne__(self, other: Var) -> Var:
+    def __ne__(self, other: Union[Var, Type]) -> Var:
         """Perform an inequality comparison.
         """Perform an inequality comparison.
 
 
         Args:
         Args:
@@ -859,6 +884,9 @@ class Var:
         Returns:
         Returns:
             A var representing the inequality comparison.
             A var representing the inequality comparison.
         """
         """
+        for python_types, js_type in PYTHON_JS_TYPE_MAP.items():
+            if not isinstance(other, Var) and other in python_types:
+                return self.compare("!==", Var.create(js_type, _var_is_string=True))  # type: ignore
         return self.compare("!==", other)
         return self.compare("!==", other)
 
 
     def __gt__(self, other: Var) -> Var:
     def __gt__(self, other: Var) -> Var:

+ 53 - 0
tests/test_var.py

@@ -316,6 +316,59 @@ def test_basic_operations(TestObj):
         str(BaseVar(_var_name="foo", _var_type=list).reverse())
         str(BaseVar(_var_name="foo", _var_type=list).reverse())
         == "{[...foo].reverse()}"
         == "{[...foo].reverse()}"
     )
     )
+    assert str(BaseVar(_var_name="foo", _var_type=str)._type()) == "{typeof foo}"  # type: ignore
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() == str)  # type: ignore
+        == "{(typeof foo === `string`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() == str)  # type: ignore
+        == "{(typeof foo === `string`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() == int)  # type: ignore
+        == "{(typeof foo === `number`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() == list)  # type: ignore
+        == "{(typeof foo === `Array`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() == float)  # type: ignore
+        == "{(typeof foo === `number`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() == tuple)  # type: ignore
+        == "{(typeof foo === `Array`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() == dict)  # type: ignore
+        == "{(typeof foo === `Object`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() != str)  # type: ignore
+        == "{(typeof foo !== `string`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() != int)  # type: ignore
+        == "{(typeof foo !== `number`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() != list)  # type: ignore
+        == "{(typeof foo !== `Array`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() != float)  # type: ignore
+        == "{(typeof foo !== `number`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() != tuple)  # type: ignore
+        == "{(typeof foo !== `Array`)}"
+    )
+    assert (
+        str(BaseVar(_var_name="foo", _var_type=str)._type() != dict)  # type: ignore
+        == "{(typeof foo !== `Object`)}"
+    )
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(