state.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805
  1. """Define the pynecone state specification."""
  2. from __future__ import annotations
  3. import asyncio
  4. import functools
  5. import traceback
  6. from abc import ABC
  7. from typing import (
  8. Any,
  9. Callable,
  10. ClassVar,
  11. Dict,
  12. List,
  13. Optional,
  14. Sequence,
  15. Set,
  16. Type,
  17. Union,
  18. )
  19. import cloudpickle
  20. from redis import Redis
  21. from pynecone import constants
  22. from pynecone.base import Base
  23. from pynecone.event import Event, EventHandler, fix_events, window_alert
  24. from pynecone.utils import format, prerequisites, types
  25. from pynecone.var import BaseVar, ComputedVar, PCDict, PCList, Var
  26. Delta = Dict[str, Any]
  27. class State(Base, ABC):
  28. """The state of the app."""
  29. # A map from the var name to the var.
  30. vars: ClassVar[Dict[str, Var]] = {}
  31. # The base vars of the class.
  32. base_vars: ClassVar[Dict[str, BaseVar]] = {}
  33. # The computed vars of the class.
  34. computed_vars: ClassVar[Dict[str, ComputedVar]] = {}
  35. # Vars inherited by the parent state.
  36. inherited_vars: ClassVar[Dict[str, Var]] = {}
  37. # Backend vars that are never sent to the client.
  38. backend_vars: ClassVar[Dict[str, Any]] = {}
  39. # Backend vars inherited
  40. inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
  41. # The event handlers.
  42. event_handlers: ClassVar[Dict[str, EventHandler]] = {}
  43. # The parent state.
  44. parent_state: Optional[State] = None
  45. # The substates of the state.
  46. substates: Dict[str, State] = {}
  47. # The set of dirty vars.
  48. dirty_vars: Set[str] = set()
  49. # The set of dirty substates.
  50. dirty_substates: Set[str] = set()
  51. # The routing path that triggered the state
  52. router_data: Dict[str, Any] = {}
  53. def __init__(self, *args, **kwargs):
  54. """Initialize the state.
  55. Args:
  56. *args: The args to pass to the Pydantic init method.
  57. **kwargs: The kwargs to pass to the Pydantic init method.
  58. """
  59. super().__init__(*args, **kwargs)
  60. # Setup the substates.
  61. for substate in self.get_substates():
  62. self.substates[substate.get_name()] = substate().set(parent_state=self)
  63. self._init_mutable_fields()
  64. def _init_mutable_fields(self):
  65. """Initialize mutable fields.
  66. So that mutation to them can be detected by the app:
  67. * list
  68. """
  69. for field in self.base_vars.values():
  70. value = getattr(self, field.name)
  71. value_in_pc_data = _convert_mutable_datatypes(
  72. value, self._reassign_field, field.name
  73. )
  74. if types._issubclass(field.type_, Union[List, Dict]):
  75. setattr(self, field.name, value_in_pc_data)
  76. self.clean()
  77. def _reassign_field(self, field_name: str):
  78. """Reassign the given field.
  79. Primarily for mutation in fields of mutable data types.
  80. Args:
  81. field_name: The name of the field we want to reassign
  82. """
  83. setattr(
  84. self,
  85. field_name,
  86. getattr(self, field_name),
  87. )
  88. def __repr__(self) -> str:
  89. """Get the string representation of the state.
  90. Returns:
  91. The string representation of the state.
  92. """
  93. return f"{self.__class__.__name__}({self.dict()})"
  94. @classmethod
  95. def __init_subclass__(cls, **kwargs):
  96. """Do some magic for the subclass initialization.
  97. Args:
  98. **kwargs: The kwargs to pass to the pydantic init_subclass method.
  99. """
  100. super().__init_subclass__(**kwargs)
  101. # Get the parent vars.
  102. parent_state = cls.get_parent_state()
  103. if parent_state is not None:
  104. cls.inherited_vars = parent_state.vars
  105. cls.inherited_backend_vars = parent_state.backend_vars
  106. cls.new_backend_vars = {
  107. name: value
  108. for name, value in cls.__dict__.items()
  109. if types.is_backend_variable(name)
  110. and name not in cls.inherited_backend_vars
  111. }
  112. cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
  113. # Set the base and computed vars.
  114. skip_vars = set(cls.inherited_vars) | {
  115. "parent_state",
  116. "substates",
  117. "dirty_vars",
  118. "dirty_substates",
  119. "router_data",
  120. }
  121. cls.base_vars = {
  122. f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls)
  123. for f in cls.get_fields().values()
  124. if f.name not in skip_vars
  125. }
  126. cls.computed_vars = {
  127. v.name: v.set_state(cls)
  128. for v in cls.__dict__.values()
  129. if isinstance(v, ComputedVar)
  130. }
  131. cls.vars = {
  132. **cls.inherited_vars,
  133. **cls.base_vars,
  134. **cls.computed_vars,
  135. }
  136. # Setup the base vars at the class level.
  137. for prop in cls.base_vars.values():
  138. cls._init_var(prop)
  139. # Set up the event handlers.
  140. events = {
  141. name: fn
  142. for name, fn in cls.__dict__.items()
  143. if not name.startswith("_") and isinstance(fn, Callable)
  144. }
  145. for name, fn in events.items():
  146. event_handler = EventHandler(fn=fn)
  147. cls.event_handlers[name] = event_handler
  148. cls.set_handlers()
  149. @classmethod
  150. def convert_handlers_to_fns(cls):
  151. """Convert the event handlers to functions.
  152. This is done so the state functions can be called as normal functions during runtime.
  153. """
  154. for name, event_handler in cls.event_handlers.items():
  155. setattr(cls, name, event_handler.fn)
  156. @classmethod
  157. def set_handlers(cls):
  158. """Set the state class handlers."""
  159. for name, event_handler in cls.event_handlers.items():
  160. setattr(cls, name, event_handler)
  161. @classmethod
  162. @functools.lru_cache()
  163. def get_parent_state(cls) -> Optional[Type[State]]:
  164. """Get the parent state.
  165. Returns:
  166. The parent state.
  167. """
  168. parent_states = [
  169. base
  170. for base in cls.__bases__
  171. if types._issubclass(base, State) and base is not State
  172. ]
  173. assert len(parent_states) < 2, "Only one parent state is allowed."
  174. return parent_states[0] if len(parent_states) == 1 else None # type: ignore
  175. @classmethod
  176. @functools.lru_cache()
  177. def get_substates(cls) -> Set[Type[State]]:
  178. """Get the substates of the state.
  179. Returns:
  180. The substates of the state.
  181. """
  182. return set(cls.__subclasses__())
  183. @classmethod
  184. @functools.lru_cache()
  185. def get_name(cls) -> str:
  186. """Get the name of the state.
  187. Returns:
  188. The name of the state.
  189. """
  190. return format.to_snake_case(cls.__name__)
  191. @classmethod
  192. @functools.lru_cache()
  193. def get_full_name(cls) -> str:
  194. """Get the full name of the state.
  195. Returns:
  196. The full name of the state.
  197. """
  198. name = cls.get_name()
  199. parent_state = cls.get_parent_state()
  200. if parent_state is not None:
  201. name = ".".join((parent_state.get_full_name(), name))
  202. return name
  203. @classmethod
  204. @functools.lru_cache()
  205. def get_class_substate(cls, path: Sequence[str]) -> Type[State]:
  206. """Get the class substate.
  207. Args:
  208. path: The path to the substate.
  209. Returns:
  210. The class substate.
  211. Raises:
  212. ValueError: If the substate is not found.
  213. """
  214. if len(path) == 0:
  215. return cls
  216. if path[0] == cls.get_name():
  217. if len(path) == 1:
  218. return cls
  219. path = path[1:]
  220. for substate in cls.get_substates():
  221. if path[0] == substate.get_name():
  222. return substate.get_class_substate(path[1:])
  223. raise ValueError(f"Invalid path: {path}")
  224. @classmethod
  225. def get_class_var(cls, path: Sequence[str]) -> Any:
  226. """Get the class var.
  227. Args:
  228. path: The path to the var.
  229. Returns:
  230. The class var.
  231. Raises:
  232. ValueError: If the path is invalid.
  233. """
  234. path, name = path[:-1], path[-1]
  235. substate = cls.get_class_substate(tuple(path))
  236. if not hasattr(substate, name):
  237. raise ValueError(f"Invalid path: {path}")
  238. return getattr(substate, name)
  239. @classmethod
  240. def _init_var(cls, prop: BaseVar):
  241. """Initialize a variable.
  242. Args:
  243. prop (BaseVar): The variable to initialize
  244. Raises:
  245. TypeError: if the variable has an incorrect type
  246. """
  247. if not types.is_valid_var_type(prop.type_):
  248. raise TypeError(
  249. "State vars must be primitive Python types, "
  250. "Plotly figures, Pandas dataframes, "
  251. "or subclasses of pc.Base. "
  252. f'Found var "{prop.name}" with type {prop.type_}.'
  253. )
  254. cls._set_var(prop)
  255. cls._create_setter(prop)
  256. cls._set_default_value(prop)
  257. @classmethod
  258. def add_var(cls, name: str, type_: Any, default_value: Any = None):
  259. """Add dynamically a variable to the State.
  260. The variable added this way can be used in the same way as a variable
  261. defined statically in the model.
  262. Args:
  263. name: The name of the variable
  264. type_: The type of the variable
  265. default_value: The default value of the variable
  266. Raises:
  267. NameError: if a variable of this name already exists
  268. """
  269. if name in cls.__fields__:
  270. raise NameError(
  271. f"The variable '{name}' already exist. Use a different name"
  272. )
  273. # create the variable based on name and type
  274. var = BaseVar(name=name, type_=type_)
  275. var.set_state(cls)
  276. # add the pydantic field dynamically (must be done before _init_var)
  277. cls.add_field(var, default_value)
  278. cls._init_var(var)
  279. # update the internal dicts so the new variable is correctly handled
  280. cls.base_vars.update({name: var})
  281. cls.vars.update({name: var})
  282. @classmethod
  283. def _set_var(cls, prop: BaseVar):
  284. """Set the var as a class member.
  285. Args:
  286. prop: The var instance to set.
  287. """
  288. setattr(cls, prop.name, prop)
  289. @classmethod
  290. def _create_setter(cls, prop: BaseVar):
  291. """Create a setter for the var.
  292. Args:
  293. prop: The var to create a setter for.
  294. """
  295. setter_name = prop.get_setter_name(include_state=False)
  296. if setter_name not in cls.__dict__:
  297. setattr(cls, setter_name, prop.get_setter())
  298. @classmethod
  299. def _set_default_value(cls, prop: BaseVar):
  300. """Set the default value for the var.
  301. Args:
  302. prop: The var to set the default value for.
  303. """
  304. # Get the pydantic field for the var.
  305. field = cls.get_fields()[prop.name]
  306. default_value = prop.get_default_value()
  307. if field.required and default_value is not None:
  308. field.required = False
  309. field.default = default_value
  310. def get_token(self) -> str:
  311. """Return the token of the client associated with this state.
  312. Returns:
  313. The token of the client.
  314. """
  315. return self.router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
  316. def get_sid(self) -> str:
  317. """Return the session ID of the client associated with this state.
  318. Returns:
  319. The session ID of the client.
  320. """
  321. return self.router_data.get(constants.RouteVar.SESSION_ID, "")
  322. def get_headers(self) -> Dict:
  323. """Return the headers of the client associated with this state.
  324. Returns:
  325. The headers of the client.
  326. """
  327. return self.router_data.get(constants.RouteVar.HEADERS, {})
  328. def get_client_ip(self) -> str:
  329. """Return the IP of the client associated with this state.
  330. Returns:
  331. The IP of the client.
  332. """
  333. return self.router_data.get(constants.RouteVar.CLIENT_IP, "")
  334. def get_current_page(self) -> str:
  335. """Obtain the path of current page from the router data.
  336. Returns:
  337. The current page.
  338. """
  339. return self.router_data.get(constants.RouteVar.PATH, "")
  340. def get_query_params(self) -> Dict[str, str]:
  341. """Obtain the query parameters for the queried page.
  342. The query object contains both the URI parameters and the GET parameters.
  343. Returns:
  344. The dict of query parameters.
  345. """
  346. return self.router_data.get(constants.RouteVar.QUERY, {})
  347. @classmethod
  348. def setup_dynamic_args(cls, args: dict[str, str]):
  349. """Set up args for easy access in renderer.
  350. Args:
  351. args: a dict of args
  352. """
  353. def argsingle_factory(param):
  354. @ComputedVar
  355. def inner_func(self) -> str:
  356. return self.get_query_params().get(param, "")
  357. return inner_func
  358. def arglist_factory(param):
  359. @ComputedVar
  360. def inner_func(self) -> List:
  361. return self.get_query_params().get(param, [])
  362. return inner_func
  363. for param, value in args.items():
  364. if value == constants.RouteArgType.SINGLE:
  365. func = argsingle_factory(param)
  366. elif value == constants.RouteArgType.LIST:
  367. func = arglist_factory(param)
  368. else:
  369. continue
  370. cls.computed_vars[param] = func.set_state(cls) # type: ignore
  371. setattr(cls, param, func)
  372. def __getattribute__(self, name: str) -> Any:
  373. """Get the state var.
  374. If the var is inherited, get the var from the parent state.
  375. Args:
  376. name: The name of the var.
  377. Returns:
  378. The value of the var.
  379. """
  380. inherited_vars = {
  381. **super().__getattribute__("inherited_vars"),
  382. **super().__getattribute__("inherited_backend_vars"),
  383. }
  384. if name in inherited_vars:
  385. return getattr(super().__getattribute__("parent_state"), name)
  386. elif name in super().__getattribute__("backend_vars"):
  387. return super().__getattribute__("backend_vars").__getitem__(name)
  388. return super().__getattribute__(name)
  389. def __setattr__(self, name: str, value: Any):
  390. """Set the attribute.
  391. If the attribute is inherited, set the attribute on the parent state.
  392. Args:
  393. name: The name of the attribute.
  394. value: The value of the attribute.
  395. """
  396. # Set the var on the parent state.
  397. inherited_vars = {**self.inherited_vars, **self.inherited_backend_vars}
  398. if name in inherited_vars:
  399. setattr(self.parent_state, name, value)
  400. return
  401. if types.is_backend_variable(name):
  402. self.backend_vars.__setitem__(name, value)
  403. self.mark_dirty()
  404. return
  405. # Set the attribute.
  406. super().__setattr__(name, value)
  407. # Add the var to the dirty list.
  408. if name in self.vars:
  409. self.dirty_vars.add(name)
  410. self.mark_dirty()
  411. def reset(self):
  412. """Reset all the base vars to their default values."""
  413. # Reset the base vars.
  414. fields = self.get_fields()
  415. for prop_name in self.base_vars:
  416. setattr(self, prop_name, fields[prop_name].default)
  417. # Recursively reset the substates.
  418. for substate in self.substates.values():
  419. substate.reset()
  420. # Clean the state.
  421. self.clean()
  422. def get_substate(self, path: Sequence[str]) -> Optional[State]:
  423. """Get the substate.
  424. Args:
  425. path: The path to the substate.
  426. Returns:
  427. The substate.
  428. Raises:
  429. ValueError: If the substate is not found.
  430. """
  431. if len(path) == 0:
  432. return self
  433. if path[0] == self.get_name():
  434. if len(path) == 1:
  435. return self
  436. path = path[1:]
  437. if path[0] not in self.substates:
  438. raise ValueError(f"Invalid path: {path}")
  439. return self.substates[path[0]].get_substate(path[1:])
  440. async def process(self, event: Event) -> StateUpdate:
  441. """Obtain event info and process event.
  442. Args:
  443. event: The event to process.
  444. Returns:
  445. The state update after processing the event.
  446. Raises:
  447. ValueError: If the state value is None.
  448. """
  449. # Get the event handler.
  450. path = event.name.split(".")
  451. path, name = path[:-1], path[-1]
  452. substate = self.get_substate(path)
  453. handler = substate.event_handlers[name] # type: ignore
  454. if not substate:
  455. raise ValueError(
  456. "The value of state cannot be None when processing an event."
  457. )
  458. return await self.process_event(
  459. handler=handler,
  460. state=substate,
  461. payload=event.payload,
  462. token=event.token,
  463. )
  464. async def process_event(
  465. self, handler: EventHandler, state: State, payload: Dict, token: str
  466. ) -> StateUpdate:
  467. """Process event.
  468. Args:
  469. handler: Eventhandler to process.
  470. state: State to process the handler.
  471. payload: The event payload.
  472. token: Client token.
  473. Returns:
  474. The state update after processing the event.
  475. """
  476. fn = functools.partial(handler.fn, state)
  477. try:
  478. if asyncio.iscoroutinefunction(fn.func):
  479. events = await fn(**payload)
  480. else:
  481. events = fn(**payload)
  482. except Exception:
  483. error = traceback.format_exc()
  484. print(error)
  485. events = fix_events(
  486. [window_alert("An error occurred. See logs for details.")], token
  487. )
  488. return StateUpdate(events=events)
  489. # Fix the returned events.
  490. events = fix_events(events, token)
  491. # Get the delta after processing the event.
  492. delta = self.get_delta()
  493. # Reset the dirty vars.
  494. self.clean()
  495. # Return the state update.
  496. return StateUpdate(delta=delta, events=events)
  497. def get_delta(self) -> Delta:
  498. """Get the delta for the state.
  499. Returns:
  500. The delta for the state.
  501. """
  502. delta = {}
  503. # Return the dirty vars, as well as all computed vars.
  504. subdelta = {
  505. prop: getattr(self, prop)
  506. for prop in self.dirty_vars | self.computed_vars.keys()
  507. }
  508. if len(subdelta) > 0:
  509. delta[self.get_full_name()] = subdelta
  510. # Recursively find the substate deltas.
  511. substates = self.substates
  512. for substate in self.dirty_substates:
  513. delta.update(substates[substate].get_delta())
  514. # Format the delta.
  515. delta = format.format_state(delta)
  516. # Return the delta.
  517. return delta
  518. def mark_dirty(self):
  519. """Mark the substate and all parent states as dirty."""
  520. if self.parent_state is not None:
  521. self.parent_state.dirty_substates.add(self.get_name())
  522. self.parent_state.mark_dirty()
  523. def clean(self):
  524. """Reset the dirty vars."""
  525. # Recursively clean the substates.
  526. for substate in self.dirty_substates:
  527. self.substates[substate].clean()
  528. # Clean this state.
  529. self.dirty_vars = set()
  530. self.dirty_substates = set()
  531. def dict(self, include_computed: bool = True, **kwargs) -> Dict[str, Any]:
  532. """Convert the object to a dictionary.
  533. Args:
  534. include_computed: Whether to include computed vars.
  535. **kwargs: Kwargs to pass to the pydantic dict method.
  536. Returns:
  537. The object as a dictionary.
  538. """
  539. base_vars = {
  540. prop_name: self.get_value(getattr(self, prop_name))
  541. for prop_name in self.base_vars
  542. }
  543. computed_vars = (
  544. {
  545. # Include the computed vars.
  546. prop_name: self.get_value(getattr(self, prop_name))
  547. for prop_name in self.computed_vars
  548. }
  549. if include_computed
  550. else {}
  551. )
  552. substate_vars = {
  553. k: v.dict(include_computed=include_computed, **kwargs)
  554. for k, v in self.substates.items()
  555. }
  556. variables = {**base_vars, **computed_vars, **substate_vars}
  557. return {k: variables[k] for k in sorted(variables)}
  558. class DefaultState(State):
  559. """The default empty state."""
  560. pass
  561. class StateUpdate(Base):
  562. """A state update sent to the frontend."""
  563. # The state delta.
  564. delta: Delta = {}
  565. # Events to be added to the event queue.
  566. events: List[Event] = []
  567. class StateManager(Base):
  568. """A class to manage many client states."""
  569. # The state class to use.
  570. state: Type[State] = DefaultState
  571. # The mapping of client ids to states.
  572. states: Dict[str, State] = {}
  573. # The token expiration time (s).
  574. token_expiration: int = constants.TOKEN_EXPIRATION
  575. # The redis client to use.
  576. redis: Optional[Redis] = None
  577. def setup(self, state: Type[State]):
  578. """Set up the state manager.
  579. Args:
  580. state: The state class to use.
  581. """
  582. self.state = state
  583. self.redis = prerequisites.get_redis()
  584. def get_state(self, token: str) -> State:
  585. """Get the state for a token.
  586. Args:
  587. token: The token to get the state for.
  588. Returns:
  589. The state for the token.
  590. """
  591. if self.redis is not None:
  592. redis_state = self.redis.get(token)
  593. if redis_state is None:
  594. self.set_state(token, self.state())
  595. return self.get_state(token)
  596. return cloudpickle.loads(redis_state)
  597. if token not in self.states:
  598. self.states[token] = self.state()
  599. return self.states[token]
  600. def set_state(self, token: str, state: State):
  601. """Set the state for a token.
  602. Args:
  603. token: The token to set the state for.
  604. state: The state to set.
  605. """
  606. if self.redis is None:
  607. return
  608. self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
  609. def _convert_mutable_datatypes(
  610. field_value: Any, reassign_field: Callable, field_name: str
  611. ) -> Any:
  612. """Recursively convert mutable data to the Pc data types.
  613. Note: right now only list & dict would be handled recursively.
  614. Args:
  615. field_value: The target field_value.
  616. reassign_field:
  617. The function to reassign the field in the parent state.
  618. field_name: the name of the field in the parent state
  619. Returns:
  620. The converted field_value
  621. """
  622. if isinstance(field_value, list):
  623. for index in range(len(field_value)):
  624. field_value[index] = _convert_mutable_datatypes(
  625. field_value[index], reassign_field, field_name
  626. )
  627. field_value = PCList(
  628. field_value, reassign_field=reassign_field, field_name=field_name
  629. )
  630. if isinstance(field_value, dict):
  631. for key, value in field_value.items():
  632. field_value[key] = _convert_mutable_datatypes(
  633. value, reassign_field, field_name
  634. )
  635. field_value = PCDict(
  636. field_value, reassign_field=reassign_field, field_name=field_name
  637. )
  638. return field_value