瀏覽代碼

Added support for `RelexSet` wrapper (#1535)

Smit Parmar 1 年之前
父節點
當前提交
ef78465f16
共有 5 個文件被更改,包括 121 次插入6 次删除
  1. 14 4
      reflex/state.py
  2. 85 0
      reflex/vars.py
  3. 3 1
      tests/conftest.py
  4. 8 1
      tests/test_state.py
  5. 11 0
      tests/test_var.py

+ 14 - 4
reflex/state.py

@@ -34,7 +34,7 @@ from reflex import constants
 from reflex.base import Base
 from reflex.base import Base
 from reflex.event import Event, EventHandler, EventSpec, fix_events, window_alert
 from reflex.event import Event, EventHandler, EventSpec, fix_events, window_alert
 from reflex.utils import format, prerequisites, types
 from reflex.utils import format, prerequisites, types
-from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList, Var
+from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList, ReflexSet, Var
 
 
 Delta = Dict[str, Any]
 Delta = Dict[str, Any]
 
 
@@ -601,8 +601,8 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             self.mark_dirty()
             self.mark_dirty()
             return
             return
 
 
-        # Make sure lists and dicts are converted to ReflexList and ReflexDict.
-        if name in self.vars and types._isinstance(value, Union[List, Dict]):
+        # Make sure lists and dicts are converted to ReflexList, ReflexDict and ReflexSet.
+        if name in self.vars and types._isinstance(value, Union[List, Dict, Set]):
             value = _convert_mutable_datatypes(value, self._reassign_field, name)
             value = _convert_mutable_datatypes(value, self._reassign_field, name)
 
 
         # Set the attribute.
         # Set the attribute.
@@ -985,7 +985,7 @@ def _convert_mutable_datatypes(
 ) -> Any:
 ) -> Any:
     """Recursively convert mutable data to the Rx data types.
     """Recursively convert mutable data to the Rx data types.
 
 
-    Note: right now only list & dict would be handled recursively.
+    Note: right now only list, dict and set would be handled recursively.
 
 
     Args:
     Args:
         field_value: The target field_value.
         field_value: The target field_value.
@@ -1015,4 +1015,14 @@ def _convert_mutable_datatypes(
             field_value, reassign_field=reassign_field, field_name=field_name
             field_value, reassign_field=reassign_field, field_name=field_name
         )
         )
 
 
+    if isinstance(field_value, set):
+        field_value = [
+            _convert_mutable_datatypes(value, reassign_field, field_name)
+            for value in field_value
+        ]
+
+        field_value = ReflexSet(
+            field_value, reassign_field=reassign_field, field_name=field_name
+        )
+
     return field_value
     return field_value

+ 85 - 0
reflex/vars.py

@@ -1130,6 +1130,91 @@ class ReflexDict(dict):
         self._reassign_field()
         self._reassign_field()
 
 
 
 
+class ReflexSet(set):
+    """A custom set that reflex can detect its mutation."""
+
+    def __init__(
+        self,
+        original_set: Set,
+        reassign_field: Callable = lambda _field_name: None,
+        field_name: str = "",
+    ):
+        """Initialize ReflexSet.
+
+        Args:
+            original_set (Set): The original set
+            reassign_field (Callable):
+                The method in the parent state to reassign the field.
+                Default to be a no-op function
+            field_name (str): the name of field in the parent state
+        """
+        self._reassign_field = lambda: reassign_field(field_name)
+
+        super().__init__(original_set)
+
+    def add(self, *args, **kwargs):
+        """Add an element to set.
+
+        Args:
+            args: The args passed.
+            kwargs: The kwargs passed.
+        """
+        super().add(*args, **kwargs)
+        self._reassign_field()
+
+    def remove(self, *args, **kwargs):
+        """Remove an element.
+        Raise key error if element not found.
+
+        Args:
+            args: The args passed.
+            kwargs: The kwargs passed.
+        """
+        super().remove(*args, **kwargs)
+        self._reassign_field()
+
+    def discard(self, *args, **kwargs):
+        """Remove an element.
+        Does not raise key error if element not found.
+
+        Args:
+            args: The args passed.
+            kwargs: The kwargs passed.
+        """
+        super().discard(*args, **kwargs)
+        self._reassign_field()
+
+    def pop(self, *args, **kwargs):
+        """Remove an element.
+
+        Args:
+            args: The args passed.
+            kwargs: The kwargs passed.
+        """
+        super().pop(*args, **kwargs)
+        self._reassign_field()
+
+    def clear(self, *args, **kwargs):
+        """Remove all elements from the set.
+
+        Args:
+            args: The args passed.
+            kwargs: The kwargs passed.
+        """
+        super().clear(*args, **kwargs)
+        self._reassign_field()
+
+    def update(self, *args, **kwargs):
+        """Adds elements from an iterable to the set.
+
+        Args:
+            args: The args passed.
+            kwargs: The kwargs passed.
+        """
+        super().update(*args, **kwargs)
+        self._reassign_field()
+
+
 class ImportVar(Base):
 class ImportVar(Base):
     """An import var."""
     """An import var."""
 
 

+ 3 - 1
tests/conftest.py

@@ -3,7 +3,7 @@ import contextlib
 import os
 import os
 import platform
 import platform
 from pathlib import Path
 from pathlib import Path
-from typing import Dict, Generator, List, Union
+from typing import Dict, Generator, List, Set, Union
 
 
 import pytest
 import pytest
 
 
@@ -561,6 +561,7 @@ def mutable_state():
             "another_key": "another_value",
             "another_key": "another_value",
             "third_key": {"key": "value"},
             "third_key": {"key": "value"},
         }
         }
+        test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
 
 
         def reassign_mutables(self):
         def reassign_mutables(self):
             self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
             self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
@@ -569,5 +570,6 @@ def mutable_state():
                 "mod_another_key": "another_value",
                 "mod_another_key": "another_value",
                 "mod_third_key": {"key": "value"},
                 "mod_third_key": {"key": "value"},
             }
             }
+            self.test_set = {1, 2, 3, 4, "five"}
 
 
     return MutableTestState()
     return MutableTestState()

+ 8 - 1
tests/test_state.py

@@ -10,7 +10,7 @@ from reflex.constants import IS_HYDRATED, RouteVar
 from reflex.event import Event, EventHandler
 from reflex.event import Event, EventHandler
 from reflex.state import State
 from reflex.state import State
 from reflex.utils import format
 from reflex.utils import format
-from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList
+from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList, ReflexSet
 
 
 
 
 class Object(Base):
 class Object(Base):
@@ -1164,6 +1164,7 @@ def test_setattr_of_mutable_types(mutable_state):
     """
     """
     array = mutable_state.array
     array = mutable_state.array
     hashmap = mutable_state.hashmap
     hashmap = mutable_state.hashmap
+    test_set = mutable_state.test_set
 
 
     assert isinstance(array, ReflexList)
     assert isinstance(array, ReflexList)
     assert isinstance(array[1], ReflexList)
     assert isinstance(array[1], ReflexList)
@@ -1173,10 +1174,14 @@ def test_setattr_of_mutable_types(mutable_state):
     assert isinstance(hashmap["key"], ReflexList)
     assert isinstance(hashmap["key"], ReflexList)
     assert isinstance(hashmap["third_key"], ReflexDict)
     assert isinstance(hashmap["third_key"], ReflexDict)
 
 
+    assert isinstance(test_set, set)
+
     mutable_state.reassign_mutables()
     mutable_state.reassign_mutables()
 
 
     array = mutable_state.array
     array = mutable_state.array
     hashmap = mutable_state.hashmap
     hashmap = mutable_state.hashmap
+    test_set = mutable_state.test_set
+
     assert isinstance(array, ReflexList)
     assert isinstance(array, ReflexList)
     assert isinstance(array[1], ReflexList)
     assert isinstance(array[1], ReflexList)
     assert isinstance(array[2], ReflexDict)
     assert isinstance(array[2], ReflexDict)
@@ -1184,3 +1189,5 @@ def test_setattr_of_mutable_types(mutable_state):
     assert isinstance(hashmap, ReflexDict)
     assert isinstance(hashmap, ReflexDict)
     assert isinstance(hashmap["mod_key"], ReflexList)
     assert isinstance(hashmap["mod_key"], ReflexList)
     assert isinstance(hashmap["mod_third_key"], ReflexDict)
     assert isinstance(hashmap["mod_third_key"], ReflexDict)
+
+    assert isinstance(test_set, ReflexSet)

+ 11 - 0
tests/test_var.py

@@ -13,6 +13,7 @@ from reflex.vars import (
     ImportVar,
     ImportVar,
     ReflexDict,
     ReflexDict,
     ReflexList,
     ReflexList,
+    ReflexSet,
     Var,
     Var,
     get_local_storage,
     get_local_storage,
 )
 )
@@ -532,6 +533,16 @@ def test_pickleable_rx_dict():
     assert cloudpickle.loads(pickled_dict) == rx_dict
     assert cloudpickle.loads(pickled_dict) == rx_dict
 
 
 
 
+def test_pickleable_rx_set():
+    """Test that ReflexSet is pickleable."""
+    rx_set = ReflexSet(
+        original_set={1, 2, 3}, reassign_field=lambda x: x, field_name="random"
+    )
+
+    pickled_set = cloudpickle.dumps(rx_set)
+    assert cloudpickle.loads(pickled_set) == rx_set
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "import_var,expected",
     "import_var,expected",
     zip(
     zip(