Parcourir la source

Fix substate event handlers (#171)

Nikhil Rao il y a 2 ans
Parent
commit
43bd4784dc
2 fichiers modifiés avec 213 ajouts et 248 suppressions
  1. 17 9
      pynecone/utils.py
  2. 196 239
      tests/test_state.py

+ 17 - 9
pynecone/utils.py

@@ -816,20 +816,28 @@ def format_cond(
     return expr
 
 
-def format_event_fn(fn: Callable) -> str:
-    """Format a function as an event.
+def format_event_handler(handler: EventHandler) -> str:
+    """Format an event handler.
 
     Args:
-        fn: The function to format.
+        handler: The event handler to format.
 
     Returns:
         The formatted function.
     """
-    from pynecone.event import EventHandler
+    # Get the class that defines the event handler.
+    parts = handler.fn.__qualname__.split(".")
 
-    if isinstance(fn, EventHandler):
-        fn = fn.fn
-    return to_snake_case(fn.__qualname__)
+    # If there's no enclosing class, just return the function name.
+    if len(parts) == 1:
+        return parts[-1]
+
+    # Get the state and the function name.
+    state_name, name = parts[-2:]
+
+    # Construct the full event handler name.
+    state = vars(sys.modules[handler.fn.__module__])[state_name]
+    return ".".join([state.get_full_name(), name])
 
 
 def format_event(event_spec: EventSpec) -> str:
@@ -842,7 +850,7 @@ def format_event(event_spec: EventSpec) -> str:
         The compiled event.
     """
     args = ",".join([":".join((name, val)) for name, val in event_spec.args])
-    return f"E(\"{format_event_fn(event_spec.handler.fn)}\", {wrap(args, '{')})"
+    return f"E(\"{format_event_handler(event_spec.handler)}\", {wrap(args, '{')})"
 
 
 USED_VARIABLES = set()
@@ -1058,7 +1066,7 @@ def fix_events(events: Optional[List[Event]], token: str) -> List[Event]:
             if isinstance(e, EventHandler):
                 e = e()
             assert isinstance(e, EventSpec), f"Unexpected event type, {type(e)}."
-            name = format_event_fn(e.handler.fn)
+            name = format_event_handler(e.handler)
             payload = dict(e.args)
 
         # Create an event and append it to the list.

+ 196 - 239
tests/test_state.py

@@ -2,123 +2,150 @@ from typing import Dict, List
 
 import pytest
 
+from pynecone import utils
 from pynecone.base import Base
 from pynecone.event import Event
 from pynecone.state import State
 from pynecone.var import BaseVar, ComputedVar
 
 
-@pytest.fixture
-def TestObject():
-    class TestObject(Base):
-        """A test object fixture."""
+class Object(Base):
+    """A test object fixture."""
 
-        prop1: int = 42
-        prop2: str = "hello"
+    prop1: int = 42
+    prop2: str = "hello"
 
-    return TestObject
 
+class TestState(State):
+    """A test state."""
 
-@pytest.fixture
-def TestState(TestObject):
-    class TestState(State):
-        """A test state."""
+    num1: int
+    num2: float = 3.14
+    key: str
+    array: List[float] = [1, 2, 3.14]
+    mapping: Dict[str, List[int]] = {"a": [1, 2, 3], "b": [4, 5, 6]}
+    obj: Object = Object()
+    complex: Dict[int, Object] = {1: Object(), 2: Object()}
 
-        num1: int
-        num2: float = 3.14
-        key: str
-        array: List[float] = [1, 2, 3.14]
-        mapping: Dict[str, List[int]] = {"a": [1, 2, 3], "b": [4, 5, 6]}
-        obj: TestObject = TestObject()
-        complex: Dict[int, TestObject] = {1: TestObject(), 2: TestObject()}
+    @ComputedVar
+    def sum(self) -> float:
+        """Dynamically sum the numbers.
 
-        @ComputedVar
-        def sum(self) -> float:
-            """Dynamically sum the numbers.
+        Returns:
+            The sum of the numbers.
+        """
+        return self.num1 + self.num2
 
-            Returns:
-                The sum of the numbers.
-            """
-            return self.num1 + self.num2
+    @ComputedVar
+    def upper(self) -> str:
+        """Uppercase the key.
 
-        @ComputedVar
-        def upper(self) -> str:
-            """Uppercase the key.
+        Returns:
+            The uppercased key.
+        """
+        return self.key.upper()
 
-            Returns:
-                The uppercased key.
-            """
-            return self.key.upper()
+    def do_something(self):
+        """Do something."""
+        pass
 
-        def do_something(self):
-            """Do something."""
-            pass
 
-    return TestState
+class ChildState(TestState):
+    """A child state fixture."""
 
+    value: str
+    count: int = 23
 
-@pytest.fixture
-def ChildState(TestState):
-    class ChildState(TestState):
-        """A child state fixture."""
+    def change_both(self, value: str, count: int):
+        """Change both the value and count.
+
+        Args:
+            value: The new value.
+            count: The new count.
+        """
+        self.value = value.upper()
+        self.count = count * 2
+
+
+class ChildState2(TestState):
+    """A child state fixture."""
 
-        value: str
-        count: int = 23
+    value: str
 
-        def change_both(self, value: str, count: int):
-            """Change both the value and count.
 
-            Args:
-                value: The new value.
-                count: The new count.
-            """
-            self.value = value.upper()
-            self.count = count * 2
+class GrandchildState(ChildState):
+    """A grandchild state fixture."""
 
-    return ChildState
+    value2: str
+
+    def do_nothing(self):
+        """Do something."""
+        pass
 
 
 @pytest.fixture
-def ChildState2(TestState):
-    class ChildState2(TestState):
-        """A child state fixture."""
+def test_state() -> TestState:
+    """A state.
 
-        value: str
+    Returns:
+        A test state.
+    """
+    return TestState()  # type: ignore
+
+
+@pytest.fixture
+def child_state(test_state) -> ChildState:
+    """A child state.
+
+    Args:
+        test_state: A test state.
 
-    return ChildState2
+    Returns:
+        A test child state.
+    """
+    child_state = test_state.get_substate(["child_state"])
+    assert child_state is not None
+    return child_state
 
 
 @pytest.fixture
-def GrandchildState(ChildState):
-    class GrandchildState(ChildState):
-        """A grandchild state fixture."""
+def child_state2(test_state) -> ChildState2:
+    """A second child state.
 
-        value2: str
+    Args:
+        test_state: A test state.
 
-    return GrandchildState
+    Returns:
+        A second test child state.
+    """
+    child_state2 = test_state.get_substate(["child_state2"])
+    assert child_state2 is not None
+    return child_state2
 
 
 @pytest.fixture
-def state(TestState) -> State:
+def grandchild_state(child_state) -> GrandchildState:
     """A state.
 
     Args:
-        TestState: The state class.
+        child_state: A child state.
 
     Returns:
         A test state.
     """
-    return TestState()  # type: ignore
+    grandchild_state = child_state.get_substate(["grandchild_state"])
+    assert grandchild_state is not None
+    return grandchild_state
 
 
-def test_base_class_vars(state):
+def test_base_class_vars(test_state):
     """Test that the class vars are set correctly.
 
     Args:
-        state: A state.
+        test_state: A state.
     """
-    fields = state.get_fields()
-    cls = type(state)
+    fields = test_state.get_fields()
+    cls = type(test_state)
 
     for field in fields:
         if field in (
@@ -137,25 +164,25 @@ def test_base_class_vars(state):
     assert cls.key.type_ == str
 
 
-def test_computed_class_var(state):
+def test_computed_class_var(test_state):
     """Test that the class computed vars are set correctly.
 
     Args:
-        state: A state.
+        test_state: A state.
     """
-    cls = type(state)
+    cls = type(test_state)
     vars = [(prop.name, prop.type_) for prop in cls.computed_vars.values()]
     assert ("sum", float) in vars
     assert ("upper", str) in vars
 
 
-def test_class_vars(state):
+def test_class_vars(test_state):
     """Test that the class vars are set correctly.
 
     Args:
-        state: A state.
+        test_state: A state.
     """
-    cls = type(state)
+    cls = type(test_state)
     assert set(cls.vars.keys()) == {
         "num1",
         "num2",
@@ -169,60 +196,59 @@ def test_class_vars(state):
     }
 
 
-def test_default_value(state):
+def test_default_value(test_state):
     """Test that the default value of a var is correct.
 
     Args:
-        state: A state.
+        test_state: A state.
     """
-    assert state.num1 == 0
-    assert state.num2 == 3.14
-    assert state.key == ""
-    assert state.sum == 3.14
-    assert state.upper == ""
+    assert test_state.num1 == 0
+    assert test_state.num2 == 3.14
+    assert test_state.key == ""
+    assert test_state.sum == 3.14
+    assert test_state.upper == ""
 
 
-def test_computed_vars(state):
+def test_computed_vars(test_state):
     """Test that the computed var is computed correctly.
 
     Args:
-        state: A state.
+        test_state: A state.
     """
-    state.num1 = 1
-    state.num2 = 4
-    assert state.sum == 5
-    state.key = "hello world"
-    assert state.upper == "HELLO WORLD"
+    test_state.num1 = 1
+    test_state.num2 = 4
+    assert test_state.sum == 5
+    test_state.key = "hello world"
+    assert test_state.upper == "HELLO WORLD"
 
 
-def test_dict(state):
+def test_dict(test_state):
     """Test that the dict representation of a state is correct.
 
     Args:
-        state: A state.
+        test_state: A state.
     """
-    assert set(state.dict().keys()) == set(state.vars.keys())
-    assert set(state.dict(include_computed=False).keys()) == set(state.base_vars)
+    substates = {"child_state", "child_state2"}
+    assert set(test_state.dict().keys()) == set(test_state.vars.keys()) | substates
+    assert (
+        set(test_state.dict(include_computed=False).keys())
+        == set(test_state.base_vars) | substates
+    )
 
 
-def test_default_setters(TestState):
+def test_default_setters(test_state):
     """Test that we can set default values.
 
     Args:
-        TestState: The state class.
+        test_state: A state.
     """
-    state = TestState()
-    for prop_name in state.base_vars:
+    for prop_name in test_state.base_vars:
         # Each base var should have a default setter.
-        assert hasattr(state, f"set_{prop_name}")
-
+        assert hasattr(test_state, f"set_{prop_name}")
 
-def test_class_indexing_with_vars(TestState):
-    """Test that we can index into a state var with another var.
 
-    Args:
-        TestState: The state class.
-    """
+def test_class_indexing_with_vars():
+    """Test that we can index into a state var with another var."""
     prop = TestState.array[TestState.num1]
     assert str(prop) == "{test_state.array[test_state.num1]}"
 
@@ -230,12 +256,8 @@ def test_class_indexing_with_vars(TestState):
     assert str(prop) == '{test_state.mapping["a"][test_state.num1]}'
 
 
-def test_class_attributes(TestState):
-    """Test that we can get class attributes.
-
-    Args:
-        TestState: The state class.
-    """
+def test_class_attributes():
+    """Test that we can get class attributes."""
     prop = TestState.obj.prop1
     assert str(prop) == "{test_state.obj.prop1}"
 
@@ -243,75 +265,40 @@ def test_class_attributes(TestState):
     assert str(prop) == "{test_state.complex[1].prop1}"
 
 
-def test_get_parent_state(TestState, ChildState, ChildState2, GrandchildState):
-    """Test getting the parent state.
-
-    Args:
-        TestState: The state class.
-        ChildState: The child state class.
-        ChildState2: The child state class.
-        GrandchildState: The grandchild state class.
-    """
+def test_get_parent_state():
+    """Test getting the parent state."""
     assert TestState.get_parent_state() is None
     assert ChildState.get_parent_state() == TestState
     assert ChildState2.get_parent_state() == TestState
     assert GrandchildState.get_parent_state() == ChildState
 
 
-def test_get_substates(TestState, ChildState, ChildState2, GrandchildState):
-    """Test getting the substates.
-
-    Args:
-        TestState: The state class.
-        ChildState: The child state class.
-        ChildState2: The child state class.
-        GrandchildState: The grandchild state class.
-    """
+def test_get_substates():
+    """Test getting the substates."""
     assert TestState.get_substates() == {ChildState, ChildState2}
     assert ChildState.get_substates() == {GrandchildState}
     assert ChildState2.get_substates() == set()
     assert GrandchildState.get_substates() == set()
 
 
-def test_get_name(TestState, ChildState, ChildState2, GrandchildState):
-    """Test getting the name of a state.
-
-    Args:
-        TestState: The state class.
-        ChildState: The child state class.
-        ChildState2: The child state class.
-        GrandchildState: The grandchild state class.
-    """
+def test_get_name():
+    """Test getting the name of a state."""
     assert TestState.get_name() == "test_state"
     assert ChildState.get_name() == "child_state"
     assert ChildState2.get_name() == "child_state2"
     assert GrandchildState.get_name() == "grandchild_state"
 
 
-def test_get_full_name(TestState, ChildState, ChildState2, GrandchildState):
-    """Test getting the full name.
-
-    Args:
-        TestState: The state class.
-        ChildState: The child state class.
-        ChildState2: The child state class.
-        GrandchildState: The grandchild state class.
-    """
+def test_get_full_name():
+    """Test getting the full name."""
     assert TestState.get_full_name() == "test_state"
     assert ChildState.get_full_name() == "test_state.child_state"
     assert ChildState2.get_full_name() == "test_state.child_state2"
     assert GrandchildState.get_full_name() == "test_state.child_state.grandchild_state"
 
 
-def test_get_class_substate(TestState, ChildState, ChildState2, GrandchildState):
-    """Test getting the substate of a class.
-
-    Args:
-        TestState: The state class.
-        ChildState: The child state class.
-        ChildState2: The child state class.
-        GrandchildState: The grandchild state class.
-    """
+def test_get_class_substate():
+    """Test getting the substate of a class."""
     assert TestState.get_class_substate(("child_state",)) == ChildState
     assert TestState.get_class_substate(("child_state2",)) == ChildState2
     assert ChildState.get_class_substate(("grandchild_state",)) == GrandchildState
@@ -330,15 +317,8 @@ def test_get_class_substate(TestState, ChildState, ChildState2, GrandchildState)
         )
 
 
-def test_get_class_var(TestState, ChildState, ChildState2, GrandchildState):
-    """Test getting the var of a class.
-
-    Args:
-        TestState: The state class.
-        ChildState: The child state class.
-        ChildState2: The child state class.
-        GrandchildState: The grandchild state class.
-    """
+def test_get_class_var():
+    """Test getting the var of a class."""
     assert TestState.get_class_var(("num1",)) == TestState.num1
     assert TestState.get_class_var(("num2",)) == TestState.num2
     assert ChildState.get_class_var(("value",)) == ChildState.value
@@ -363,58 +343,45 @@ def test_get_class_var(TestState, ChildState, ChildState2, GrandchildState):
         )
 
 
-def test_set_class_var(TestState):
-    """Test setting the var of a class.
-
-    Args:
-        TestState: The state class.
-    """
+def test_set_class_var():
+    """Test setting the var of a class."""
     with pytest.raises(AttributeError):
-        TestState.num3
+        TestState.num3  # type: ignore
     TestState._set_var(BaseVar(name="num3", type_=int).set_state(TestState))
-    var = TestState.num3
+    var = TestState.num3  # type: ignore
     assert var.name == "num3"
     assert var.type_ == int
     assert var.state == TestState.get_full_name()
 
 
-def test_set_parent_and_substates(TestState, ChildState, ChildState2, GrandchildState):
+def test_set_parent_and_substates(test_state, child_state, grandchild_state):
     """Test setting the parent and substates.
 
     Args:
-        TestState: The state class.
-        ChildState: The child state class.
-        ChildState2: The child state class.
-        GrandchildState: The grandchild state class.
+        test_state: A state.
+        child_state: A child state.
+        grandchild_state: A grandchild state.
     """
-    test_state = TestState()
     assert len(test_state.substates) == 2
     assert set(test_state.substates) == {"child_state", "child_state2"}
 
-    child_state = test_state.substates["child_state"]
     assert child_state.parent_state == test_state
     assert len(child_state.substates) == 1
     assert set(child_state.substates) == {"grandchild_state"}
 
-    grandchild_state = child_state.substates["grandchild_state"]
     assert grandchild_state.parent_state == child_state
     assert len(grandchild_state.substates) == 0
 
 
-def test_get_child_attribute(TestState, ChildState, ChildState2, GrandchildState):
+def test_get_child_attribute(test_state, child_state, child_state2, grandchild_state):
     """Test getting the attribute of a state.
 
     Args:
-        TestState: The state class.
-        ChildState: The child state class.
-        ChildState2: The child state class.
-        GrandchildState: The grandchild state class.
+        test_state: A state.
+        child_state: A child state.
+        child_state2: A child state.
+        grandchild_state: A grandchild state.
     """
-    test_state = TestState()
-    child_state = test_state.get_substate(["child_state"])
-    child_state2 = test_state.get_substate(["child_state2"])
-    grandchild_state = child_state.get_substate(["grandchild_state"])
-
     assert test_state.num1 == 0
     assert child_state.value == ""
     assert child_state2.value == ""
@@ -428,19 +395,14 @@ def test_get_child_attribute(TestState, ChildState, ChildState2, GrandchildState
         test_state.child_state.grandchild_state.invalid
 
 
-def test_set_child_attribute(TestState, ChildState, ChildState2, GrandchildState):
+def test_set_child_attribute(test_state, child_state, grandchild_state):
     """Test setting the attribute of a state.
 
     Args:
-        TestState: The state class.
-        ChildState: The child state class.
-        ChildState2: The child state class.
-        GrandchildState: The grandchild state class.
+        test_state: A state.
+        child_state: A child state.
+        grandchild_state: A grandchild state.
     """
-    test_state = TestState()
-    child_state = test_state.get_substate(["child_state"])
-    grandchild_state = child_state.get_substate(["grandchild_state"])
-
     test_state.num1 = 10
     assert test_state.num1 == 10
     child_state.value = "test"
@@ -449,20 +411,15 @@ def test_set_child_attribute(TestState, ChildState, ChildState2, GrandchildState
     assert grandchild_state.value2 == "test2"
 
 
-def test_get_substate(TestState, ChildState, ChildState2, GrandchildState):
+def test_get_substate(test_state, child_state, child_state2, grandchild_state):
     """Test getting the substate of a state.
 
     Args:
-        TestState: The state class.
-        ChildState: The child state class.
-        ChildState2: The child state class.
-        GrandchildState: The grandchild state class.
+        test_state: A state.
+        child_state: A child state.
+        child_state2: A child state.
+        grandchild_state: A grandchild state.
     """
-    test_state = TestState()
-    child_state = test_state.substates["child_state"]
-    child_state2 = test_state.substates["child_state2"]
-    grandchild_state = child_state.substates["grandchild_state"]
-
     assert test_state.get_substate(("child_state",)) == child_state
     assert test_state.get_substate(("child_state2",)) == child_state2
     assert (
@@ -477,14 +434,12 @@ def test_get_substate(TestState, ChildState, ChildState2, GrandchildState):
         test_state.get_substate(("child_state", "grandchild_state", "invalid"))
 
 
-def test_set_dirty_var(TestState):
+def test_set_dirty_var(test_state):
     """Test changing state vars marks the value as dirty.
 
     Args:
-        TestState: The state class.
+        test_state: A state.
     """
-    test_state = TestState()
-
     # Initially there should be no dirty vars.
     assert test_state.dirty_vars == set()
 
@@ -501,20 +456,15 @@ def test_set_dirty_var(TestState):
     assert test_state.dirty_vars == set()
 
 
-def test_set_dirty_substate(TestState, ChildState, ChildState2, GrandchildState):
+def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_state):
     """Test changing substate vars marks the value as dirty.
 
     Args:
-        TestState: The state class.
-        ChildState: The child state class.
-        ChildState2: The child state class.
-        GrandchildState: The grandchild state class.
+        test_state: A state.
+        child_state: A child state.
+        child_state2: A child state.
+        grandchild_state: A grandchild state.
     """
-    test_state = TestState()
-    child_state = test_state.get_substate(["child_state"])
-    child_state2 = test_state.get_substate(["child_state2"])
-    grandchild_state = child_state.get_substate(["grandchild_state"])
-
     # Initially there should be no dirty vars.
     assert test_state.dirty_vars == set()
     assert child_state.dirty_vars == set()
@@ -544,16 +494,13 @@ def test_set_dirty_substate(TestState, ChildState, ChildState2, GrandchildState)
     assert grandchild_state.dirty_vars == set()
 
 
-def test_reset(TestState, ChildState):
+def test_reset(test_state, child_state):
     """Test resetting the state.
 
     Args:
-        TestState: The state class.
-        ChildState: The child state class.
+        test_state: A state.
+        child_state: A child state.
     """
-    test_state = TestState()
-    child_state = test_state.get_substate(["child_state"])
-
     # Set some values.
     test_state.num1 = 1
     test_state.num2 = 2
@@ -576,13 +523,12 @@ def test_reset(TestState, ChildState):
 
 
 @pytest.mark.asyncio
-async def test_process_event_simple(TestState):
+async def test_process_event_simple(test_state):
     """Test processing an event.
 
     Args:
-        TestState: The state class.
+        test_state: A state.
     """
-    test_state = TestState()
     assert test_state.num1 == 0
 
     event = Event(token="t", name="set_num1", payload={"value": 69})
@@ -597,18 +543,14 @@ async def test_process_event_simple(TestState):
 
 
 @pytest.mark.asyncio
-async def test_process_event_substate(TestState, ChildState, GrandchildState):
+async def test_process_event_substate(test_state, child_state, grandchild_state):
     """Test processing an event on a substate.
 
     Args:
-        TestState: The state class.
-        ChildState: The child state class.
-        GrandchildState: The grandchild state class.
+        test_state: A state.
+        child_state: A child state.
+        grandchild_state: A grandchild state.
     """
-    test_state = TestState()
-    child_state = test_state.get_substate(["child_state"])
-    grandchild_state = child_state.get_substate(["grandchild_state"])
-
     # Events should bubble down to the substate.
     assert child_state.value == ""
     assert child_state.count == 23
@@ -637,3 +579,18 @@ async def test_process_event_substate(TestState, ChildState, GrandchildState):
         "test_state.child_state.grandchild_state": {"value2": "new"},
         "test_state": {"sum": 3.14, "upper": ""},
     }
+
+
+def test_format_event_handler():
+    """Test formatting an event handler."""
+    assert (
+        utils.format_event_handler(TestState.do_something) == "test_state.do_something"  # type: ignore
+    )
+    assert (
+        utils.format_event_handler(ChildState.change_both)  # type: ignore
+        == "test_state.child_state.change_both"
+    )
+    assert (
+        utils.format_event_handler(GrandchildState.do_nothing)  # type: ignore
+        == "test_state.child_state.grandchild_state.do_nothing"
+    )