Ver código fonte

Add async events (#1107)

Nikhil Rao 2 anos atrás
pai
commit
a18c6880b5

+ 29 - 15
pynecone/app.py

@@ -2,7 +2,18 @@
 
 
 import asyncio
 import asyncio
 import inspect
 import inspect
-from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union
+from typing import (
+    Any,
+    AsyncIterator,
+    Callable,
+    Coroutine,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    Type,
+    Union,
+)
 
 
 from fastapi import FastAPI, UploadFile
 from fastapi import FastAPI, UploadFile
 from fastapi.middleware import cors
 from fastapi.middleware import cors
@@ -411,7 +422,7 @@ class App(Base):
 
 
 async def process(
 async def process(
     app: App, event: Event, sid: str, headers: Dict, client_ip: str
     app: App, event: Event, sid: str, headers: Dict, client_ip: str
-) -> StateUpdate:
+) -> AsyncIterator[StateUpdate]:
     """Process an event.
     """Process an event.
 
 
     Args:
     Args:
@@ -421,7 +432,7 @@ async def process(
         headers: The client headers.
         headers: The client headers.
         client_ip: The client_ip.
         client_ip: The client_ip.
 
 
-    Returns:
+    Yields:
         The state updates after processing the event.
         The state updates after processing the event.
     """
     """
     # Get the state for the session.
     # Get the state for the session.
@@ -447,20 +458,23 @@ async def process(
     # Preprocess the event.
     # Preprocess the event.
     update = await app.preprocess(state, event)
     update = await app.preprocess(state, event)
 
 
+    # If there was an update, yield it.
+    if update is not None:
+        yield update
+
     # Only process the event if there is no update.
     # Only process the event if there is no update.
-    if update is None:
-        # Apply the event to the state.
-        update = await state._process(event)
+    else:
+        # Process the event.
+        async for update in state._process(event):
+            yield update
 
 
         # Postprocess the event.
         # Postprocess the event.
+        assert update is not None, "Process did not return an update."
         update = await app.postprocess(state, event, update)
         update = await app.postprocess(state, event, update)
 
 
-    # Update the state.
+    # Set the state for the session.
     app.state_manager.set_state(event.token, state)
     app.state_manager.set_state(event.token, state)
 
 
-    # Return the update.
-    return update
-
 
 
 async def ping() -> str:
 async def ping() -> str:
     """Test API endpoint.
     """Test API endpoint.
@@ -531,7 +545,8 @@ def upload(app: App):
             name=handler,
             name=handler,
             payload={handler_upload_param[0]: files},
             payload={handler_upload_param[0]: files},
         )
         )
-        update = await state._process(event)
+        # TODO: refactor this to handle yields.
+        update = await state._process(event).__anext__()
         return update
         return update
 
 
     return upload_file
     return upload_file
@@ -595,10 +610,9 @@ class EventNamespace(AsyncNamespace):
         client_ip = environ["REMOTE_ADDR"]
         client_ip = environ["REMOTE_ADDR"]
 
 
         # Process the events.
         # Process the events.
-        update = await process(self.app, event, sid, headers, client_ip)
-
-        # Emit the event.
-        await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)  # type: ignore
+        async for update in process(self.app, event, sid, headers, client_ip):
+            # Emit the event.
+            await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)  # type: ignore
 
 
     async def on_ping(self, sid):
     async def on_ping(self, sid):
         """Event for testing the API endpoint.
         """Event for testing the API endpoint.

+ 1 - 1
pynecone/event.py

@@ -30,7 +30,7 @@ class EventHandler(Base):
     """An event handler responds to an event to update the state."""
     """An event handler responds to an event to update the state."""
 
 
     # The function to call in response to the event.
     # The function to call in response to the event.
-    fn: Callable
+    fn: Any
 
 
     class Config:
     class Config:
         """The Pydantic config."""
         """The Pydantic config."""

+ 49 - 24
pynecone/state.py

@@ -4,11 +4,13 @@ from __future__ import annotations
 import asyncio
 import asyncio
 import copy
 import copy
 import functools
 import functools
+import inspect
 import traceback
 import traceback
 from abc import ABC
 from abc import ABC
 from collections import defaultdict
 from collections import defaultdict
 from typing import (
 from typing import (
     Any,
     Any,
+    AsyncIterator,
     Callable,
     Callable,
     ClassVar,
     ClassVar,
     Dict,
     Dict,
@@ -26,7 +28,7 @@ from redis import Redis
 
 
 from pynecone import constants
 from pynecone import constants
 from pynecone.base import Base
 from pynecone.base import Base
-from pynecone.event import Event, EventHandler, fix_events, window_alert
+from pynecone.event import Event, EventHandler, EventSpec, fix_events, window_alert
 from pynecone.utils import format, prerequisites, types
 from pynecone.utils import format, prerequisites, types
 from pynecone.vars import BaseVar, ComputedVar, PCDict, PCList, Var
 from pynecone.vars import BaseVar, ComputedVar, PCDict, PCList, Var
 
 
@@ -618,13 +620,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             raise ValueError(f"Invalid path: {path}")
             raise ValueError(f"Invalid path: {path}")
         return self.substates[path[0]].get_substate(path[1:])
         return self.substates[path[0]].get_substate(path[1:])
 
 
-    async def _process(self, event: Event) -> StateUpdate:
+    async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
         """Obtain event info and process event.
         """Obtain event info and process event.
 
 
         Args:
         Args:
             event: The event to process.
             event: The event to process.
 
 
-        Returns:
+        Yields:
             The state update after processing the event.
             The state update after processing the event.
 
 
         Raises:
         Raises:
@@ -641,52 +643,75 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
                 "The value of state cannot be None when processing an event."
                 "The value of state cannot be None when processing an event."
             )
             )
 
 
-        return await self._process_event(
+        # Get the event generator.
+        event_iter = self._process_event(
             handler=handler,
             handler=handler,
             state=substate,
             state=substate,
             payload=event.payload,
             payload=event.payload,
-            token=event.token,
         )
         )
 
 
+        # Clean the state before processing the event.
+        self.clean()
+
+        # Run the event generator and return state updates.
+        async for events in event_iter:
+            # Fix the returned events.
+            events = fix_events(events, event.token)  # type: ignore
+
+            # Get the delta after processing the event.
+            delta = self.get_delta()
+
+            # Yield the state update.
+            yield StateUpdate(delta=delta, events=events)
+
+            # Clean the state to prepare for the next event.
+            self.clean()
+
     async def _process_event(
     async def _process_event(
-        self, handler: EventHandler, state: State, payload: Dict, token: str
-    ) -> StateUpdate:
+        self, handler: EventHandler, state: State, payload: Dict
+    ) -> AsyncIterator[Optional[List[EventSpec]]]:
         """Process event.
         """Process event.
 
 
         Args:
         Args:
             handler: Eventhandler to process.
             handler: Eventhandler to process.
             state: State to process the handler.
             state: State to process the handler.
             payload: The event payload.
             payload: The event payload.
-            token: Client token.
 
 
-        Returns:
+        Yields:
             The state update after processing the event.
             The state update after processing the event.
         """
         """
+        # Get the function to process the event.
         fn = functools.partial(handler.fn, state)
         fn = functools.partial(handler.fn, state)
+
+        # Wrap the function in a try/except block.
         try:
         try:
+            # Handle async functions.
             if asyncio.iscoroutinefunction(fn.func):
             if asyncio.iscoroutinefunction(fn.func):
                 events = await fn(**payload)
                 events = await fn(**payload)
+
+            # Handle regular functions.
             else:
             else:
                 events = fn(**payload)
                 events = fn(**payload)
-        except Exception:
-            error = traceback.format_exc()
-            print(error)
-            events = fix_events(
-                [window_alert("An error occurred. See logs for details.")], token
-            )
-            return StateUpdate(events=events)
 
 
-        # Fix the returned events.
-        events = fix_events(events, token)
+            # Handle async generators.
+            if inspect.isasyncgen(events):
+                async for event in events:
+                    yield event
 
 
-        # Get the delta after processing the event.
-        delta = self.get_delta()
+            # Handle regular generators.
+            elif inspect.isgenerator(events):
+                for event in events:
+                    yield event
 
 
-        # Reset the dirty vars.
-        self.clean()
+            # Handle regular event chains.
+            else:
+                yield events
 
 
-        # Return the state update.
-        return StateUpdate(delta=delta, events=events)
+        # If an error occurs, throw a window alert.
+        except Exception:
+            error = traceback.format_exc()
+            print(error)
+            yield [window_alert("An error occurred. See logs for details.")]
 
 
     def _always_dirty_computed_vars(self) -> Set[str]:
     def _always_dirty_computed_vars(self) -> Set[str]:
         """The set of ComputedVars that always need to be recalculated.
         """The set of ComputedVars that always need to be recalculated.

+ 6 - 6
tests/middleware/test_hydrate_middleware.py

@@ -107,11 +107,11 @@ async def test_preprocess(State, hydrate_middleware, request, event_fixture, exp
     assert len(events) == 2
     assert len(events) == 2
 
 
     # Apply the on_load event.
     # Apply the on_load event.
-    update = await state._process(events[0])
+    update = await state._process(events[0]).__anext__()
     assert update.delta == expected
     assert update.delta == expected
 
 
     # Apply the hydrate event.
     # Apply the hydrate event.
-    update = await state._process(events[1])
+    update = await state._process(events[1]).__anext__()
     assert update.delta == exp_is_hydrated(state)
     assert update.delta == exp_is_hydrated(state)
 
 
 
 
@@ -136,13 +136,13 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1):
 
 
     # Apply the events.
     # Apply the events.
     events = update.events
     events = update.events
-    update = await state._process(events[0])
+    update = await state._process(events[0]).__anext__()
     assert update.delta == {"test_state": {"num": 1}}
     assert update.delta == {"test_state": {"num": 1}}
 
 
-    update = await state._process(events[1])
+    update = await state._process(events[1]).__anext__()
     assert update.delta == {"test_state": {"num": 2}}
     assert update.delta == {"test_state": {"num": 2}}
 
 
-    update = await state._process(events[2])
+    update = await state._process(events[2]).__anext__()
     assert update.delta == exp_is_hydrated(state)
     assert update.delta == exp_is_hydrated(state)
 
 
 
 
@@ -165,5 +165,5 @@ async def test_preprocess_no_events(hydrate_middleware, event1):
     assert len(update.events) == 1
     assert len(update.events) == 1
     assert isinstance(update, StateUpdate)
     assert isinstance(update, StateUpdate)
 
 
-    update = await state._process(update.events[0])
+    update = await state._process(update.events[0]).__anext__()
     assert update.delta == exp_is_hydrated(state)
     assert update.delta == exp_is_hydrated(state)

+ 8 - 7
tests/test_app.py

@@ -207,7 +207,7 @@ async def test_dynamic_var_event(test_state):
             router_data={"pathname": "/", "query": {}},
             router_data={"pathname": "/", "query": {}},
             payload={"value": 50},
             payload={"value": 50},
         )
         )
-    )
+    ).__anext__()
     assert result.delta == {"test_state": {"int_val": 50}}
     assert result.delta == {"test_state": {"int_val": 50}}
 
 
 
 
@@ -324,7 +324,7 @@ async def test_list_mutation_detection__plain_list(
                 router_data={"pathname": "/", "query": {}},
                 router_data={"pathname": "/", "query": {}},
                 payload={},
                 payload={},
             )
             )
-        )
+        ).__anext__()
 
 
         assert result.delta == expected_delta
         assert result.delta == expected_delta
 
 
@@ -451,7 +451,7 @@ async def test_dict_mutation_detection__plain_list(
                 router_data={"pathname": "/", "query": {}},
                 router_data={"pathname": "/", "query": {}},
                 payload={},
                 payload={},
             )
             )
-        )
+        ).__anext__()
 
 
         assert result.delta == expected_delta
         assert result.delta == expected_delta
 
 
@@ -645,7 +645,8 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             sid=sid,
             sid=sid,
             headers={},
             headers={},
             client_ip=client_ip,
             client_ip=client_ip,
-        )
+        ).__anext__()
+
         # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
         # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
         assert update == StateUpdate(
         assert update == StateUpdate(
             delta={
             delta={
@@ -675,7 +676,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             sid=sid,
             sid=sid,
             headers={},
             headers={},
             client_ip=client_ip,
             client_ip=client_ip,
-        )
+        ).__anext__()
         assert on_load_update == StateUpdate(
         assert on_load_update == StateUpdate(
             delta={
             delta={
                 state.get_name(): {
                 state.get_name(): {
@@ -695,7 +696,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             sid=sid,
             sid=sid,
             headers={},
             headers={},
             client_ip=client_ip,
             client_ip=client_ip,
-        )
+        ).__anext__()
         assert on_set_is_hydrated_update == StateUpdate(
         assert on_set_is_hydrated_update == StateUpdate(
             delta={
             delta={
                 state.get_name(): {
                 state.get_name(): {
@@ -715,7 +716,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             sid=sid,
             sid=sid,
             headers={},
             headers={},
             client_ip=client_ip,
             client_ip=client_ip,
-        )
+        ).__anext__()
         assert update == StateUpdate(
         assert update == StateUpdate(
             delta={
             delta={
                 state.get_name(): {
                 state.get_name(): {

+ 57 - 3
tests/test_state.py

@@ -91,6 +91,25 @@ class GrandchildState(ChildState):
         pass
         pass
 
 
 
 
+class GenState(State):
+    """A state with event handlers that generate multiple updates."""
+
+    value: int
+
+    def go(self, c: int):
+        """Increment the value c times and update each time.
+
+        Args:
+            c: The number of times to increment.
+
+        Yields:
+            After each increment.
+        """
+        for _ in range(c):
+            self.value += 1
+            yield
+
+
 @pytest.fixture
 @pytest.fixture
 def test_state() -> TestState:
 def test_state() -> TestState:
     """A state.
     """A state.
@@ -146,6 +165,16 @@ def grandchild_state(child_state) -> GrandchildState:
     return grandchild_state
     return grandchild_state
 
 
 
 
+@pytest.fixture
+def gen_state() -> GenState:
+    """A state.
+
+    Returns:
+        A test state.
+    """
+    return GenState()  # type: ignore
+
+
 def test_base_class_vars(test_state):
 def test_base_class_vars(test_state):
     """Test that the class vars are set correctly.
     """Test that the class vars are set correctly.
 
 
@@ -577,7 +606,7 @@ async def test_process_event_simple(test_state):
     assert test_state.num1 == 0
     assert test_state.num1 == 0
 
 
     event = Event(token="t", name="set_num1", payload={"value": 69})
     event = Event(token="t", name="set_num1", payload={"value": 69})
-    update = await test_state._process(event)
+    update = await test_state._process(event).__anext__()
 
 
     # The event should update the value.
     # The event should update the value.
     assert test_state.num1 == 69
     assert test_state.num1 == 69
@@ -603,7 +632,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
     event = Event(
     event = Event(
         token="t", name="child_state.change_both", payload={"value": "hi", "count": 12}
         token="t", name="child_state.change_both", payload={"value": "hi", "count": 12}
     )
     )
-    update = await test_state._process(event)
+    update = await test_state._process(event).__anext__()
     assert child_state.value == "HI"
     assert child_state.value == "HI"
     assert child_state.count == 24
     assert child_state.count == 24
     assert update.delta == {
     assert update.delta == {
@@ -619,7 +648,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
         name="child_state.grandchild_state.set_value2",
         name="child_state.grandchild_state.set_value2",
         payload={"value": "new"},
         payload={"value": "new"},
     )
     )
-    update = await test_state._process(event)
+    update = await test_state._process(event).__anext__()
     assert grandchild_state.value2 == "new"
     assert grandchild_state.value2 == "new"
     assert update.delta == {
     assert update.delta == {
         "test_state": {"sum": 3.14, "upper": ""},
         "test_state": {"sum": 3.14, "upper": ""},
@@ -627,6 +656,31 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
     }
     }
 
 
 
 
+@pytest.mark.asyncio
+async def test_process_event_generator(gen_state):
+    """Test event handlers that generate multiple updates.
+
+    Args:
+        gen_state: A state.
+    """
+    event = Event(
+        token="t",
+        name="go",
+        payload={"c": 5},
+    )
+    gen = gen_state._process(event)
+
+    count = 0
+    async for update in gen:
+        count += 1
+        assert gen_state.value == count
+        assert update.delta == {
+            "gen_state": {"value": count},
+        }
+
+    assert count == 5
+
+
 def test_format_event_handler():
 def test_format_event_handler():
     """Test formatting an event handler."""
     """Test formatting an event handler."""
     assert (
     assert (