|
@@ -4,11 +4,13 @@ from __future__ import annotations
|
|
|
import asyncio
|
|
|
import copy
|
|
|
import functools
|
|
|
+import inspect
|
|
|
import traceback
|
|
|
from abc import ABC
|
|
|
from collections import defaultdict
|
|
|
from typing import (
|
|
|
Any,
|
|
|
+ AsyncIterator,
|
|
|
Callable,
|
|
|
ClassVar,
|
|
|
Dict,
|
|
@@ -26,7 +28,7 @@ from redis import Redis
|
|
|
|
|
|
from pynecone import constants
|
|
|
from pynecone.base import Base
|
|
|
-from pynecone.event import Event, EventHandler, fix_events, window_alert
|
|
|
+from pynecone.event import Event, EventHandler, EventSpec, fix_events, window_alert
|
|
|
from pynecone.utils import format, prerequisites, types
|
|
|
from pynecone.vars import BaseVar, ComputedVar, PCDict, PCList, Var
|
|
|
|
|
@@ -618,13 +620,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
raise ValueError(f"Invalid path: {path}")
|
|
|
return self.substates[path[0]].get_substate(path[1:])
|
|
|
|
|
|
- async def _process(self, event: Event) -> StateUpdate:
|
|
|
+ async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
|
|
|
"""Obtain event info and process event.
|
|
|
|
|
|
Args:
|
|
|
event: The event to process.
|
|
|
|
|
|
- Returns:
|
|
|
+ Yields:
|
|
|
The state update after processing the event.
|
|
|
|
|
|
Raises:
|
|
@@ -641,52 +643,75 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
"The value of state cannot be None when processing an event."
|
|
|
)
|
|
|
|
|
|
- return await self._process_event(
|
|
|
+ # Get the event generator.
|
|
|
+ event_iter = self._process_event(
|
|
|
handler=handler,
|
|
|
state=substate,
|
|
|
payload=event.payload,
|
|
|
- token=event.token,
|
|
|
)
|
|
|
|
|
|
+ # Clean the state before processing the event.
|
|
|
+ self.clean()
|
|
|
+
|
|
|
+ # Run the event generator and return state updates.
|
|
|
+ async for events in event_iter:
|
|
|
+ # Fix the returned events.
|
|
|
+ events = fix_events(events, event.token) # type: ignore
|
|
|
+
|
|
|
+ # Get the delta after processing the event.
|
|
|
+ delta = self.get_delta()
|
|
|
+
|
|
|
+ # Yield the state update.
|
|
|
+ yield StateUpdate(delta=delta, events=events)
|
|
|
+
|
|
|
+ # Clean the state to prepare for the next event.
|
|
|
+ self.clean()
|
|
|
+
|
|
|
async def _process_event(
|
|
|
- self, handler: EventHandler, state: State, payload: Dict, token: str
|
|
|
- ) -> StateUpdate:
|
|
|
+ self, handler: EventHandler, state: State, payload: Dict
|
|
|
+ ) -> AsyncIterator[Optional[List[EventSpec]]]:
|
|
|
"""Process event.
|
|
|
|
|
|
Args:
|
|
|
handler: Eventhandler to process.
|
|
|
state: State to process the handler.
|
|
|
payload: The event payload.
|
|
|
- token: Client token.
|
|
|
|
|
|
- Returns:
|
|
|
+ Yields:
|
|
|
The state update after processing the event.
|
|
|
"""
|
|
|
+ # Get the function to process the event.
|
|
|
fn = functools.partial(handler.fn, state)
|
|
|
+
|
|
|
+ # Wrap the function in a try/except block.
|
|
|
try:
|
|
|
+ # Handle async functions.
|
|
|
if asyncio.iscoroutinefunction(fn.func):
|
|
|
events = await fn(**payload)
|
|
|
+
|
|
|
+ # Handle regular functions.
|
|
|
else:
|
|
|
events = fn(**payload)
|
|
|
- except Exception:
|
|
|
- error = traceback.format_exc()
|
|
|
- print(error)
|
|
|
- events = fix_events(
|
|
|
- [window_alert("An error occurred. See logs for details.")], token
|
|
|
- )
|
|
|
- return StateUpdate(events=events)
|
|
|
|
|
|
- # Fix the returned events.
|
|
|
- events = fix_events(events, token)
|
|
|
+ # Handle async generators.
|
|
|
+ if inspect.isasyncgen(events):
|
|
|
+ async for event in events:
|
|
|
+ yield event
|
|
|
|
|
|
- # Get the delta after processing the event.
|
|
|
- delta = self.get_delta()
|
|
|
+ # Handle regular generators.
|
|
|
+ elif inspect.isgenerator(events):
|
|
|
+ for event in events:
|
|
|
+ yield event
|
|
|
|
|
|
- # Reset the dirty vars.
|
|
|
- self.clean()
|
|
|
+ # Handle regular event chains.
|
|
|
+ else:
|
|
|
+ yield events
|
|
|
|
|
|
- # Return the state update.
|
|
|
- return StateUpdate(delta=delta, events=events)
|
|
|
+ # If an error occurs, throw a window alert.
|
|
|
+ except Exception:
|
|
|
+ error = traceback.format_exc()
|
|
|
+ print(error)
|
|
|
+ yield [window_alert("An error occurred. See logs for details.")]
|
|
|
|
|
|
def _always_dirty_computed_vars(self) -> Set[str]:
|
|
|
"""The set of ComputedVars that always need to be recalculated.
|