state.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884
  1. """Define the pynecone state specification."""
  2. from __future__ import annotations
  3. import asyncio
  4. import copy
  5. import functools
  6. import traceback
  7. from abc import ABC
  8. from collections import defaultdict
  9. from typing import (
  10. Any,
  11. Callable,
  12. ClassVar,
  13. Dict,
  14. List,
  15. Optional,
  16. Sequence,
  17. Set,
  18. Type,
  19. Union,
  20. )
  21. import cloudpickle
  22. import pydantic
  23. from redis import Redis
  24. from pynecone import constants
  25. from pynecone.base import Base
  26. from pynecone.event import Event, EventHandler, fix_events, window_alert
  27. from pynecone.utils import format, prerequisites, types
  28. from pynecone.var import BaseVar, ComputedVar, PCDict, PCList, Var
  29. Delta = Dict[str, Any]
  30. class State(Base, ABC, extra=pydantic.Extra.allow):
  31. """The state of the app."""
  32. # A map from the var name to the var.
  33. vars: ClassVar[Dict[str, Var]] = {}
  34. # The base vars of the class.
  35. base_vars: ClassVar[Dict[str, BaseVar]] = {}
  36. # The computed vars of the class.
  37. computed_vars: ClassVar[Dict[str, ComputedVar]] = {}
  38. # Vars inherited by the parent state.
  39. inherited_vars: ClassVar[Dict[str, Var]] = {}
  40. # Backend vars that are never sent to the client.
  41. backend_vars: ClassVar[Dict[str, Any]] = {}
  42. # Backend vars inherited
  43. inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
  44. # The event handlers.
  45. event_handlers: ClassVar[Dict[str, EventHandler]] = {}
  46. # The parent state.
  47. parent_state: Optional[State] = None
  48. # The substates of the state.
  49. substates: Dict[str, State] = {}
  50. # The set of dirty vars.
  51. dirty_vars: Set[str] = set()
  52. # The set of dirty substates.
  53. dirty_substates: Set[str] = set()
  54. # The routing path that triggered the state
  55. router_data: Dict[str, Any] = {}
  56. # Mapping of var name to set of computed variables that depend on it
  57. computed_var_dependencies: Dict[str, Set[str]] = {}
  58. # Per-instance copy of backend variable values
  59. _backend_vars: Dict[str, Any] = {}
  60. def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
  61. """Initialize the state.
  62. Args:
  63. *args: The args to pass to the Pydantic init method.
  64. parent_state: The parent state.
  65. **kwargs: The kwargs to pass to the Pydantic init method.
  66. """
  67. kwargs["parent_state"] = parent_state
  68. super().__init__(*args, **kwargs)
  69. # Setup the substates.
  70. for substate in self.get_substates():
  71. self.substates[substate.get_name()] = substate(parent_state=self)
  72. # Convert the event handlers to functions.
  73. for name, event_handler in self.event_handlers.items():
  74. fn = functools.partial(event_handler.fn, self)
  75. fn.__module__ = event_handler.fn.__module__ # type: ignore
  76. fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
  77. setattr(self, name, fn)
  78. # Initialize computed vars dependencies.
  79. self.computed_var_dependencies = defaultdict(set)
  80. for cvar_name, cvar in self.computed_vars.items():
  81. # Add the dependencies.
  82. for var in cvar.deps():
  83. self.computed_var_dependencies[var].add(cvar_name)
  84. # Initialize the mutable fields.
  85. self._init_mutable_fields()
  86. # Create a fresh copy of the backend variables for this instance
  87. self._backend_vars = copy.deepcopy(self.backend_vars)
  88. def _init_mutable_fields(self):
  89. """Initialize mutable fields.
  90. So that mutation to them can be detected by the app:
  91. * list
  92. """
  93. for field in self.base_vars.values():
  94. value = getattr(self, field.name)
  95. value_in_pc_data = _convert_mutable_datatypes(
  96. value, self._reassign_field, field.name
  97. )
  98. if types._issubclass(field.type_, Union[List, Dict]):
  99. setattr(self, field.name, value_in_pc_data)
  100. self.clean()
  101. def _reassign_field(self, field_name: str):
  102. """Reassign the given field.
  103. Primarily for mutation in fields of mutable data types.
  104. Args:
  105. field_name: The name of the field we want to reassign
  106. """
  107. setattr(
  108. self,
  109. field_name,
  110. getattr(self, field_name),
  111. )
  112. def __repr__(self) -> str:
  113. """Get the string representation of the state.
  114. Returns:
  115. The string representation of the state.
  116. """
  117. return f"{self.__class__.__name__}({self.dict()})"
  118. @classmethod
  119. def __init_subclass__(cls, **kwargs):
  120. """Do some magic for the subclass initialization.
  121. Args:
  122. **kwargs: The kwargs to pass to the pydantic init_subclass method.
  123. """
  124. super().__init_subclass__(**kwargs)
  125. # Get the parent vars.
  126. parent_state = cls.get_parent_state()
  127. if parent_state is not None:
  128. cls.inherited_vars = parent_state.vars
  129. cls.inherited_backend_vars = parent_state.backend_vars
  130. cls.new_backend_vars = {
  131. name: value
  132. for name, value in cls.__dict__.items()
  133. if types.is_backend_variable(name)
  134. and name not in cls.inherited_backend_vars
  135. }
  136. cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
  137. # Set the base and computed vars.
  138. cls.base_vars = {
  139. f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls)
  140. for f in cls.get_fields().values()
  141. if f.name not in cls.get_skip_vars()
  142. }
  143. cls.computed_vars = {
  144. v.name: v.set_state(cls)
  145. for v in cls.__dict__.values()
  146. if isinstance(v, ComputedVar)
  147. }
  148. cls.vars = {
  149. **cls.inherited_vars,
  150. **cls.base_vars,
  151. **cls.computed_vars,
  152. }
  153. cls.event_handlers = {}
  154. # Setup the base vars at the class level.
  155. for prop in cls.base_vars.values():
  156. cls._init_var(prop)
  157. # Set up the event handlers.
  158. events = {
  159. name: fn
  160. for name, fn in cls.__dict__.items()
  161. if not name.startswith("_")
  162. and isinstance(fn, Callable)
  163. and not isinstance(fn, EventHandler)
  164. }
  165. for name, fn in events.items():
  166. handler = EventHandler(fn=fn)
  167. cls.event_handlers[name] = handler
  168. setattr(cls, name, handler)
  169. @classmethod
  170. def get_skip_vars(cls) -> Set[str]:
  171. """Get the vars to skip when serializing.
  172. Returns:
  173. The vars to skip when serializing.
  174. """
  175. return set(cls.inherited_vars) | {
  176. "parent_state",
  177. "substates",
  178. "dirty_vars",
  179. "dirty_substates",
  180. "router_data",
  181. "computed_var_dependencies",
  182. "_backend_vars",
  183. }
  184. @classmethod
  185. @functools.lru_cache()
  186. def get_parent_state(cls) -> Optional[Type[State]]:
  187. """Get the parent state.
  188. Returns:
  189. The parent state.
  190. """
  191. parent_states = [
  192. base
  193. for base in cls.__bases__
  194. if types._issubclass(base, State) and base is not State
  195. ]
  196. assert len(parent_states) < 2, "Only one parent state is allowed."
  197. return parent_states[0] if len(parent_states) == 1 else None # type: ignore
  198. @classmethod
  199. @functools.lru_cache()
  200. def get_substates(cls) -> Set[Type[State]]:
  201. """Get the substates of the state.
  202. Returns:
  203. The substates of the state.
  204. """
  205. return set(cls.__subclasses__())
  206. @classmethod
  207. @functools.lru_cache()
  208. def get_name(cls) -> str:
  209. """Get the name of the state.
  210. Returns:
  211. The name of the state.
  212. """
  213. return format.to_snake_case(cls.__name__)
  214. @classmethod
  215. @functools.lru_cache()
  216. def get_full_name(cls) -> str:
  217. """Get the full name of the state.
  218. Returns:
  219. The full name of the state.
  220. """
  221. name = cls.get_name()
  222. parent_state = cls.get_parent_state()
  223. if parent_state is not None:
  224. name = ".".join((parent_state.get_full_name(), name))
  225. return name
  226. @classmethod
  227. @functools.lru_cache()
  228. def get_class_substate(cls, path: Sequence[str]) -> Type[State]:
  229. """Get the class substate.
  230. Args:
  231. path: The path to the substate.
  232. Returns:
  233. The class substate.
  234. Raises:
  235. ValueError: If the substate is not found.
  236. """
  237. if len(path) == 0:
  238. return cls
  239. if path[0] == cls.get_name():
  240. if len(path) == 1:
  241. return cls
  242. path = path[1:]
  243. for substate in cls.get_substates():
  244. if path[0] == substate.get_name():
  245. return substate.get_class_substate(path[1:])
  246. raise ValueError(f"Invalid path: {path}")
  247. @classmethod
  248. def get_class_var(cls, path: Sequence[str]) -> Any:
  249. """Get the class var.
  250. Args:
  251. path: The path to the var.
  252. Returns:
  253. The class var.
  254. Raises:
  255. ValueError: If the path is invalid.
  256. """
  257. path, name = path[:-1], path[-1]
  258. substate = cls.get_class_substate(tuple(path))
  259. if not hasattr(substate, name):
  260. raise ValueError(f"Invalid path: {path}")
  261. return getattr(substate, name)
  262. @classmethod
  263. def _init_var(cls, prop: BaseVar):
  264. """Initialize a variable.
  265. Args:
  266. prop (BaseVar): The variable to initialize
  267. Raises:
  268. TypeError: if the variable has an incorrect type
  269. """
  270. if not types.is_valid_var_type(prop.type_):
  271. raise TypeError(
  272. "State vars must be primitive Python types, "
  273. "Plotly figures, Pandas dataframes, "
  274. "or subclasses of pc.Base. "
  275. f'Found var "{prop.name}" with type {prop.type_}.'
  276. )
  277. cls._set_var(prop)
  278. cls._create_setter(prop)
  279. cls._set_default_value(prop)
  280. @classmethod
  281. def add_var(cls, name: str, type_: Any, default_value: Any = None):
  282. """Add dynamically a variable to the State.
  283. The variable added this way can be used in the same way as a variable
  284. defined statically in the model.
  285. Args:
  286. name: The name of the variable
  287. type_: The type of the variable
  288. default_value: The default value of the variable
  289. Raises:
  290. NameError: if a variable of this name already exists
  291. """
  292. if name in cls.__fields__:
  293. raise NameError(
  294. f"The variable '{name}' already exist. Use a different name"
  295. )
  296. # create the variable based on name and type
  297. var = BaseVar(name=name, type_=type_)
  298. var.set_state(cls)
  299. # add the pydantic field dynamically (must be done before _init_var)
  300. cls.add_field(var, default_value)
  301. cls._init_var(var)
  302. # update the internal dicts so the new variable is correctly handled
  303. cls.base_vars.update({name: var})
  304. cls.vars.update({name: var})
  305. @classmethod
  306. def _set_var(cls, prop: BaseVar):
  307. """Set the var as a class member.
  308. Args:
  309. prop: The var instance to set.
  310. """
  311. setattr(cls, prop.name, prop)
  312. @classmethod
  313. def _create_setter(cls, prop: BaseVar):
  314. """Create a setter for the var.
  315. Args:
  316. prop: The var to create a setter for.
  317. """
  318. setter_name = prop.get_setter_name(include_state=False)
  319. if setter_name not in cls.__dict__:
  320. event_handler = EventHandler(fn=prop.get_setter())
  321. cls.event_handlers[setter_name] = event_handler
  322. setattr(cls, setter_name, event_handler)
  323. @classmethod
  324. def _set_default_value(cls, prop: BaseVar):
  325. """Set the default value for the var.
  326. Args:
  327. prop: The var to set the default value for.
  328. """
  329. # Get the pydantic field for the var.
  330. field = cls.get_fields()[prop.name]
  331. default_value = prop.get_default_value()
  332. if field.required and default_value is not None:
  333. field.required = False
  334. field.default = default_value
  335. def get_token(self) -> str:
  336. """Return the token of the client associated with this state.
  337. Returns:
  338. The token of the client.
  339. """
  340. return self.router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
  341. def get_sid(self) -> str:
  342. """Return the session ID of the client associated with this state.
  343. Returns:
  344. The session ID of the client.
  345. """
  346. return self.router_data.get(constants.RouteVar.SESSION_ID, "")
  347. def get_headers(self) -> Dict:
  348. """Return the headers of the client associated with this state.
  349. Returns:
  350. The headers of the client.
  351. """
  352. return self.router_data.get(constants.RouteVar.HEADERS, {})
  353. def get_client_ip(self) -> str:
  354. """Return the IP of the client associated with this state.
  355. Returns:
  356. The IP of the client.
  357. """
  358. return self.router_data.get(constants.RouteVar.CLIENT_IP, "")
  359. def get_current_page(self) -> str:
  360. """Obtain the path of current page from the router data.
  361. Returns:
  362. The current page.
  363. """
  364. return self.router_data.get(constants.RouteVar.PATH, "")
  365. def get_query_params(self) -> Dict[str, str]:
  366. """Obtain the query parameters for the queried page.
  367. The query object contains both the URI parameters and the GET parameters.
  368. Returns:
  369. The dict of query parameters.
  370. """
  371. return self.router_data.get(constants.RouteVar.QUERY, {})
  372. @classmethod
  373. def setup_dynamic_args(cls, args: dict[str, str]):
  374. """Set up args for easy access in renderer.
  375. Args:
  376. args: a dict of args
  377. """
  378. def argsingle_factory(param):
  379. @ComputedVar
  380. def inner_func(self) -> str:
  381. return self.get_query_params().get(param, "")
  382. return inner_func
  383. def arglist_factory(param):
  384. @ComputedVar
  385. def inner_func(self) -> List:
  386. return self.get_query_params().get(param, [])
  387. return inner_func
  388. for param, value in args.items():
  389. if value == constants.RouteArgType.SINGLE:
  390. func = argsingle_factory(param)
  391. elif value == constants.RouteArgType.LIST:
  392. func = arglist_factory(param)
  393. else:
  394. continue
  395. # link dynamically created ComputedVar to this state class for dep determination
  396. func.__objclass__ = cls
  397. func.fget.__name__ = param
  398. cls.vars[param] = cls.computed_vars[param] = func.set_state(cls) # type: ignore
  399. setattr(cls, param, func)
  400. def __getattribute__(self, name: str) -> Any:
  401. """Get the state var.
  402. If the var is inherited, get the var from the parent state.
  403. Args:
  404. name: The name of the var.
  405. Returns:
  406. The value of the var.
  407. """
  408. # If the state hasn't been initialized yet, return the default value.
  409. if not super().__getattribute__("__dict__"):
  410. return super().__getattribute__(name)
  411. inherited_vars = {
  412. **super().__getattribute__("inherited_vars"),
  413. **super().__getattribute__("inherited_backend_vars"),
  414. }
  415. if name in inherited_vars:
  416. return getattr(super().__getattribute__("parent_state"), name)
  417. elif name in super().__getattribute__("_backend_vars"):
  418. return super().__getattribute__("_backend_vars").__getitem__(name)
  419. return super().__getattribute__(name)
  420. def __setattr__(self, name: str, value: Any):
  421. """Set the attribute.
  422. If the attribute is inherited, set the attribute on the parent state.
  423. Args:
  424. name: The name of the attribute.
  425. value: The value of the attribute.
  426. """
  427. # Set the var on the parent state.
  428. inherited_vars = {**self.inherited_vars, **self.inherited_backend_vars}
  429. if name in inherited_vars:
  430. setattr(self.parent_state, name, value)
  431. return
  432. if types.is_backend_variable(name) and name != "_backend_vars":
  433. self._backend_vars.__setitem__(name, value)
  434. self.dirty_vars.add(name)
  435. self.mark_dirty()
  436. return
  437. # Set the attribute.
  438. super().__setattr__(name, value)
  439. # Add the var to the dirty list.
  440. if name in self.vars or name in self.computed_var_dependencies:
  441. self.dirty_vars.add(name)
  442. self.mark_dirty()
  443. def reset(self):
  444. """Reset all the base vars to their default values."""
  445. # Reset the base vars.
  446. fields = self.get_fields()
  447. for prop_name in self.base_vars:
  448. setattr(self, prop_name, fields[prop_name].default)
  449. # Recursively reset the substates.
  450. for substate in self.substates.values():
  451. substate.reset()
  452. # Clean the state.
  453. self.clean()
  454. def get_substate(self, path: Sequence[str]) -> Optional[State]:
  455. """Get the substate.
  456. Args:
  457. path: The path to the substate.
  458. Returns:
  459. The substate.
  460. Raises:
  461. ValueError: If the substate is not found.
  462. """
  463. if len(path) == 0:
  464. return self
  465. if path[0] == self.get_name():
  466. if len(path) == 1:
  467. return self
  468. path = path[1:]
  469. if path[0] not in self.substates:
  470. raise ValueError(f"Invalid path: {path}")
  471. return self.substates[path[0]].get_substate(path[1:])
  472. async def _process(self, event: Event) -> StateUpdate:
  473. """Obtain event info and process event.
  474. Args:
  475. event: The event to process.
  476. Returns:
  477. The state update after processing the event.
  478. Raises:
  479. ValueError: If the state value is None.
  480. """
  481. # Get the event handler.
  482. path = event.name.split(".")
  483. path, name = path[:-1], path[-1]
  484. substate = self.get_substate(path)
  485. handler = substate.event_handlers[name] # type: ignore
  486. if not substate:
  487. raise ValueError(
  488. "The value of state cannot be None when processing an event."
  489. )
  490. return await self._process_event(
  491. handler=handler,
  492. state=substate,
  493. payload=event.payload,
  494. token=event.token,
  495. )
  496. async def _process_event(
  497. self, handler: EventHandler, state: State, payload: Dict, token: str
  498. ) -> StateUpdate:
  499. """Process event.
  500. Args:
  501. handler: Eventhandler to process.
  502. state: State to process the handler.
  503. payload: The event payload.
  504. token: Client token.
  505. Returns:
  506. The state update after processing the event.
  507. """
  508. fn = functools.partial(handler.fn, state)
  509. try:
  510. if asyncio.iscoroutinefunction(fn.func):
  511. events = await fn(**payload)
  512. else:
  513. events = fn(**payload)
  514. except Exception:
  515. error = traceback.format_exc()
  516. print(error)
  517. events = fix_events(
  518. [window_alert("An error occurred. See logs for details.")], token
  519. )
  520. return StateUpdate(events=events)
  521. # Fix the returned events.
  522. events = fix_events(events, token)
  523. # Get the delta after processing the event.
  524. delta = self.get_delta()
  525. # Reset the dirty vars.
  526. self.clean()
  527. # Return the state update.
  528. return StateUpdate(delta=delta, events=events)
  529. def _mark_dirty_computed_vars(self) -> None:
  530. """Mark ComputedVars that need to be recalculated based on dirty_vars."""
  531. # Mark all ComputedVars as dirty.
  532. for cvar in self.computed_vars.values():
  533. cvar.mark_dirty(instance=self)
  534. # TODO: Uncomment the actual implementation below.
  535. # dirty_vars = self.dirty_vars
  536. # while dirty_vars:
  537. # calc_vars, dirty_vars = dirty_vars, set()
  538. # for cvar in self._dirty_computed_vars(from_vars=calc_vars):
  539. # self.dirty_vars.add(cvar)
  540. # dirty_vars.add(cvar)
  541. # actual_var = self.computed_vars.get(cvar)
  542. # if actual_var:
  543. # actual_var.mark_dirty(instance=self)
  544. def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
  545. """Determine ComputedVars that need to be recalculated based on the given vars.
  546. Args:
  547. from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
  548. Returns:
  549. Set of computed vars to include in the delta.
  550. """
  551. return set(self.computed_vars)
  552. # TODO: Uncomment the actual implementation below.
  553. # return set(
  554. # cvar
  555. # for dirty_var in from_vars or self.dirty_vars
  556. # for cvar in self.computed_var_dependencies[dirty_var]
  557. # )
  558. def get_delta(self) -> Delta:
  559. """Get the delta for the state.
  560. Returns:
  561. The delta for the state.
  562. """
  563. delta = {}
  564. # Recursively find the substate deltas.
  565. substates = self.substates
  566. for substate in self.dirty_substates:
  567. delta.update(substates[substate].get_delta())
  568. # Return the dirty vars and dependent computed vars
  569. self._mark_dirty_computed_vars()
  570. delta_vars = self.dirty_vars.intersection(self.base_vars).union(
  571. self._dirty_computed_vars()
  572. )
  573. subdelta = {
  574. prop: getattr(self, prop)
  575. for prop in delta_vars
  576. if not types.is_backend_variable(prop)
  577. }
  578. if len(subdelta) > 0:
  579. delta[self.get_full_name()] = subdelta
  580. # Format the delta.
  581. delta = format.format_state(delta)
  582. # Return the delta.
  583. return delta
  584. def mark_dirty(self):
  585. """Mark the substate and all parent states as dirty."""
  586. if self.parent_state is not None:
  587. self.parent_state.dirty_substates.add(self.get_name())
  588. self.parent_state.mark_dirty()
  589. # have to mark computed vars dirty to allow access to newly computed
  590. # values within the same ComputedVar function
  591. self._mark_dirty_computed_vars()
  592. def clean(self):
  593. """Reset the dirty vars."""
  594. # Recursively clean the substates.
  595. for substate in self.dirty_substates:
  596. self.substates[substate].clean()
  597. # Clean this state.
  598. self.dirty_vars = set()
  599. self.dirty_substates = set()
  600. def dict(self, include_computed: bool = True, **kwargs) -> Dict[str, Any]:
  601. """Convert the object to a dictionary.
  602. Args:
  603. include_computed: Whether to include computed vars.
  604. **kwargs: Kwargs to pass to the pydantic dict method.
  605. Returns:
  606. The object as a dictionary.
  607. """
  608. base_vars = {
  609. prop_name: self.get_value(getattr(self, prop_name))
  610. for prop_name in self.base_vars
  611. }
  612. computed_vars = (
  613. {
  614. # Include the computed vars.
  615. prop_name: self.get_value(getattr(self, prop_name))
  616. for prop_name in self.computed_vars
  617. }
  618. if include_computed
  619. else {}
  620. )
  621. substate_vars = {
  622. k: v.dict(include_computed=include_computed, **kwargs)
  623. for k, v in self.substates.items()
  624. }
  625. variables = {**base_vars, **computed_vars, **substate_vars}
  626. return {k: variables[k] for k in sorted(variables)}
  627. class DefaultState(State):
  628. """The default empty state."""
  629. pass
  630. class StateUpdate(Base):
  631. """A state update sent to the frontend."""
  632. # The state delta.
  633. delta: Delta = {}
  634. # Events to be added to the event queue.
  635. events: List[Event] = []
  636. class StateManager(Base):
  637. """A class to manage many client states."""
  638. # The state class to use.
  639. state: Type[State] = DefaultState
  640. # The mapping of client ids to states.
  641. states: Dict[str, State] = {}
  642. # The token expiration time (s).
  643. token_expiration: int = constants.TOKEN_EXPIRATION
  644. # The redis client to use.
  645. redis: Optional[Redis] = None
  646. def setup(self, state: Type[State]):
  647. """Set up the state manager.
  648. Args:
  649. state: The state class to use.
  650. """
  651. self.state = state
  652. self.redis = prerequisites.get_redis()
  653. def get_state(self, token: str) -> State:
  654. """Get the state for a token.
  655. Args:
  656. token: The token to get the state for.
  657. Returns:
  658. The state for the token.
  659. """
  660. if self.redis is not None:
  661. redis_state = self.redis.get(token)
  662. if redis_state is None:
  663. self.set_state(token, self.state())
  664. return self.get_state(token)
  665. return cloudpickle.loads(redis_state)
  666. if token not in self.states:
  667. self.states[token] = self.state()
  668. return self.states[token]
  669. def set_state(self, token: str, state: State):
  670. """Set the state for a token.
  671. Args:
  672. token: The token to set the state for.
  673. state: The state to set.
  674. """
  675. if self.redis is None:
  676. return
  677. self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
  678. def _convert_mutable_datatypes(
  679. field_value: Any, reassign_field: Callable, field_name: str
  680. ) -> Any:
  681. """Recursively convert mutable data to the Pc data types.
  682. Note: right now only list & dict would be handled recursively.
  683. Args:
  684. field_value: The target field_value.
  685. reassign_field:
  686. The function to reassign the field in the parent state.
  687. field_name: the name of the field in the parent state
  688. Returns:
  689. The converted field_value
  690. """
  691. if isinstance(field_value, list):
  692. for index in range(len(field_value)):
  693. field_value[index] = _convert_mutable_datatypes(
  694. field_value[index], reassign_field, field_name
  695. )
  696. field_value = PCList(
  697. field_value, reassign_field=reassign_field, field_name=field_name
  698. )
  699. if isinstance(field_value, dict):
  700. for key, value in field_value.items():
  701. field_value[key] = _convert_mutable_datatypes(
  702. value, reassign_field, field_name
  703. )
  704. field_value = PCDict(
  705. field_value, reassign_field=reassign_field, field_name=field_name
  706. )
  707. return field_value