浏览代码

Create PCDict (#503)

Elijah Ahianyo 2 年之前
父节点
当前提交
8de2163d84
共有 5 个文件被更改,包括 380 次插入89 次删除
  1. 17 4
      pynecone/state.py
  2. 104 21
      pynecone/var.py
  3. 121 0
      tests/conftest.py
  4. 127 63
      tests/test_app.py
  5. 11 1
      tests/test_var.py

+ 17 - 4
pynecone/state.py

@@ -5,7 +5,18 @@ import asyncio
 import functools
 import traceback
 from abc import ABC
-from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Set, Type
+from typing import (
+    Any,
+    Callable,
+    ClassVar,
+    Dict,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Type,
+    Union,
+)
 
 import cloudpickle
 from redis import Redis
@@ -13,7 +24,7 @@ from redis import Redis
 from pynecone import constants, utils
 from pynecone.base import Base
 from pynecone.event import Event, EventHandler, window_alert
-from pynecone.var import BaseVar, ComputedVar, PCList, Var
+from pynecone.var import BaseVar, ComputedVar, PCDict, PCList, Var
 
 Delta = Dict[str, Any]
 
@@ -79,7 +90,7 @@ class State(Base, ABC):
                 value, self._reassign_field, field.name
             )
 
-            if utils._issubclass(field.type_, List):
+            if utils._issubclass(field.type_, Union[List, Dict]):
                 setattr(self, field.name, value_in_pc_data)
 
         self.clean()
@@ -727,5 +738,7 @@ def _convert_mutable_datatypes(
             field_value[key] = _convert_mutable_datatypes(
                 value, reassign_field, field_name
             )
-
+        field_value = PCDict(
+            field_value, reassign_field=reassign_field, field_name=field_name
+        )
     return field_value

+ 104 - 21
pynecone/var.py

@@ -773,72 +773,155 @@ class PCList(list):
 
         super().__init__(original_list)
 
-    def append(self, *args, **kargs):
+    def append(self, *args, **kwargs):
         """Append.
 
         Args:
             args: The args passed.
-            kargs: The kwargs passed.
+            kwargs: The kwargs passed.
         """
-        super().append(*args, **kargs)
+        super().append(*args, **kwargs)
         self._reassign_field()
 
-    def __setitem__(self, *args, **kargs):
+    def __setitem__(self, *args, **kwargs):
         """Set item.
 
         Args:
             args: The args passed.
-            kargs: The kwargs passed.
+            kwargs: The kwargs passed.
         """
-        super().__setitem__(*args, **kargs)
+        super().__setitem__(*args, **kwargs)
         self._reassign_field()
 
-    def __delitem__(self, *args, **kargs):
+    def __delitem__(self, *args, **kwargs):
         """Delete item.
 
         Args:
             args: The args passed.
-            kargs: The kwargs passed.
+            kwargs: The kwargs passed.
         """
-        super().__delitem__(*args, **kargs)
+        super().__delitem__(*args, **kwargs)
         self._reassign_field()
 
-    def clear(self, *args, **kargs):
+    def clear(self, *args, **kwargs):
         """Remove all item from the list.
 
         Args:
             args: The args passed.
-            kargs: The kwargs passed.
+            kwargs: The kwargs passed.
         """
-        super().clear(*args, **kargs)
+        super().clear(*args, **kwargs)
         self._reassign_field()
 
-    def extend(self, *args, **kargs):
+    def extend(self, *args, **kwargs):
         """Add all item of a list to the end of the list.
 
         Args:
             args: The args passed.
-            kargs: The kwargs passed.
+            kwargs: The kwargs passed.
         """
-        super().extend(*args, **kargs)
+        super().extend(*args, **kwargs)
         self._reassign_field() if hasattr(self, "_reassign_field") else None
 
-    def pop(self, *args, **kargs):
+    def pop(self, *args, **kwargs):
         """Remove an element.
 
         Args:
             args: The args passed.
-            kargs: The kwargs passed.
+            kwargs: The kwargs passed.
         """
-        super().pop(*args, **kargs)
+        super().pop(*args, **kwargs)
         self._reassign_field()
 
-    def remove(self, *args, **kargs):
+    def remove(self, *args, **kwargs):
         """Remove an element.
 
         Args:
             args: The args passed.
-            kargs: The kwargs passed.
+            kwargs: The kwargs passed.
         """
-        super().remove(*args, **kargs)
+        super().remove(*args, **kwargs)
+        self._reassign_field()
+
+
+class PCDict(dict):
+    """A custom dict that pynecone can detect its mutation."""
+
+    def __init__(
+        self,
+        original_dict: Dict,
+        reassign_field: Callable = lambda _field_name: None,
+        field_name: str = "",
+    ):
+        """Initialize PCDict.
+
+        Args:
+            original_dict: The original dict
+            reassign_field:
+                The method in the parent state to reassign the field.
+                Default to be a no-op function
+            field_name: the name of field in the parent state
+        """
+        super().__init__(original_dict)
+        self._reassign_field = lambda: reassign_field(field_name)
+
+    def clear(self):
+        """Remove all item from the list."""
+        super().clear()
+
+        self._reassign_field()
+
+    def setdefault(self, *args, **kwargs):
+        """set default.
+
+        Args:
+            args: The args passed.
+            kwargs: The kwargs passed.
+        """
+        super().setdefault(*args, **kwargs)
+        self._reassign_field()
+
+    def popitem(self):
+        """Pop last item."""
+        super().popitem()
+        self._reassign_field()
+
+    def pop(self, k, d=None):
+        """Remove an element.
+
+        Args:
+            k: The args passed.
+            d: The kwargs passed.
+        """
+        super().pop(k, d)
+        self._reassign_field()
+
+    def update(self, *args, **kwargs):
+        """update dict.
+
+        Args:
+            args: The args passed.
+            kwargs: The kwargs passed.
+        """
+        super().update(*args, **kwargs)
+        self._reassign_field()
+
+    def __setitem__(self, *args, **kwargs):
+        """set item.
+
+        Args:
+            args: The args passed.
+            kwargs: The kwargs passed.
+        """
+        super().__setitem__(*args, **kwargs)
+        self._reassign_field() if hasattr(self, "_reassign_field") else None
+
+    def __delitem__(self, *args, **kwargs):
+        """delete item.
+
+        Args:
+            args: The args passed.
+            kwargs: The kwargs passed.
+        """
+        super().__delitem__(*args, **kwargs)
         self._reassign_field()

+ 121 - 0
tests/conftest.py

@@ -4,6 +4,8 @@ from typing import Generator
 
 import pytest
 
+from pynecone.state import State
+
 
 @pytest.fixture(scope="function")
 def windows_platform() -> Generator:
@@ -13,3 +15,122 @@ def windows_platform() -> Generator:
         whether system is windows.
     """
     yield platform.system() == "Windows"
+
+
+@pytest.fixture
+def list_mutation_state():
+    """A fixture to create a state with list mutation features.
+
+    Returns:
+        A state with list mutation features.
+    """
+
+    class TestState(State):
+        """The test state."""
+
+        # plain list
+        plain_friends = ["Tommy"]
+
+        def make_friend(self):
+            self.plain_friends.append("another-fd")
+
+        def change_first_friend(self):
+            self.plain_friends[0] = "Jenny"
+
+        def unfriend_all_friends(self):
+            self.plain_friends.clear()
+
+        def unfriend_first_friend(self):
+            del self.plain_friends[0]
+
+        def remove_last_friend(self):
+            self.plain_friends.pop()
+
+        def make_friends_with_colleagues(self):
+            colleagues = ["Peter", "Jimmy"]
+            self.plain_friends.extend(colleagues)
+
+        def remove_tommy(self):
+            self.plain_friends.remove("Tommy")
+
+        # list in dict
+        friends_in_dict = {"Tommy": ["Jenny"]}
+
+        def remove_jenny_from_tommy(self):
+            self.friends_in_dict["Tommy"].remove("Jenny")
+
+        def add_jimmy_to_tommy_friends(self):
+            self.friends_in_dict["Tommy"].append("Jimmy")
+
+        def tommy_has_no_fds(self):
+            self.friends_in_dict["Tommy"].clear()
+
+        # nested list
+        friends_in_nested_list = [["Tommy"], ["Jenny"]]
+
+        def remove_first_group(self):
+            self.friends_in_nested_list.pop(0)
+
+        def remove_first_person_from_first_group(self):
+            self.friends_in_nested_list[0].pop(0)
+
+        def add_jimmy_to_second_group(self):
+            self.friends_in_nested_list[1].append("Jimmy")
+
+    return TestState()
+
+
+@pytest.fixture
+def dict_mutation_state():
+    """A fixture to create a state with dict mutation features.
+
+    Returns:
+        A state with dict mutation features.
+    """
+
+    class TestState(State):
+        """The test state."""
+
+        # plain dict
+        details = {"name": "Tommy"}
+
+        def add_age(self):
+            self.details.update({"age": 20})  # type: ignore
+
+        def change_name(self):
+            self.details["name"] = "Jenny"
+
+        def remove_last_detail(self):
+            self.details.popitem()
+
+        def clear_details(self):
+            self.details.clear()
+
+        def remove_name(self):
+            del self.details["name"]
+
+        def pop_out_age(self):
+            self.details.pop("age")
+
+        # dict in list
+        address = [{"home": "home address"}, {"work": "work address"}]
+
+        def remove_home_address(self):
+            self.address[0].pop("home")
+
+        def add_street_to_home_address(self):
+            self.address[0]["street"] = "street address"
+
+        # nested dict
+        friend_in_nested_dict = {"name": "Nikhil", "friend": {"name": "Alek"}}
+
+        def change_friend_name(self):
+            self.friend_in_nested_dict["friend"]["name"] = "Tommy"
+
+        def remove_friend(self):
+            self.friend_in_nested_dict.pop("friend")
+
+        def add_friend_age(self):
+            self.friend_in_nested_dict["friend"]["age"] = 30
+
+    return TestState()

+ 127 - 63
tests/test_app.py

@@ -164,69 +164,6 @@ def test_set_and_get_state(TestState: Type[State]):
     assert state2.var == 2  # type: ignore
 
 
-@pytest.fixture
-def list_mutation_state():
-    """A fixture to create a state with list mutation features.
-
-    Returns:
-        A state with list mutation features.
-    """
-
-    class TestState(State):
-        """The test state."""
-
-        # plain list
-        plain_friends = ["Tommy"]
-
-        def make_friend(self):
-            self.plain_friends.append("another-fd")
-
-        def change_first_friend(self):
-            self.plain_friends[0] = "Jenny"
-
-        def unfriend_all_friends(self):
-            self.plain_friends.clear()
-
-        def unfriend_first_friend(self):
-            del self.plain_friends[0]
-
-        def remove_last_friend(self):
-            self.plain_friends.pop()
-
-        def make_friends_with_colleagues(self):
-            colleagues = ["Peter", "Jimmy"]
-            self.plain_friends.extend(colleagues)
-
-        def remove_tommy(self):
-            self.plain_friends.remove("Tommy")
-
-        # list in dict
-        friends_in_dict = {"Tommy": ["Jenny"]}
-
-        def remove_jenny_from_tommy(self):
-            self.friends_in_dict["Tommy"].remove("Jenny")
-
-        def add_jimmy_to_tommy_friends(self):
-            self.friends_in_dict["Tommy"].append("Jimmy")
-
-        def tommy_has_no_fds(self):
-            self.friends_in_dict["Tommy"].clear()
-
-        # nested list
-        friends_in_nested_list = [["Tommy"], ["Jenny"]]
-
-        def remove_first_group(self):
-            self.friends_in_nested_list.pop(0)
-
-        def remove_first_person_from_first_group(self):
-            self.friends_in_nested_list[0].pop(0)
-
-        def add_jimmy_to_second_group(self):
-            self.friends_in_nested_list[1].append("Jimmy")
-
-    return TestState()
-
-
 @pytest.mark.asyncio
 @pytest.mark.parametrize(
     "event_tuples",
@@ -343,3 +280,130 @@ async def test_list_mutation_detection__plain_list(
         )
 
         assert result.delta == expected_delta
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "event_tuples",
+    [
+        pytest.param(
+            [
+                (
+                    "test_state.add_age",
+                    {"test_state": {"details": {"name": "Tommy", "age": 20}}},
+                ),
+                (
+                    "test_state.change_name",
+                    {"test_state": {"details": {"name": "Jenny", "age": 20}}},
+                ),
+                (
+                    "test_state.remove_last_detail",
+                    {"test_state": {"details": {"name": "Jenny"}}},
+                ),
+            ],
+            id="update then __setitem__",
+        ),
+        pytest.param(
+            [
+                (
+                    "test_state.clear_details",
+                    {"test_state": {"details": {}}},
+                ),
+                (
+                    "test_state.add_age",
+                    {"test_state": {"details": {"age": 20}}},
+                ),
+            ],
+            id="delitem then update",
+        ),
+        pytest.param(
+            [
+                (
+                    "test_state.add_age",
+                    {"test_state": {"details": {"name": "Tommy", "age": 20}}},
+                ),
+                (
+                    "test_state.remove_name",
+                    {"test_state": {"details": {"age": 20}}},
+                ),
+                (
+                    "test_state.pop_out_age",
+                    {"test_state": {"details": {}}},
+                ),
+            ],
+            id="add, remove, pop",
+        ),
+        pytest.param(
+            [
+                (
+                    "test_state.remove_home_address",
+                    {"test_state": {"address": [{}, {"work": "work address"}]}},
+                ),
+                (
+                    "test_state.add_street_to_home_address",
+                    {
+                        "test_state": {
+                            "address": [
+                                {"street": "street address"},
+                                {"work": "work address"},
+                            ]
+                        }
+                    },
+                ),
+            ],
+            id="dict in list",
+        ),
+        pytest.param(
+            [
+                (
+                    "test_state.change_friend_name",
+                    {
+                        "test_state": {
+                            "friend_in_nested_dict": {
+                                "name": "Nikhil",
+                                "friend": {"name": "Tommy"},
+                            }
+                        }
+                    },
+                ),
+                (
+                    "test_state.add_friend_age",
+                    {
+                        "test_state": {
+                            "friend_in_nested_dict": {
+                                "name": "Nikhil",
+                                "friend": {"name": "Tommy", "age": 30},
+                            }
+                        }
+                    },
+                ),
+                (
+                    "test_state.remove_friend",
+                    {"test_state": {"friend_in_nested_dict": {"name": "Nikhil"}}},
+                ),
+            ],
+            id="nested dict",
+        ),
+    ],
+)
+async def test_dict_mutation_detection__plain_list(
+    event_tuples: List[Tuple[str, List[str]]], dict_mutation_state: State
+):
+    """Test dict mutation detection
+    when reassignment is not explicitly included in the logic.
+
+    Args:
+        event_tuples: From parametrization.
+        dict_mutation_state: A state with dict mutation features.
+    """
+    for event_name, expected_delta in event_tuples:
+        result = await dict_mutation_state.process(
+            Event(
+                token="fake-token",
+                name=event_name,
+                router_data={"pathname": "/", "query": {}},
+                payload={},
+            )
+        )
+
+        assert result.delta == expected_delta

+ 11 - 1
tests/test_var.py

@@ -4,7 +4,7 @@ import cloudpickle
 import pytest
 
 from pynecone.base import Base
-from pynecone.var import BaseVar, PCList, Var
+from pynecone.var import BaseVar, PCDict, PCList, Var
 
 test_vars = [
     BaseVar(name="prop1", type_=int),
@@ -218,3 +218,13 @@ def test_pickleable_pc_list():
 
     pickled_list = cloudpickle.dumps(pc_list)
     assert cloudpickle.loads(pickled_list) == pc_list
+
+
+def test_pickleable_pc_dict():
+    """Test that PCDict is pickleable."""
+    pc_dict = PCDict(
+        original_dict={1: 2, 3: 4}, reassign_field=lambda x: x, field_name="random"
+    )
+
+    pickled_dict = cloudpickle.dumps(pc_dict)
+    assert cloudpickle.loads(pickled_dict) == pc_dict