Преглед на файлове

Fix Event chaining in the on_load event handler return not working (#773)

* Fix Event chaining in the on_load event handler return not working

* added async tests

* addressed comments
Elijah Ahianyo преди 2 години
родител
ревизия
e8387c8e26

+ 28 - 17
pynecone/app.py

@@ -148,7 +148,9 @@ class App(Base):
             allow_origins=["*"],
         )
 
-    def preprocess(self, state: State, event: Event) -> Optional[Delta]:
+    async def preprocess(
+        self, state: State, event: Event
+    ) -> Optional[Union[StateUpdate, List[StateUpdate]]]:
         """Preprocess the event.
 
         This is where middleware can modify the event before it is processed.
@@ -165,11 +167,13 @@ class App(Base):
             An optional state to return.
         """
         for middleware in self.middleware:
-            out = middleware.preprocess(app=self, state=state, event=event)
+            out = await middleware.preprocess(app=self, state=state, event=event)
             if out is not None:
                 return out
 
-    def postprocess(self, state: State, event: Event, delta: Delta) -> Optional[Delta]:
+    async def postprocess(
+        self, state: State, event: Event, delta: Delta
+    ) -> Optional[Delta]:
         """Postprocess the event.
 
         This is where middleware can modify the delta after it is processed.
@@ -187,7 +191,7 @@ class App(Base):
             An optional state to return.
         """
         for middleware in self.middleware:
-            out = middleware.postprocess(
+            out = await middleware.postprocess(
                 app=self, state=state, event=event, delta=delta
             )
             if out is not None:
@@ -400,7 +404,7 @@ class App(Base):
 
 async def process(
     app: App, event: Event, sid: str, headers: Dict, client_ip: str
-) -> StateUpdate:
+) -> Union[StateUpdate, List[StateUpdate]]:
     """Process an event.
 
     Args:
@@ -411,7 +415,7 @@ async def process(
         client_ip: The client_ip.
 
     Returns:
-        The state update after processing the event.
+        The state update(s) after processing the event.
     """
     # Get the state for the session.
     state = app.state_manager.get_state(event.token)
@@ -430,21 +434,27 @@ async def process(
     state.router_data[constants.RouteVar.CLIENT_IP] = client_ip
 
     # Preprocess the event.
-    pre = app.preprocess(state, event)
-    if pre is not None:
-        return StateUpdate(delta=pre)
+    pre = await app.preprocess(state, event)
+    if pre is not None and not isinstance(pre, List):
+        return pre
 
     # Apply the event to the state.
-    update = await state.process(event)
+    updates = pre if pre else await state.process(event)
     app.state_manager.set_state(event.token, state)
 
+    updates = updates if isinstance(updates, List) else [updates]
+
     # Postprocess the event.
-    post = app.postprocess(state, event, update.delta)
-    if post is not None:
-        return StateUpdate(delta=post)
+    post_list = []
+    for update in updates:
+        post = await app.postprocess(state, event, update.delta)  # type: ignore
+        post_list.append(post) if post else None
+
+    if post_list:
+        return [StateUpdate(delta=post) for post in post_list]
 
     # Return the update.
-    return update
+    return updates
 
 
 async def ping() -> str:
@@ -578,11 +588,12 @@ class EventNamespace(AsyncNamespace):
         # Get the client IP
         client_ip = environ["REMOTE_ADDR"]
 
-        # Process the event.
-        update = await process(self.app, event, sid, headers, client_ip)
+        # Process the events.
+        updates = await process(self.app, event, sid, headers, client_ip)
 
         # Emit the event.
-        await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
+        for update in updates:
+            await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)  # type: ignore
 
     async def on_ping(self, sid):
         """Event for testing the API endpoint.

+ 1 - 1
pynecone/components/typography/markdown.py

@@ -69,7 +69,7 @@ class Markdown(Component):
                     "li": "{ListItem}",
                     "p": "{Text}",
                     "a": "{Link}",
-                    "code": """{({node, inline, className, children, ...props}) => 
+                    "code": """{({node, inline, className, children, ...props}) =>
                     {
         const match = (className || '').match(/language-(?<lang>.*)/);
         return !inline ? (

+ 38 - 13
pynecone/middleware/hydrate_middleware.py

@@ -1,12 +1,12 @@
 """Middleware to hydrate the state."""
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional, Union
 
 from pynecone import constants
 from pynecone.event import Event, EventHandler, get_hydrate_event
 from pynecone.middleware.middleware import Middleware
-from pynecone.state import Delta, State
+from pynecone.state import State, StateUpdate
 from pynecone.utils import format
 
 if TYPE_CHECKING:
@@ -16,7 +16,9 @@ if TYPE_CHECKING:
 class HydrateMiddleware(Middleware):
     """Middleware to handle initial app hydration."""
 
-    def preprocess(self, app: App, state: State, event: Event) -> Optional[Delta]:
+    async def preprocess(
+        self, app: App, state: State, event: Event
+    ) -> Optional[Union[StateUpdate, List[StateUpdate]]]:
         """Preprocess the event.
 
         Args:
@@ -25,7 +27,7 @@ class HydrateMiddleware(Middleware):
             event: The event to preprocess.
 
         Returns:
-            An optional state to return.
+            An optional delta or list of state updates to return.
         """
         if event.name == get_hydrate_event(state):
             route = event.router_data.get(constants.RouteVar.PATH, "")
@@ -37,20 +39,43 @@ class HydrateMiddleware(Middleware):
                 load_event = None
 
             if load_event:
-                if isinstance(load_event, list):
-                    for single_event in load_event:
-                        self.execute_load_event(state, single_event)
-                else:
-                    self.execute_load_event(state, load_event)
-            return format.format_state({state.get_name(): state.dict()})
-
-    def execute_load_event(self, state: State, load_event: EventHandler) -> None:
+                if not isinstance(load_event, List):
+                    load_event = [load_event]
+                updates = []
+                for single_event in load_event:
+                    updates.append(
+                        await self.execute_load_event(
+                            state, single_event, event.token, event.payload
+                        )
+                    )
+                return updates
+            delta = format.format_state({state.get_name(): state.dict()})
+            return StateUpdate(delta=delta) if delta else None
+
+    async def execute_load_event(
+        self, state: State, load_event: EventHandler, token: str, payload: Dict
+    ) -> StateUpdate:
         """Execute single load event.
 
         Args:
             state: The client state.
             load_event: A single load event to execute.
+            token: Client token
+            payload: The event payload
+
+        Returns:
+            A state Update.
+
+        Raises:
+            ValueError: If the state value is None.
         """
         substate_path = format.format_event_handler(load_event).split(".")
         ex_state = state.get_substate(substate_path[:-1])
-        load_event.fn(ex_state)
+        if not ex_state:
+            raise ValueError(
+                "The value of state cannot be None when processing an on-load event."
+            )
+
+        return await state.process_event(
+            handler=load_event, state=ex_state, payload=payload, token=token
+        )

+ 2 - 2
pynecone/middleware/logging_middleware.py

@@ -14,7 +14,7 @@ if TYPE_CHECKING:
 class LoggingMiddleware(Middleware):
     """Middleware to log requests and responses."""
 
-    def preprocess(self, app: App, state: State, event: Event):
+    async def preprocess(self, app: App, state: State, event: Event):
         """Preprocess the event.
 
         Args:
@@ -24,7 +24,7 @@ class LoggingMiddleware(Middleware):
         """
         print(f"Event {event}")
 
-    def postprocess(self, app: App, state: State, event: Event, delta: Delta):
+    async def postprocess(self, app: App, state: State, event: Event, delta: Delta):
         """Postprocess the event.
 
         Args:

+ 6 - 4
pynecone/middleware/middleware.py

@@ -2,11 +2,11 @@
 from __future__ import annotations
 
 from abc import ABC
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, List, Optional, Union
 
 from pynecone.base import Base
 from pynecone.event import Event
-from pynecone.state import Delta, State
+from pynecone.state import Delta, State, StateUpdate
 
 if TYPE_CHECKING:
     from pynecone.app import App
@@ -15,7 +15,9 @@ if TYPE_CHECKING:
 class Middleware(Base, ABC):
     """Middleware to preprocess and postprocess requests."""
 
-    def preprocess(self, app: App, state: State, event: Event) -> Optional[Delta]:
+    async def preprocess(
+        self, app: App, state: State, event: Event
+    ) -> Optional[Union[StateUpdate, List[StateUpdate]]]:
         """Preprocess the event.
 
         Args:
@@ -28,7 +30,7 @@ class Middleware(Base, ABC):
         """
         return None
 
-    def postprocess(
+    async def postprocess(
         self, app: App, state: State, event: Event, delta
     ) -> Optional[Delta]:
         """Postprocess the event.

+ 2 - 2
pynecone/pc.py

@@ -203,12 +203,12 @@ def export(
 
     if zipping:
         console.rule(
-            """Backend & Frontend compiled. See [green bold]backend.zip[/green bold] 
+            """Backend & Frontend compiled. See [green bold]backend.zip[/green bold]
             and [green bold]frontend.zip[/green bold]."""
         )
     else:
         console.rule(
-            """Backend & Frontend compiled. See [green bold]app[/green bold] 
+            """Backend & Frontend compiled. See [green bold]app[/green bold]
             and [green bold].web/_static[/green bold] directories."""
         )
 

+ 35 - 7
pynecone/state.py

@@ -546,13 +546,16 @@ class State(Base, ABC):
         return self.substates[path[0]].get_substate(path[1:])
 
     async def process(self, event: Event) -> StateUpdate:
-        """Process an event.
+        """Obtain event info and process event.
 
         Args:
             event: The event to process.
 
         Returns:
             The state update after processing the event.
+
+        Raises:
+            ValueError: If the state value is None.
         """
         # Get the event handler.
         path = event.name.split(".")
@@ -560,23 +563,48 @@ class State(Base, ABC):
         substate = self.get_substate(path)
         handler = substate.event_handlers[name]  # type: ignore
 
-        # Process the event.
-        fn = functools.partial(handler.fn, substate)
+        if not substate:
+            raise ValueError(
+                "The value of state cannot be None when processing an event."
+            )
+
+        return await self.process_event(
+            handler=handler,
+            state=substate,
+            payload=event.payload,
+            token=event.token,
+        )
+
+    async def process_event(
+        self, handler: EventHandler, state: State, payload: Dict, token: str
+    ) -> StateUpdate:
+        """Process event.
+
+        Args:
+            handler: Eventhandler to process.
+            state: State to process the handler.
+            payload: The event payload.
+            token: Client token.
+
+        Returns:
+            The state update after processing the event.
+        """
+        fn = functools.partial(handler.fn, state)
         try:
             if asyncio.iscoroutinefunction(fn.func):
-                events = await fn(**event.payload)
+                events = await fn(**payload)
             else:
-                events = fn(**event.payload)
+                events = fn(**payload)
         except Exception:
             error = traceback.format_exc()
             print(error)
             events = fix_events(
-                [window_alert("An error occurred. See logs for details.")], event.token
+                [window_alert("An error occurred. See logs for details.")], token
             )
             return StateUpdate(events=events)
 
         # Fix the returned events.
-        events = fix_events(events, event.token)
+        events = fix_events(events, token)
 
         # Get the delta after processing the event.
         delta = self.get_delta()

+ 0 - 1
tests/components/datadisplay/test_datatable.py

@@ -94,7 +94,6 @@ def test_computed_var_without_annotation(fixture, request, err_msg, is_data_fram
         is_data_frame: whether data field is a pandas dataframe.
     """
     with pytest.raises(ValueError) as err:
-
         if is_data_frame:
             data_table(data=request.getfixturevalue(fixture).data)
         else:

+ 0 - 0
tests/middleware/__init__.py


+ 34 - 0
tests/middleware/conftest.py

@@ -0,0 +1,34 @@
+import pytest
+
+from pynecone.event import Event
+
+
+def create_event(name):
+    return Event(
+        token="<token>",
+        name=name,
+        router_data={
+            "pathname": "/",
+            "query": {},
+            "token": "<token>",
+            "sid": "<sid>",
+            "headers": {},
+            "ip": "127.0.0.1",
+        },
+        payload={},
+    )
+
+
+@pytest.fixture
+def event1():
+    return create_event("test_state.hydrate")
+
+
+@pytest.fixture
+def event2():
+    return create_event("test_state2.hydrate")
+
+
+@pytest.fixture
+def event3():
+    return create_event("test_state3.hydrate")

+ 96 - 0
tests/middleware/test_hydrate_middleware.py

@@ -0,0 +1,96 @@
+from typing import List
+
+import pytest
+
+from pynecone.app import App
+from pynecone.middleware.hydrate_middleware import HydrateMiddleware
+from pynecone.state import State
+
+
+class TestState(State):
+    """A test state with no return in handler."""
+
+    num: int = 0
+
+    def test_handler(self):
+        """Test handler."""
+        self.num += 1
+
+
+class TestState2(State):
+    """A test state with return in handler."""
+
+    num: int = 0
+    name: str
+
+    def test_handler(self):
+        """Test handler that calls another handler.
+
+        Returns:
+            Chain of EventHandlers
+        """
+        self.num += 1
+        return [self.change_name()]
+
+    def change_name(self):
+        """Test handler to change name."""
+        self.name = "random"
+
+
+class TestState3(State):
+    """A test state with async handler."""
+
+    num: int = 0
+
+    async def test_handler(self):
+        """Test handler."""
+        self.num += 1
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "state, expected, event_fixture",
+    [
+        (TestState, {"test_state": {"num": 1}}, "event1"),
+        (TestState2, {"test_state2": {"num": 1}}, "event2"),
+        (TestState3, {"test_state3": {"num": 1}}, "event3"),
+    ],
+)
+async def test_preprocess(state, request, event_fixture, expected):
+    """Test that a state hydrate event is processed correctly.
+
+    Args:
+        state: state to process event
+        request: pytest fixture request
+        event_fixture: The event fixture(an Event)
+        expected: expected delta
+    """
+    app = App(state=state, load_events={"index": state.test_handler})
+
+    hydrate_middleware = HydrateMiddleware()
+    result = await hydrate_middleware.preprocess(
+        app=app, event=request.getfixturevalue(event_fixture), state=state()
+    )
+    assert isinstance(result, List)
+    assert result[0].delta == expected
+
+
+@pytest.mark.asyncio
+async def test_preprocess_multiple_load_events(event1):
+    """Test that a state hydrate event for multiple on-load events is processed correctly.
+
+    Args:
+        event1: an Event.
+    """
+    app = App(
+        state=TestState,
+        load_events={"index": [TestState.test_handler, TestState.test_handler]},
+    )
+
+    hydrate_middleware = HydrateMiddleware()
+    result = await hydrate_middleware.preprocess(
+        app=app, event=event1, state=TestState()
+    )
+    assert isinstance(result, List)
+    assert result[0].delta == {"test_state": {"num": 1}}
+    assert result[1].delta == {"test_state": {"num": 2}}