Ver Fonte

List and Dict mutation on setattr (#1428)

Elijah Ahianyo há 1 ano atrás
pai
commit
3fa33bd644
3 ficheiros alterados com 78 adições e 10 exclusões
  1. 13 8
      reflex/state.py
  2. 34 1
      tests/conftest.py
  3. 31 1
      tests/test_state.py

+ 13 - 8
reflex/state.py

@@ -595,6 +595,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             self.mark_dirty()
             self.mark_dirty()
             return
             return
 
 
+        # Make sure lists and dicts are converted to ReflexList and ReflexDict.
+        if name in self.vars and types._isinstance(value, Union[List, Dict]):
+            value = _convert_mutable_datatypes(value, self._reassign_field, name)
+
         # Set the attribute.
         # Set the attribute.
         super().__setattr__(name, value)
         super().__setattr__(name, value)
 
 
@@ -990,21 +994,22 @@ def _convert_mutable_datatypes(
         The converted field_value
         The converted field_value
     """
     """
     if isinstance(field_value, list):
     if isinstance(field_value, list):
-        for index in range(len(field_value)):
-            field_value[index] = _convert_mutable_datatypes(
-                field_value[index], reassign_field, field_name
-            )
+        field_value = [
+            _convert_mutable_datatypes(value, reassign_field, field_name)
+            for value in field_value
+        ]
 
 
         field_value = ReflexList(
         field_value = ReflexList(
             field_value, reassign_field=reassign_field, field_name=field_name
             field_value, reassign_field=reassign_field, field_name=field_name
         )
         )
 
 
     if isinstance(field_value, dict):
     if isinstance(field_value, dict):
-        for key, value in field_value.items():
-            field_value[key] = _convert_mutable_datatypes(
-                value, reassign_field, field_name
-            )
+        field_value = {
+            key: _convert_mutable_datatypes(value, reassign_field, field_name)
+            for key, value in field_value.items()
+        }
         field_value = ReflexDict(
         field_value = ReflexDict(
             field_value, reassign_field=reassign_field, field_name=field_name
             field_value, reassign_field=reassign_field, field_name=field_name
         )
         )
+
     return field_value
     return field_value

+ 34 - 1
tests/conftest.py

@@ -3,7 +3,7 @@ import contextlib
 import os
 import os
 import platform
 import platform
 from pathlib import Path
 from pathlib import Path
-from typing import Dict, Generator, List
+from typing import Dict, Generator, List, Union
 
 
 import pytest
 import pytest
 
 
@@ -538,3 +538,36 @@ def tmp_working_dir(tmp_path):
     working_dir.mkdir()
     working_dir.mkdir()
     with chdir(working_dir):
     with chdir(working_dir):
         yield working_dir
         yield working_dir
+
+
+@pytest.fixture
+def mutable_state():
+    """Create a Test state containing mutable types.
+
+    Returns:
+        A state object.
+    """
+
+    class MutableTestState(rx.State):
+        """A test state."""
+
+        array: List[Union[str, List, Dict[str, str]]] = [
+            "value",
+            [1, 2, 3],
+            {"key": "value"},
+        ]
+        hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
+            "key": ["list", "of", "values"],
+            "another_key": "another_value",
+            "third_key": {"key": "value"},
+        }
+
+        def reassign_mutables(self):
+            self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
+            self.hashmap = {
+                "mod_key": ["list", "of", "values"],
+                "mod_another_key": "another_value",
+                "mod_third_key": {"key": "value"},
+            }
+
+    return MutableTestState()

+ 31 - 1
tests/test_state.py

@@ -10,7 +10,7 @@ from reflex.constants import IS_HYDRATED, RouteVar
 from reflex.event import Event, EventHandler
 from reflex.event import Event, EventHandler
 from reflex.state import State
 from reflex.state import State
 from reflex.utils import format
 from reflex.utils import format
-from reflex.vars import BaseVar, ComputedVar
+from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList
 
 
 
 
 class Object(Base):
 class Object(Base):
@@ -1140,3 +1140,33 @@ def test_backend_method():
     bms = BackendMethodState()
     bms = BackendMethodState()
     bms.handler()
     bms.handler()
     assert bms._be_method()
     assert bms._be_method()
+
+
+def test_setattr_of_mutable_types(mutable_state):
+    """Test that mutable types are converted to corresponding Reflex wrappers.
+
+    Args:
+        mutable_state: A test state.
+    """
+    array = mutable_state.array
+    hashmap = mutable_state.hashmap
+
+    assert isinstance(array, ReflexList)
+    assert isinstance(array[1], ReflexList)
+    assert isinstance(array[2], ReflexDict)
+
+    assert isinstance(hashmap, ReflexDict)
+    assert isinstance(hashmap["key"], ReflexList)
+    assert isinstance(hashmap["third_key"], ReflexDict)
+
+    mutable_state.reassign_mutables()
+
+    array = mutable_state.array
+    hashmap = mutable_state.hashmap
+    assert isinstance(array, ReflexList)
+    assert isinstance(array[1], ReflexList)
+    assert isinstance(array[2], ReflexDict)
+
+    assert isinstance(hashmap, ReflexDict)
+    assert isinstance(hashmap["mod_key"], ReflexList)
+    assert isinstance(hashmap["mod_third_key"], ReflexDict)