浏览代码

Add `pc.list` for mutation detection (#339)

Tommy Dew 2 年之前
父节点
当前提交
d5a76f103a
共有 3 个文件被更改,包括 346 次插入2 次删除
  1. 70 1
      pynecone/state.py
  2. 93 0
      pynecone/var.py
  3. 183 1
      tests/test_app.py

+ 70 - 1
pynecone/state.py

@@ -13,7 +13,7 @@ from redis import Redis
 from pynecone import constants, utils
 from pynecone.base import Base
 from pynecone.event import Event, EventHandler, window_alert
-from pynecone.var import BaseVar, ComputedVar, Var
+from pynecone.var import BaseVar, ComputedVar, PCList, Var
 
 Delta = Dict[str, Any]
 
@@ -61,6 +61,40 @@ class State(Base, ABC):
         for substate in self.get_substates():
             self.substates[substate.get_name()] = substate().set(parent_state=self)
 
+        self._init_mutable_fields()
+
+    def _init_mutable_fields(self):
+        """Initialize mutable fields.
+
+        So that mutation to them can be detected by the app:
+        * list
+        """
+        for field in self.base_vars.values():
+            value = getattr(self, field.name)
+
+            value_in_pc_data = _convert_mutable_datatypes(
+                value, self._reassign_field, field.name
+            )
+
+            if utils._issubclass(field.type_, List):
+                setattr(self, field.name, value_in_pc_data)
+
+        self.clean()
+
+    def _reassign_field(self, field_name: str):
+        """Reassign the given field.
+
+        Primarily for mutation in fields of mutable data types.
+
+        Args:
+            field_name (str): The name of the field we want to reassign
+        """
+        setattr(
+            self,
+            field_name,
+            getattr(self, field_name),
+        )
+
     def __repr__(self) -> str:
         """Get the string representation of the state.
 
@@ -578,3 +612,38 @@ class StateManager(Base):
         if self.redis is None:
             return
         self.redis.set(token, pickle.dumps(state), ex=self.token_expiration)
+
+
+def _convert_mutable_datatypes(
+    field_value: Any, reassign_field: Callable, field_name: str
+) -> Any:
+    """Recursively convert mutable data to the Pc data types.
+
+    Note: right now only list & dict would be handled recursively.
+
+    Args:
+        field_value: The target field_value.
+        reassign_field:
+            The function to reassign the field in the parent state.
+        field_name: the name of the field in the parent state
+
+    Returns:
+        The converted field_value
+    """
+    if isinstance(field_value, list):
+        for index in range(len(field_value)):
+            field_value[index] = _convert_mutable_datatypes(
+                field_value[index], reassign_field, field_name
+            )
+
+        field_value = PCList(
+            field_value, reassign_field=reassign_field, field_name=field_name
+        )
+
+    elif isinstance(field_value, dict):
+        for key, value in field_value.items():
+            field_value[key] = _convert_mutable_datatypes(
+                value, reassign_field, field_name
+            )
+
+    return field_value

+ 93 - 0
pynecone/var.py

@@ -748,3 +748,96 @@ class ComputedVar(property, Var):
         if "return" in self.fget.__annotations__:
             return self.fget.__annotations__["return"]
         return Any
+
+
+class PCList(list):
+    """A custom list that pynecone can detect its mutation."""
+
+    def __init__(
+        self,
+        original_list: List,
+        reassign_field: Callable = lambda _field_name: None,
+        field_name: str = "",
+    ):
+        """Initialize PCList.
+
+        Args:
+            original_list (List): The original list
+            reassign_field (Callable):
+                The method in the parent state to reassign the field.
+                Default to be a no-op function
+            field_name (str): the name of field in the parent state
+        """
+        self._reassign_field = lambda: reassign_field(field_name)
+
+        super().__init__(original_list)
+
+    def append(self, *args, **kargs):
+        """Append.
+
+        Args:
+            args: The args passed.
+            kargs: The kwargs passed.
+        """
+        super().append(*args, **kargs)
+        self._reassign_field()
+
+    def __setitem__(self, *args, **kargs):
+        """Set item.
+
+        Args:
+            args: The args passed.
+            kargs: The kwargs passed.
+        """
+        super().__setitem__(*args, **kargs)
+        self._reassign_field()
+
+    def __delitem__(self, *args, **kargs):
+        """Delete item.
+
+        Args:
+            args: The args passed.
+            kargs: The kwargs passed.
+        """
+        super().__delitem__(*args, **kargs)
+        self._reassign_field()
+
+    def clear(self, *args, **kargs):
+        """Remove all item from the list.
+
+        Args:
+            args: The args passed.
+            kargs: The kwargs passed.
+        """
+        super().clear(*args, **kargs)
+        self._reassign_field()
+
+    def extend(self, *args, **kargs):
+        """Add all item of a list to the end of the list.
+
+        Args:
+            args: The args passed.
+            kargs: The kwargs passed.
+        """
+        super().extend(*args, **kargs)
+        self._reassign_field()
+
+    def pop(self, *args, **kargs):
+        """Remove an element.
+
+        Args:
+            args: The args passed.
+            kargs: The kwargs passed.
+        """
+        super().pop(*args, **kargs)
+        self._reassign_field()
+
+    def remove(self, *args, **kargs):
+        """Remove an element.
+
+        Args:
+            args: The args passed.
+            kargs: The kwargs passed.
+        """
+        super().remove(*args, **kargs)
+        self._reassign_field()

+ 183 - 1
tests/test_app.py

@@ -1,9 +1,10 @@
-from typing import Type
+from typing import List, Tuple, Type
 
 import pytest
 
 from pynecone.app import App, DefaultState
 from pynecone.components import Box
+from pynecone.event import Event
 from pynecone.middleware import HydrateMiddleware
 from pynecone.state import State
 from pynecone.style import Style
@@ -156,3 +157,184 @@ def test_set_and_get_state(TestState: Type[State]):
     state2 = app.state_manager.get_state(token2)
     assert state1.var == 1  # type: ignore
     assert state2.var == 2  # type: ignore
+
+
+@pytest.fixture
+def list_mutation_state():
+    """A fixture to create a state with list mutation features.
+
+    Returns:
+        A state with list mutation features.
+    """
+
+    class TestState(State):
+        """The test state."""
+
+        # plain list
+        plain_friends = ["Tommy"]
+
+        def make_friend(self):
+            self.plain_friends.append("another-fd")
+
+        def change_first_friend(self):
+            self.plain_friends[0] = "Jenny"
+
+        def unfriend_all_friends(self):
+            self.plain_friends.clear()
+
+        def unfriend_first_friend(self):
+            del self.plain_friends[0]
+
+        def remove_last_friend(self):
+            self.plain_friends.pop()
+
+        def make_friends_with_colleagues(self):
+            colleagues = ["Peter", "Jimmy"]
+            self.plain_friends.extend(colleagues)
+
+        def remove_tommy(self):
+            self.plain_friends.remove("Tommy")
+
+        # list in dict
+        friends_in_dict = {"Tommy": ["Jenny"]}
+
+        def remove_jenny_from_tommy(self):
+            self.friends_in_dict["Tommy"].remove("Jenny")
+
+        def add_jimmy_to_tommy_friends(self):
+            self.friends_in_dict["Tommy"].append("Jimmy")
+
+        def tommy_has_no_fds(self):
+            self.friends_in_dict["Tommy"].clear()
+
+        # nested list
+        friends_in_nested_list = [["Tommy"], ["Jenny"]]
+
+        def remove_first_group(self):
+            self.friends_in_nested_list.pop(0)
+
+        def remove_first_person_from_first_group(self):
+            self.friends_in_nested_list[0].pop(0)
+
+        def add_jimmy_to_second_group(self):
+            self.friends_in_nested_list[1].append("Jimmy")
+
+    return TestState()
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "event_tuples",
+    [
+        pytest.param(
+            [
+                (
+                    "test_state.make_friend",
+                    {"test_state": {"plain_friends": ["Tommy", "another-fd"]}},
+                ),
+                (
+                    "test_state.change_first_friend",
+                    {"test_state": {"plain_friends": ["Jenny", "another-fd"]}},
+                ),
+            ],
+            id="append then __setitem__",
+        ),
+        pytest.param(
+            [
+                (
+                    "test_state.unfriend_first_friend",
+                    {"test_state": {"plain_friends": []}},
+                ),
+                (
+                    "test_state.make_friend",
+                    {"test_state": {"plain_friends": ["another-fd"]}},
+                ),
+            ],
+            id="delitem then append",
+        ),
+        pytest.param(
+            [
+                (
+                    "test_state.make_friends_with_colleagues",
+                    {"test_state": {"plain_friends": ["Tommy", "Peter", "Jimmy"]}},
+                ),
+                (
+                    "test_state.remove_tommy",
+                    {"test_state": {"plain_friends": ["Peter", "Jimmy"]}},
+                ),
+                (
+                    "test_state.remove_last_friend",
+                    {"test_state": {"plain_friends": ["Peter"]}},
+                ),
+                (
+                    "test_state.unfriend_all_friends",
+                    {"test_state": {"plain_friends": []}},
+                ),
+            ],
+            id="extend, remove, pop, clear",
+        ),
+        pytest.param(
+            [
+                (
+                    "test_state.add_jimmy_to_second_group",
+                    {
+                        "test_state": {
+                            "friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]]
+                        }
+                    },
+                ),
+                (
+                    "test_state.remove_first_person_from_first_group",
+                    {
+                        "test_state": {
+                            "friends_in_nested_list": [[], ["Jenny", "Jimmy"]]
+                        }
+                    },
+                ),
+                (
+                    "test_state.remove_first_group",
+                    {"test_state": {"friends_in_nested_list": [["Jenny", "Jimmy"]]}},
+                ),
+            ],
+            id="nested list",
+        ),
+        pytest.param(
+            [
+                (
+                    "test_state.add_jimmy_to_tommy_friends",
+                    {"test_state": {"friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}}},
+                ),
+                (
+                    "test_state.remove_jenny_from_tommy",
+                    {"test_state": {"friends_in_dict": {"Tommy": ["Jimmy"]}}},
+                ),
+                (
+                    "test_state.tommy_has_no_fds",
+                    {"test_state": {"friends_in_dict": {"Tommy": []}}},
+                ),
+            ],
+            id="list in dict",
+        ),
+    ],
+)
+async def test_list_mutation_detection__plain_list(
+    event_tuples: List[Tuple[str, List[str]]], list_mutation_state: State
+):
+    """Test list mutation detection
+    when reassignment is not explicitly included in the logic.
+
+    Args:
+        event_tuples: From parametrization.
+        list_mutation_state: A state with list mutation features.
+    """
+    for event_name, expected_delta in event_tuples:
+        result = await list_mutation_state.process(
+            Event(
+                token="fake-token",
+                name=event_name,
+                router_data={"pathname": "/", "query": {}},
+                payload={},
+            )
+        )
+
+        assert result.delta == expected_delta