Browse Source

Fix processing flag for generator event handlers (#1136)

Nikhil Rao 1 year ago
parent
commit
895719cf68

+ 1 - 1
pynecone/.templates/jinja/web/pages/index.js.jinja2

@@ -47,7 +47,7 @@ export default function Component() {
         {{const.result|react_setter}}({
           {{const.state}}: null,
           {{const.events}}: [],
-          {{const.processing}}: false,
+          {{const.processing}}: {{const.result}}.{{const.processing}},
         })
       }
 

+ 1 - 1
pynecone/.templates/web/utils/state.js

@@ -214,7 +214,7 @@ export const connect = async (
     update = JSON5.parse(update);
     applyDelta(state, update.delta);
     setResult({
-      processing: true,
+      processing: update.processing,
       state: state,
       events: update.events,
     });

+ 4 - 4
pynecone/app.py

@@ -466,11 +466,11 @@ async def process(
     else:
         # Process the event.
         async for update in state._process(event):
-            yield update
+            # Postprocess the event.
+            update = await app.postprocess(state, event, update)
 
-        # Postprocess the event.
-        assert update is not None, "Process did not return an update."
-        update = await app.postprocess(state, event, update)
+            # Yield the update.
+            yield update
 
     # Set the state for the session.
     app.state_manager.set_state(event.token, state)

+ 16 - 8
pynecone/state.py

@@ -18,6 +18,7 @@ from typing import (
     Optional,
     Sequence,
     Set,
+    Tuple,
     Type,
     Union,
 )
@@ -654,7 +655,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         self.clean()
 
         # Run the event generator and return state updates.
-        async for events in event_iter:
+        async for events, processing in event_iter:
             # Fix the returned events.
             events = fix_events(events, event.token)  # type: ignore
 
@@ -662,14 +663,14 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             delta = self.get_delta()
 
             # Yield the state update.
-            yield StateUpdate(delta=delta, events=events)
+            yield StateUpdate(delta=delta, events=events, processing=processing)
 
             # Clean the state to prepare for the next event.
             self.clean()
 
     async def _process_event(
         self, handler: EventHandler, state: State, payload: Dict
-    ) -> AsyncIterator[Optional[List[EventSpec]]]:
+    ) -> AsyncIterator[Tuple[Optional[List[EventSpec]], bool]]:
         """Process event.
 
         Args:
@@ -678,7 +679,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             payload: The event payload.
 
         Yields:
-            The state update after processing the event.
+            Tuple containing:
+                0: The state update after processing the event.
+                1: Whether the event is being processed.
         """
         # Get the function to process the event.
         fn = functools.partial(handler.fn, state)
@@ -696,22 +699,24 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             # Handle async generators.
             if inspect.isasyncgen(events):
                 async for event in events:
-                    yield event
+                    yield event, True
+                yield None, False
 
             # Handle regular generators.
             elif inspect.isgenerator(events):
                 for event in events:
-                    yield event
+                    yield event, True
+                yield None, False
 
             # Handle regular event chains.
             else:
-                yield events
+                yield events, False
 
         # If an error occurs, throw a window alert.
         except Exception:
             error = traceback.format_exc()
             print(error)
-            yield [window_alert("An error occurred. See logs for details.")]
+            yield [window_alert("An error occurred. See logs for details.")], False
 
     def _always_dirty_computed_vars(self) -> Set[str]:
         """The set of ComputedVars that always need to be recalculated.
@@ -876,6 +881,9 @@ class StateUpdate(Base):
     # Events to be added to the event queue.
     events: List[Event] = []
 
+    # Whether the event is still processing.
+    processing: bool = False
+
 
 class StateManager(Base):
     """A class to manage many client states."""

+ 9 - 6
tests/test_state.py

@@ -673,12 +673,15 @@ async def test_process_event_generator(gen_state):
     count = 0
     async for update in gen:
         count += 1
-        assert gen_state.value == count
-        assert update.delta == {
-            "gen_state": {"value": count},
-        }
-
-    assert count == 5
+        if count == 6:
+            assert update.delta == {}
+        else:
+            assert gen_state.value == count
+            assert update.delta == {
+                "gen_state": {"value": count},
+            }
+
+    assert count == 6
 
 
 def test_format_event_handler():