state.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886
  1. """Define the pynecone state specification."""
  2. from __future__ import annotations
  3. import asyncio
  4. import functools
  5. import traceback
  6. from abc import ABC
  7. from collections import defaultdict
  8. from typing import (
  9. Any,
  10. Callable,
  11. ClassVar,
  12. Dict,
  13. List,
  14. Optional,
  15. Sequence,
  16. Set,
  17. Type,
  18. Union,
  19. )
  20. import cloudpickle
  21. import pydantic
  22. from redis import Redis
  23. from pynecone import constants
  24. from pynecone.base import Base
  25. from pynecone.event import Event, EventHandler, fix_events, window_alert
  26. from pynecone.utils import format, prerequisites, types
  27. from pynecone.var import BaseVar, ComputedVar, PCDict, PCList, Var
  28. Delta = Dict[str, Any]
  29. class State(Base, ABC, extra=pydantic.Extra.allow):
  30. """The state of the app."""
  31. # A map from the var name to the var.
  32. vars: ClassVar[Dict[str, Var]] = {}
  33. # The base vars of the class.
  34. base_vars: ClassVar[Dict[str, BaseVar]] = {}
  35. # The computed vars of the class.
  36. computed_vars: ClassVar[Dict[str, ComputedVar]] = {}
  37. # Vars inherited by the parent state.
  38. inherited_vars: ClassVar[Dict[str, Var]] = {}
  39. # Backend vars that are never sent to the client.
  40. backend_vars: ClassVar[Dict[str, Any]] = {}
  41. # Backend vars inherited
  42. inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
  43. # The event handlers.
  44. event_handlers: ClassVar[Dict[str, EventHandler]] = {}
  45. # The parent state.
  46. parent_state: Optional[State] = None
  47. # The substates of the state.
  48. substates: Dict[str, State] = {}
  49. # The set of dirty vars.
  50. dirty_vars: Set[str] = set()
  51. # The set of dirty substates.
  52. dirty_substates: Set[str] = set()
  53. # The routing path that triggered the state
  54. router_data: Dict[str, Any] = {}
  55. # Mapping of var name to set of computed variables that depend on it
  56. computed_var_dependencies: Dict[str, Set[str]] = {}
  57. # Whether to track accessed vars.
  58. track_vars: bool = False
  59. # The current set of accessed vars during tracking.
  60. tracked_vars: Set[str] = set()
  61. def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
  62. """Initialize the state.
  63. Args:
  64. *args: The args to pass to the Pydantic init method.
  65. parent_state: The parent state.
  66. **kwargs: The kwargs to pass to the Pydantic init method.
  67. """
  68. kwargs["parent_state"] = parent_state
  69. super().__init__(*args, **kwargs)
  70. # Setup the substates.
  71. for substate in self.get_substates():
  72. self.substates[substate.get_name()] = substate(parent_state=self)
  73. # Convert the event handlers to functions.
  74. for name, event_handler in self.event_handlers.items():
  75. fn = functools.partial(event_handler.fn, self)
  76. fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
  77. setattr(self, name, fn)
  78. # Initialize the mutable fields.
  79. self._init_mutable_fields()
  80. # Initialize computed vars dependencies.
  81. self.computed_var_dependencies = defaultdict(set)
  82. for cvar in self.computed_vars:
  83. self.tracked_vars = set()
  84. # Enable tracking and get the computed var.
  85. self.track_vars = True
  86. self.__getattribute__(cvar)
  87. self.track_vars = False
  88. # Add the dependencies.
  89. for var in self.tracked_vars:
  90. self.computed_var_dependencies[var].add(cvar)
  91. def _init_mutable_fields(self):
  92. """Initialize mutable fields.
  93. So that mutation to them can be detected by the app:
  94. * list
  95. """
  96. for field in self.base_vars.values():
  97. value = getattr(self, field.name)
  98. value_in_pc_data = _convert_mutable_datatypes(
  99. value, self._reassign_field, field.name
  100. )
  101. if types._issubclass(field.type_, Union[List, Dict]):
  102. setattr(self, field.name, value_in_pc_data)
  103. self.clean()
  104. def _reassign_field(self, field_name: str):
  105. """Reassign the given field.
  106. Primarily for mutation in fields of mutable data types.
  107. Args:
  108. field_name: The name of the field we want to reassign
  109. """
  110. setattr(
  111. self,
  112. field_name,
  113. getattr(self, field_name),
  114. )
  115. def __repr__(self) -> str:
  116. """Get the string representation of the state.
  117. Returns:
  118. The string representation of the state.
  119. """
  120. return f"{self.__class__.__name__}({self.dict()})"
  121. @classmethod
  122. def __init_subclass__(cls, **kwargs):
  123. """Do some magic for the subclass initialization.
  124. Args:
  125. **kwargs: The kwargs to pass to the pydantic init_subclass method.
  126. """
  127. super().__init_subclass__(**kwargs)
  128. # Get the parent vars.
  129. parent_state = cls.get_parent_state()
  130. if parent_state is not None:
  131. cls.inherited_vars = parent_state.vars
  132. cls.inherited_backend_vars = parent_state.backend_vars
  133. cls.new_backend_vars = {
  134. name: value
  135. for name, value in cls.__dict__.items()
  136. if types.is_backend_variable(name)
  137. and name not in cls.inherited_backend_vars
  138. }
  139. cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
  140. # Set the base and computed vars.
  141. cls.base_vars = {
  142. f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls)
  143. for f in cls.get_fields().values()
  144. if f.name not in cls.get_skip_vars()
  145. }
  146. cls.computed_vars = {
  147. v.name: v.set_state(cls)
  148. for v in cls.__dict__.values()
  149. if isinstance(v, ComputedVar)
  150. }
  151. cls.vars = {
  152. **cls.inherited_vars,
  153. **cls.base_vars,
  154. **cls.computed_vars,
  155. }
  156. cls.computed_var_dependencies = {}
  157. cls.event_handlers = {}
  158. # Setup the base vars at the class level.
  159. for prop in cls.base_vars.values():
  160. cls._init_var(prop)
  161. # Set up the event handlers.
  162. events = {
  163. name: fn
  164. for name, fn in cls.__dict__.items()
  165. if not name.startswith("_")
  166. and isinstance(fn, Callable)
  167. and not isinstance(fn, EventHandler)
  168. }
  169. for name, fn in events.items():
  170. handler = EventHandler(fn=fn)
  171. cls.event_handlers[name] = handler
  172. setattr(cls, name, handler)
  173. @classmethod
  174. def get_skip_vars(cls) -> Set[str]:
  175. """Get the vars to skip when serializing.
  176. Returns:
  177. The vars to skip when serializing.
  178. """
  179. return set(cls.inherited_vars) | {
  180. "parent_state",
  181. "substates",
  182. "dirty_vars",
  183. "dirty_substates",
  184. "router_data",
  185. "computed_var_dependencies",
  186. "track_vars",
  187. "tracked_vars",
  188. }
  189. @classmethod
  190. @functools.lru_cache()
  191. def get_parent_state(cls) -> Optional[Type[State]]:
  192. """Get the parent state.
  193. Returns:
  194. The parent state.
  195. """
  196. parent_states = [
  197. base
  198. for base in cls.__bases__
  199. if types._issubclass(base, State) and base is not State
  200. ]
  201. assert len(parent_states) < 2, "Only one parent state is allowed."
  202. return parent_states[0] if len(parent_states) == 1 else None # type: ignore
  203. @classmethod
  204. @functools.lru_cache()
  205. def get_substates(cls) -> Set[Type[State]]:
  206. """Get the substates of the state.
  207. Returns:
  208. The substates of the state.
  209. """
  210. return set(cls.__subclasses__())
  211. @classmethod
  212. @functools.lru_cache()
  213. def get_name(cls) -> str:
  214. """Get the name of the state.
  215. Returns:
  216. The name of the state.
  217. """
  218. return format.to_snake_case(cls.__name__)
  219. @classmethod
  220. @functools.lru_cache()
  221. def get_full_name(cls) -> str:
  222. """Get the full name of the state.
  223. Returns:
  224. The full name of the state.
  225. """
  226. name = cls.get_name()
  227. parent_state = cls.get_parent_state()
  228. if parent_state is not None:
  229. name = ".".join((parent_state.get_full_name(), name))
  230. return name
  231. @classmethod
  232. @functools.lru_cache()
  233. def get_class_substate(cls, path: Sequence[str]) -> Type[State]:
  234. """Get the class substate.
  235. Args:
  236. path: The path to the substate.
  237. Returns:
  238. The class substate.
  239. Raises:
  240. ValueError: If the substate is not found.
  241. """
  242. if len(path) == 0:
  243. return cls
  244. if path[0] == cls.get_name():
  245. if len(path) == 1:
  246. return cls
  247. path = path[1:]
  248. for substate in cls.get_substates():
  249. if path[0] == substate.get_name():
  250. return substate.get_class_substate(path[1:])
  251. raise ValueError(f"Invalid path: {path}")
  252. @classmethod
  253. def get_class_var(cls, path: Sequence[str]) -> Any:
  254. """Get the class var.
  255. Args:
  256. path: The path to the var.
  257. Returns:
  258. The class var.
  259. Raises:
  260. ValueError: If the path is invalid.
  261. """
  262. path, name = path[:-1], path[-1]
  263. substate = cls.get_class_substate(tuple(path))
  264. if not hasattr(substate, name):
  265. raise ValueError(f"Invalid path: {path}")
  266. return getattr(substate, name)
  267. @classmethod
  268. def _init_var(cls, prop: BaseVar):
  269. """Initialize a variable.
  270. Args:
  271. prop (BaseVar): The variable to initialize
  272. Raises:
  273. TypeError: if the variable has an incorrect type
  274. """
  275. if not types.is_valid_var_type(prop.type_):
  276. raise TypeError(
  277. "State vars must be primitive Python types, "
  278. "Plotly figures, Pandas dataframes, "
  279. "or subclasses of pc.Base. "
  280. f'Found var "{prop.name}" with type {prop.type_}.'
  281. )
  282. cls._set_var(prop)
  283. cls._create_setter(prop)
  284. cls._set_default_value(prop)
  285. @classmethod
  286. def add_var(cls, name: str, type_: Any, default_value: Any = None):
  287. """Add dynamically a variable to the State.
  288. The variable added this way can be used in the same way as a variable
  289. defined statically in the model.
  290. Args:
  291. name: The name of the variable
  292. type_: The type of the variable
  293. default_value: The default value of the variable
  294. Raises:
  295. NameError: if a variable of this name already exists
  296. """
  297. if name in cls.__fields__:
  298. raise NameError(
  299. f"The variable '{name}' already exist. Use a different name"
  300. )
  301. # create the variable based on name and type
  302. var = BaseVar(name=name, type_=type_)
  303. var.set_state(cls)
  304. # add the pydantic field dynamically (must be done before _init_var)
  305. cls.add_field(var, default_value)
  306. cls._init_var(var)
  307. # update the internal dicts so the new variable is correctly handled
  308. cls.base_vars.update({name: var})
  309. cls.vars.update({name: var})
  310. @classmethod
  311. def _set_var(cls, prop: BaseVar):
  312. """Set the var as a class member.
  313. Args:
  314. prop: The var instance to set.
  315. """
  316. setattr(cls, prop.name, prop)
  317. @classmethod
  318. def _create_setter(cls, prop: BaseVar):
  319. """Create a setter for the var.
  320. Args:
  321. prop: The var to create a setter for.
  322. """
  323. setter_name = prop.get_setter_name(include_state=False)
  324. if setter_name not in cls.__dict__:
  325. event_handler = EventHandler(fn=prop.get_setter())
  326. cls.event_handlers[setter_name] = event_handler
  327. setattr(cls, setter_name, event_handler)
  328. @classmethod
  329. def _set_default_value(cls, prop: BaseVar):
  330. """Set the default value for the var.
  331. Args:
  332. prop: The var to set the default value for.
  333. """
  334. # Get the pydantic field for the var.
  335. field = cls.get_fields()[prop.name]
  336. default_value = prop.get_default_value()
  337. if field.required and default_value is not None:
  338. field.required = False
  339. field.default = default_value
  340. def get_token(self) -> str:
  341. """Return the token of the client associated with this state.
  342. Returns:
  343. The token of the client.
  344. """
  345. return self.router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
  346. def get_sid(self) -> str:
  347. """Return the session ID of the client associated with this state.
  348. Returns:
  349. The session ID of the client.
  350. """
  351. return self.router_data.get(constants.RouteVar.SESSION_ID, "")
  352. def get_headers(self) -> Dict:
  353. """Return the headers of the client associated with this state.
  354. Returns:
  355. The headers of the client.
  356. """
  357. return self.router_data.get(constants.RouteVar.HEADERS, {})
  358. def get_client_ip(self) -> str:
  359. """Return the IP of the client associated with this state.
  360. Returns:
  361. The IP of the client.
  362. """
  363. return self.router_data.get(constants.RouteVar.CLIENT_IP, "")
  364. def get_current_page(self) -> str:
  365. """Obtain the path of current page from the router data.
  366. Returns:
  367. The current page.
  368. """
  369. return self.router_data.get(constants.RouteVar.PATH, "")
  370. def get_query_params(self) -> Dict[str, str]:
  371. """Obtain the query parameters for the queried page.
  372. The query object contains both the URI parameters and the GET parameters.
  373. Returns:
  374. The dict of query parameters.
  375. """
  376. return self.router_data.get(constants.RouteVar.QUERY, {})
  377. @classmethod
  378. def setup_dynamic_args(cls, args: dict[str, str]):
  379. """Set up args for easy access in renderer.
  380. Args:
  381. args: a dict of args
  382. """
  383. def argsingle_factory(param):
  384. @ComputedVar
  385. def inner_func(self) -> str:
  386. return self.get_query_params().get(param, "")
  387. return inner_func
  388. def arglist_factory(param):
  389. @ComputedVar
  390. def inner_func(self) -> List:
  391. return self.get_query_params().get(param, [])
  392. return inner_func
  393. for param, value in args.items():
  394. if value == constants.RouteArgType.SINGLE:
  395. func = argsingle_factory(param)
  396. elif value == constants.RouteArgType.LIST:
  397. func = arglist_factory(param)
  398. else:
  399. continue
  400. cls.computed_vars[param] = func.set_state(cls) # type: ignore
  401. setattr(cls, param, func)
  402. def __getattribute__(self, name: str) -> Any:
  403. """Get the state var.
  404. If the var is inherited, get the var from the parent state.
  405. If the Var is a dependent of a ComputedVar, track this status in computed_var_dependencies.
  406. Args:
  407. name: The name of the var.
  408. Returns:
  409. The value of the var.
  410. """
  411. # If the state hasn't been initialized yet, return the default value.
  412. if not super().__getattribute__("__dict__"):
  413. return super().__getattribute__(name)
  414. # Check if tracking is enabled.
  415. if super().__getattribute__("track_vars"):
  416. # Get the non-computed vars.
  417. all_vars = {
  418. **super().__getattribute__("vars"),
  419. **super().__getattribute__("backend_vars"),
  420. }
  421. # Add the var to the tracked vars.
  422. if name in all_vars:
  423. super().__getattribute__("tracked_vars").add(name)
  424. inherited_vars = {
  425. **super().__getattribute__("inherited_vars"),
  426. **super().__getattribute__("inherited_backend_vars"),
  427. }
  428. if name in inherited_vars:
  429. return getattr(super().__getattribute__("parent_state"), name)
  430. elif name in super().__getattribute__("backend_vars"):
  431. return super().__getattribute__("backend_vars").__getitem__(name)
  432. return super().__getattribute__(name)
  433. def __setattr__(self, name: str, value: Any):
  434. """Set the attribute.
  435. If the attribute is inherited, set the attribute on the parent state.
  436. Args:
  437. name: The name of the attribute.
  438. value: The value of the attribute.
  439. """
  440. # Set the var on the parent state.
  441. inherited_vars = {**self.inherited_vars, **self.inherited_backend_vars}
  442. if name in inherited_vars:
  443. setattr(self.parent_state, name, value)
  444. return
  445. if types.is_backend_variable(name):
  446. self.backend_vars.__setitem__(name, value)
  447. self.dirty_vars.add(name)
  448. self.mark_dirty()
  449. return
  450. # Set the attribute.
  451. super().__setattr__(name, value)
  452. # Add the var to the dirty list.
  453. if name in self.vars:
  454. self.dirty_vars.add(name)
  455. self.mark_dirty()
  456. def reset(self):
  457. """Reset all the base vars to their default values."""
  458. # Reset the base vars.
  459. fields = self.get_fields()
  460. for prop_name in self.base_vars:
  461. setattr(self, prop_name, fields[prop_name].default)
  462. # Recursively reset the substates.
  463. for substate in self.substates.values():
  464. substate.reset()
  465. # Clean the state.
  466. self.clean()
  467. def get_substate(self, path: Sequence[str]) -> Optional[State]:
  468. """Get the substate.
  469. Args:
  470. path: The path to the substate.
  471. Returns:
  472. The substate.
  473. Raises:
  474. ValueError: If the substate is not found.
  475. """
  476. if len(path) == 0:
  477. return self
  478. if path[0] == self.get_name():
  479. if len(path) == 1:
  480. return self
  481. path = path[1:]
  482. if path[0] not in self.substates:
  483. raise ValueError(f"Invalid path: {path}")
  484. return self.substates[path[0]].get_substate(path[1:])
  485. async def _process(self, event: Event) -> StateUpdate:
  486. """Obtain event info and process event.
  487. Args:
  488. event: The event to process.
  489. Returns:
  490. The state update after processing the event.
  491. Raises:
  492. ValueError: If the state value is None.
  493. """
  494. # Get the event handler.
  495. path = event.name.split(".")
  496. path, name = path[:-1], path[-1]
  497. substate = self.get_substate(path)
  498. handler = substate.event_handlers[name] # type: ignore
  499. if not substate:
  500. raise ValueError(
  501. "The value of state cannot be None when processing an event."
  502. )
  503. return await self._process_event(
  504. handler=handler,
  505. state=substate,
  506. payload=event.payload,
  507. token=event.token,
  508. )
  509. async def _process_event(
  510. self, handler: EventHandler, state: State, payload: Dict, token: str
  511. ) -> StateUpdate:
  512. """Process event.
  513. Args:
  514. handler: Eventhandler to process.
  515. state: State to process the handler.
  516. payload: The event payload.
  517. token: Client token.
  518. Returns:
  519. The state update after processing the event.
  520. """
  521. fn = functools.partial(handler.fn, state)
  522. try:
  523. if asyncio.iscoroutinefunction(fn.func):
  524. events = await fn(**payload)
  525. else:
  526. events = fn(**payload)
  527. except Exception:
  528. error = traceback.format_exc()
  529. print(error)
  530. events = fix_events(
  531. [window_alert("An error occurred. See logs for details.")], token
  532. )
  533. return StateUpdate(events=events)
  534. # Fix the returned events.
  535. events = fix_events(events, token)
  536. # Get the delta after processing the event.
  537. delta = self.get_delta()
  538. # Reset the dirty vars.
  539. self.clean()
  540. # Return the state update.
  541. return StateUpdate(delta=delta, events=events)
  542. def _dirty_computed_vars(
  543. self, from_vars: Optional[Set[str]] = None, check: bool = False
  544. ) -> Set[str]:
  545. """Get ComputedVars that need to be recomputed based on dirty_vars.
  546. Args:
  547. from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
  548. check: Whether to perform the check.
  549. Returns:
  550. Set of computed vars to include in the delta.
  551. """
  552. # If checking is disabled, return all computed vars.
  553. if not check:
  554. return set(self.computed_vars)
  555. # Return only the computed vars that depend on the dirty vars.
  556. return set(
  557. cvar
  558. for dirty_var in from_vars or self.dirty_vars
  559. for cvar in self.computed_vars
  560. if cvar in self.computed_var_dependencies.get(dirty_var, set())
  561. )
  562. def get_delta(self, check: bool = False) -> Delta:
  563. """Get the delta for the state.
  564. Args:
  565. check: Whether to check for dirty computed vars.
  566. Returns:
  567. The delta for the state.
  568. """
  569. delta = {}
  570. # Return the dirty vars, as well as computed vars depending on dirty vars.
  571. subdelta = {
  572. prop: getattr(self, prop)
  573. for prop in self.dirty_vars | self._dirty_computed_vars(check=check)
  574. if not types.is_backend_variable(prop)
  575. }
  576. if len(subdelta) > 0:
  577. delta[self.get_full_name()] = subdelta
  578. # Recursively find the substate deltas.
  579. substates = self.substates
  580. for substate in self.dirty_substates:
  581. delta.update(substates[substate].get_delta())
  582. # Format the delta.
  583. delta = format.format_state(delta)
  584. # Return the delta.
  585. return delta
  586. def mark_dirty(self):
  587. """Mark the substate and all parent states as dirty."""
  588. if self.parent_state is not None:
  589. self.parent_state.dirty_substates.add(self.get_name())
  590. self.parent_state.mark_dirty()
  591. def clean(self):
  592. """Reset the dirty vars."""
  593. # Recursively clean the substates.
  594. for substate in self.dirty_substates:
  595. self.substates[substate].clean()
  596. # Clean this state.
  597. self.dirty_vars = set()
  598. self.dirty_substates = set()
  599. def dict(self, include_computed: bool = True, **kwargs) -> Dict[str, Any]:
  600. """Convert the object to a dictionary.
  601. Args:
  602. include_computed: Whether to include computed vars.
  603. **kwargs: Kwargs to pass to the pydantic dict method.
  604. Returns:
  605. The object as a dictionary.
  606. """
  607. base_vars = {
  608. prop_name: self.get_value(getattr(self, prop_name))
  609. for prop_name in self.base_vars
  610. }
  611. computed_vars = (
  612. {
  613. # Include the computed vars.
  614. prop_name: self.get_value(getattr(self, prop_name))
  615. for prop_name in self.computed_vars
  616. }
  617. if include_computed
  618. else {}
  619. )
  620. substate_vars = {
  621. k: v.dict(include_computed=include_computed, **kwargs)
  622. for k, v in self.substates.items()
  623. }
  624. variables = {**base_vars, **computed_vars, **substate_vars}
  625. return {k: variables[k] for k in sorted(variables)}
  626. class DefaultState(State):
  627. """The default empty state."""
  628. pass
  629. class StateUpdate(Base):
  630. """A state update sent to the frontend."""
  631. # The state delta.
  632. delta: Delta = {}
  633. # Events to be added to the event queue.
  634. events: List[Event] = []
  635. class StateManager(Base):
  636. """A class to manage many client states."""
  637. # The state class to use.
  638. state: Type[State] = DefaultState
  639. # The mapping of client ids to states.
  640. states: Dict[str, State] = {}
  641. # The token expiration time (s).
  642. token_expiration: int = constants.TOKEN_EXPIRATION
  643. # The redis client to use.
  644. redis: Optional[Redis] = None
  645. def setup(self, state: Type[State]):
  646. """Set up the state manager.
  647. Args:
  648. state: The state class to use.
  649. """
  650. self.state = state
  651. self.redis = prerequisites.get_redis()
  652. def get_state(self, token: str) -> State:
  653. """Get the state for a token.
  654. Args:
  655. token: The token to get the state for.
  656. Returns:
  657. The state for the token.
  658. """
  659. if self.redis is not None:
  660. redis_state = self.redis.get(token)
  661. if redis_state is None:
  662. self.set_state(token, self.state())
  663. return self.get_state(token)
  664. return cloudpickle.loads(redis_state)
  665. if token not in self.states:
  666. self.states[token] = self.state()
  667. return self.states[token]
  668. def set_state(self, token: str, state: State):
  669. """Set the state for a token.
  670. Args:
  671. token: The token to set the state for.
  672. state: The state to set.
  673. """
  674. if self.redis is None:
  675. return
  676. self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
  677. def _convert_mutable_datatypes(
  678. field_value: Any, reassign_field: Callable, field_name: str
  679. ) -> Any:
  680. """Recursively convert mutable data to the Pc data types.
  681. Note: right now only list & dict would be handled recursively.
  682. Args:
  683. field_value: The target field_value.
  684. reassign_field:
  685. The function to reassign the field in the parent state.
  686. field_name: the name of the field in the parent state
  687. Returns:
  688. The converted field_value
  689. """
  690. if isinstance(field_value, list):
  691. for index in range(len(field_value)):
  692. field_value[index] = _convert_mutable_datatypes(
  693. field_value[index], reassign_field, field_name
  694. )
  695. field_value = PCList(
  696. field_value, reassign_field=reassign_field, field_name=field_name
  697. )
  698. if isinstance(field_value, dict):
  699. for key, value in field_value.items():
  700. field_value[key] = _convert_mutable_datatypes(
  701. value, reassign_field, field_name
  702. )
  703. field_value = PCDict(
  704. field_value, reassign_field=reassign_field, field_name=field_name
  705. )
  706. return field_value