state.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945
  1. """Define the pynecone state specification."""
  2. from __future__ import annotations
  3. import asyncio
  4. import copy
  5. import functools
  6. import traceback
  7. from abc import ABC
  8. from collections import defaultdict
  9. from typing import (
  10. Any,
  11. Callable,
  12. ClassVar,
  13. Dict,
  14. List,
  15. Optional,
  16. Sequence,
  17. Set,
  18. Type,
  19. Union,
  20. )
  21. import cloudpickle
  22. import pydantic
  23. from redis import Redis
  24. from pynecone import constants
  25. from pynecone.base import Base
  26. from pynecone.event import Event, EventHandler, fix_events, window_alert
  27. from pynecone.utils import format, prerequisites, types
  28. from pynecone.vars import BaseVar, ComputedVar, PCDict, PCList, Var
  29. Delta = Dict[str, Any]
  30. class State(Base, ABC, extra=pydantic.Extra.allow):
  31. """The state of the app."""
  32. # A map from the var name to the var.
  33. vars: ClassVar[Dict[str, Var]] = {}
  34. # The base vars of the class.
  35. base_vars: ClassVar[Dict[str, BaseVar]] = {}
  36. # The computed vars of the class.
  37. computed_vars: ClassVar[Dict[str, ComputedVar]] = {}
  38. # Vars inherited by the parent state.
  39. inherited_vars: ClassVar[Dict[str, Var]] = {}
  40. # Backend vars that are never sent to the client.
  41. backend_vars: ClassVar[Dict[str, Any]] = {}
  42. # Backend vars inherited
  43. inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
  44. # The event handlers.
  45. event_handlers: ClassVar[Dict[str, EventHandler]] = {}
  46. # The parent state.
  47. parent_state: Optional[State] = None
  48. # The substates of the state.
  49. substates: Dict[str, State] = {}
  50. # The set of dirty vars.
  51. dirty_vars: Set[str] = set()
  52. # The set of dirty substates.
  53. dirty_substates: Set[str] = set()
  54. # The routing path that triggered the state
  55. router_data: Dict[str, Any] = {}
  56. # Mapping of var name to set of computed variables that depend on it
  57. computed_var_dependencies: Dict[str, Set[str]] = {}
  58. # Mapping of var name to set of substates that depend on it
  59. substate_var_dependencies: Dict[str, Set[str]] = {}
  60. # Per-instance copy of backend variable values
  61. _backend_vars: Dict[str, Any] = {}
  62. def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
  63. """Initialize the state.
  64. Args:
  65. *args: The args to pass to the Pydantic init method.
  66. parent_state: The parent state.
  67. **kwargs: The kwargs to pass to the Pydantic init method.
  68. """
  69. kwargs["parent_state"] = parent_state
  70. super().__init__(*args, **kwargs)
  71. # initialize per-instance var dependency tracking
  72. self.computed_var_dependencies = defaultdict(set)
  73. self.substate_var_dependencies = defaultdict(set)
  74. # Setup the substates.
  75. for substate in self.get_substates():
  76. self.substates[substate.get_name()] = substate(parent_state=self)
  77. # Convert the event handlers to functions.
  78. for name, event_handler in self.event_handlers.items():
  79. fn = functools.partial(event_handler.fn, self)
  80. fn.__module__ = event_handler.fn.__module__ # type: ignore
  81. fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
  82. setattr(self, name, fn)
  83. # Initialize computed vars dependencies.
  84. inherited_vars = set(self.inherited_vars).union(
  85. set(self.inherited_backend_vars),
  86. )
  87. for cvar_name, cvar in self.computed_vars.items():
  88. # Add the dependencies.
  89. for var in cvar.deps(objclass=type(self)):
  90. self.computed_var_dependencies[var].add(cvar_name)
  91. if var in inherited_vars:
  92. # track that this substate depends on its parent for this var
  93. state_name = self.get_name()
  94. parent_state = self.parent_state
  95. while parent_state is not None and var in parent_state.vars:
  96. parent_state.substate_var_dependencies[var].add(state_name)
  97. state_name, parent_state = (
  98. parent_state.get_name(),
  99. parent_state.parent_state,
  100. )
  101. # Initialize the mutable fields.
  102. self._init_mutable_fields()
  103. # Create a fresh copy of the backend variables for this instance
  104. self._backend_vars = copy.deepcopy(self.backend_vars)
  105. def _init_mutable_fields(self):
  106. """Initialize mutable fields.
  107. So that mutation to them can be detected by the app:
  108. * list
  109. """
  110. for field in self.base_vars.values():
  111. value = getattr(self, field.name)
  112. value_in_pc_data = _convert_mutable_datatypes(
  113. value, self._reassign_field, field.name
  114. )
  115. if types._issubclass(field.type_, Union[List, Dict]):
  116. setattr(self, field.name, value_in_pc_data)
  117. self.clean()
  118. def _reassign_field(self, field_name: str):
  119. """Reassign the given field.
  120. Primarily for mutation in fields of mutable data types.
  121. Args:
  122. field_name: The name of the field we want to reassign
  123. """
  124. setattr(
  125. self,
  126. field_name,
  127. getattr(self, field_name),
  128. )
  129. def __repr__(self) -> str:
  130. """Get the string representation of the state.
  131. Returns:
  132. The string representation of the state.
  133. """
  134. return f"{self.__class__.__name__}({self.dict()})"
  135. @classmethod
  136. def __init_subclass__(cls, **kwargs):
  137. """Do some magic for the subclass initialization.
  138. Args:
  139. **kwargs: The kwargs to pass to the pydantic init_subclass method.
  140. """
  141. super().__init_subclass__(**kwargs)
  142. # Get the parent vars.
  143. parent_state = cls.get_parent_state()
  144. if parent_state is not None:
  145. cls.inherited_vars = parent_state.vars
  146. cls.inherited_backend_vars = parent_state.backend_vars
  147. cls.new_backend_vars = {
  148. name: value
  149. for name, value in cls.__dict__.items()
  150. if types.is_backend_variable(name)
  151. and name not in cls.inherited_backend_vars
  152. }
  153. cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
  154. # Set the base and computed vars.
  155. cls.base_vars = {
  156. f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls)
  157. for f in cls.get_fields().values()
  158. if f.name not in cls.get_skip_vars()
  159. }
  160. cls.computed_vars = {
  161. v.name: v.set_state(cls)
  162. for v in cls.__dict__.values()
  163. if isinstance(v, ComputedVar)
  164. }
  165. cls.vars = {
  166. **cls.inherited_vars,
  167. **cls.base_vars,
  168. **cls.computed_vars,
  169. }
  170. cls.event_handlers = {}
  171. # Setup the base vars at the class level.
  172. for prop in cls.base_vars.values():
  173. cls._init_var(prop)
  174. # Set up the event handlers.
  175. events = {
  176. name: fn
  177. for name, fn in cls.__dict__.items()
  178. if not name.startswith("_")
  179. and isinstance(fn, Callable)
  180. and not isinstance(fn, EventHandler)
  181. }
  182. for name, fn in events.items():
  183. handler = EventHandler(fn=fn)
  184. cls.event_handlers[name] = handler
  185. setattr(cls, name, handler)
  186. @classmethod
  187. def get_skip_vars(cls) -> Set[str]:
  188. """Get the vars to skip when serializing.
  189. Returns:
  190. The vars to skip when serializing.
  191. """
  192. return set(cls.inherited_vars) | {
  193. "parent_state",
  194. "substates",
  195. "dirty_vars",
  196. "dirty_substates",
  197. "router_data",
  198. "computed_var_dependencies",
  199. "substate_var_dependencies",
  200. "_backend_vars",
  201. }
  202. @classmethod
  203. @functools.lru_cache()
  204. def get_parent_state(cls) -> Optional[Type[State]]:
  205. """Get the parent state.
  206. Returns:
  207. The parent state.
  208. """
  209. parent_states = [
  210. base
  211. for base in cls.__bases__
  212. if types._issubclass(base, State) and base is not State
  213. ]
  214. assert len(parent_states) < 2, "Only one parent state is allowed."
  215. return parent_states[0] if len(parent_states) == 1 else None # type: ignore
  216. @classmethod
  217. @functools.lru_cache()
  218. def get_substates(cls) -> Set[Type[State]]:
  219. """Get the substates of the state.
  220. Returns:
  221. The substates of the state.
  222. """
  223. return set(cls.__subclasses__())
  224. @classmethod
  225. @functools.lru_cache()
  226. def get_name(cls) -> str:
  227. """Get the name of the state.
  228. Returns:
  229. The name of the state.
  230. """
  231. return format.to_snake_case(cls.__name__)
  232. @classmethod
  233. @functools.lru_cache()
  234. def get_full_name(cls) -> str:
  235. """Get the full name of the state.
  236. Returns:
  237. The full name of the state.
  238. """
  239. name = cls.get_name()
  240. parent_state = cls.get_parent_state()
  241. if parent_state is not None:
  242. name = ".".join((parent_state.get_full_name(), name))
  243. return name
  244. @classmethod
  245. @functools.lru_cache()
  246. def get_class_substate(cls, path: Sequence[str]) -> Type[State]:
  247. """Get the class substate.
  248. Args:
  249. path: The path to the substate.
  250. Returns:
  251. The class substate.
  252. Raises:
  253. ValueError: If the substate is not found.
  254. """
  255. if len(path) == 0:
  256. return cls
  257. if path[0] == cls.get_name():
  258. if len(path) == 1:
  259. return cls
  260. path = path[1:]
  261. for substate in cls.get_substates():
  262. if path[0] == substate.get_name():
  263. return substate.get_class_substate(path[1:])
  264. raise ValueError(f"Invalid path: {path}")
  265. @classmethod
  266. def get_class_var(cls, path: Sequence[str]) -> Any:
  267. """Get the class var.
  268. Args:
  269. path: The path to the var.
  270. Returns:
  271. The class var.
  272. Raises:
  273. ValueError: If the path is invalid.
  274. """
  275. path, name = path[:-1], path[-1]
  276. substate = cls.get_class_substate(tuple(path))
  277. if not hasattr(substate, name):
  278. raise ValueError(f"Invalid path: {path}")
  279. return getattr(substate, name)
  280. @classmethod
  281. def _init_var(cls, prop: BaseVar):
  282. """Initialize a variable.
  283. Args:
  284. prop (BaseVar): The variable to initialize
  285. Raises:
  286. TypeError: if the variable has an incorrect type
  287. """
  288. if not types.is_valid_var_type(prop.type_):
  289. raise TypeError(
  290. "State vars must be primitive Python types, "
  291. "Plotly figures, Pandas dataframes, "
  292. "or subclasses of pc.Base. "
  293. f'Found var "{prop.name}" with type {prop.type_}.'
  294. )
  295. cls._set_var(prop)
  296. cls._create_setter(prop)
  297. cls._set_default_value(prop)
  298. @classmethod
  299. def add_var(cls, name: str, type_: Any, default_value: Any = None):
  300. """Add dynamically a variable to the State.
  301. The variable added this way can be used in the same way as a variable
  302. defined statically in the model.
  303. Args:
  304. name: The name of the variable
  305. type_: The type of the variable
  306. default_value: The default value of the variable
  307. Raises:
  308. NameError: if a variable of this name already exists
  309. """
  310. if name in cls.__fields__:
  311. raise NameError(
  312. f"The variable '{name}' already exist. Use a different name"
  313. )
  314. # create the variable based on name and type
  315. var = BaseVar(name=name, type_=type_)
  316. var.set_state(cls)
  317. # add the pydantic field dynamically (must be done before _init_var)
  318. cls.add_field(var, default_value)
  319. cls._init_var(var)
  320. # update the internal dicts so the new variable is correctly handled
  321. cls.base_vars.update({name: var})
  322. cls.vars.update({name: var})
  323. # let substates know about the new variable
  324. for substate_class in cls.__subclasses__():
  325. substate_class.vars.setdefault(name, var)
  326. @classmethod
  327. def _set_var(cls, prop: BaseVar):
  328. """Set the var as a class member.
  329. Args:
  330. prop: The var instance to set.
  331. """
  332. setattr(cls, prop.name, prop)
  333. @classmethod
  334. def _create_setter(cls, prop: BaseVar):
  335. """Create a setter for the var.
  336. Args:
  337. prop: The var to create a setter for.
  338. """
  339. setter_name = prop.get_setter_name(include_state=False)
  340. if setter_name not in cls.__dict__:
  341. event_handler = EventHandler(fn=prop.get_setter())
  342. cls.event_handlers[setter_name] = event_handler
  343. setattr(cls, setter_name, event_handler)
  344. @classmethod
  345. def _set_default_value(cls, prop: BaseVar):
  346. """Set the default value for the var.
  347. Args:
  348. prop: The var to set the default value for.
  349. """
  350. # Get the pydantic field for the var.
  351. field = cls.get_fields()[prop.name]
  352. default_value = prop.get_default_value()
  353. if field.required and default_value is not None:
  354. field.required = False
  355. field.default = default_value
  356. def get_token(self) -> str:
  357. """Return the token of the client associated with this state.
  358. Returns:
  359. The token of the client.
  360. """
  361. return self.router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
  362. def get_sid(self) -> str:
  363. """Return the session ID of the client associated with this state.
  364. Returns:
  365. The session ID of the client.
  366. """
  367. return self.router_data.get(constants.RouteVar.SESSION_ID, "")
  368. def get_headers(self) -> Dict:
  369. """Return the headers of the client associated with this state.
  370. Returns:
  371. The headers of the client.
  372. """
  373. return self.router_data.get(constants.RouteVar.HEADERS, {})
  374. def get_client_ip(self) -> str:
  375. """Return the IP of the client associated with this state.
  376. Returns:
  377. The IP of the client.
  378. """
  379. return self.router_data.get(constants.RouteVar.CLIENT_IP, "")
  380. def get_current_page(self) -> str:
  381. """Obtain the path of current page from the router data.
  382. Returns:
  383. The current page.
  384. """
  385. return self.router_data.get(constants.RouteVar.PATH, "")
  386. def get_query_params(self) -> Dict[str, str]:
  387. """Obtain the query parameters for the queried page.
  388. The query object contains both the URI parameters and the GET parameters.
  389. Returns:
  390. The dict of query parameters.
  391. """
  392. return self.router_data.get(constants.RouteVar.QUERY, {})
  393. @classmethod
  394. def setup_dynamic_args(cls, args: dict[str, str]):
  395. """Set up args for easy access in renderer.
  396. Args:
  397. args: a dict of args
  398. """
  399. def argsingle_factory(param):
  400. @ComputedVar
  401. def inner_func(self) -> str:
  402. return self.get_query_params().get(param, "")
  403. return inner_func
  404. def arglist_factory(param):
  405. @ComputedVar
  406. def inner_func(self) -> List:
  407. return self.get_query_params().get(param, [])
  408. return inner_func
  409. for param, value in args.items():
  410. if value == constants.RouteArgType.SINGLE:
  411. func = argsingle_factory(param)
  412. elif value == constants.RouteArgType.LIST:
  413. func = arglist_factory(param)
  414. else:
  415. continue
  416. func.fget.__name__ = param # to allow passing as a prop
  417. cls.vars[param] = cls.computed_vars[param] = func.set_state(cls) # type: ignore
  418. setattr(cls, param, func)
  419. def __getattribute__(self, name: str) -> Any:
  420. """Get the state var.
  421. If the var is inherited, get the var from the parent state.
  422. Args:
  423. name: The name of the var.
  424. Returns:
  425. The value of the var.
  426. """
  427. # If the state hasn't been initialized yet, return the default value.
  428. if not super().__getattribute__("__dict__"):
  429. return super().__getattribute__(name)
  430. inherited_vars = {
  431. **super().__getattribute__("inherited_vars"),
  432. **super().__getattribute__("inherited_backend_vars"),
  433. }
  434. if name in inherited_vars:
  435. return getattr(super().__getattribute__("parent_state"), name)
  436. elif name in super().__getattribute__("_backend_vars"):
  437. return super().__getattribute__("_backend_vars").__getitem__(name)
  438. return super().__getattribute__(name)
  439. def __setattr__(self, name: str, value: Any):
  440. """Set the attribute.
  441. If the attribute is inherited, set the attribute on the parent state.
  442. Args:
  443. name: The name of the attribute.
  444. value: The value of the attribute.
  445. """
  446. # Set the var on the parent state.
  447. inherited_vars = {**self.inherited_vars, **self.inherited_backend_vars}
  448. if name in inherited_vars:
  449. setattr(self.parent_state, name, value)
  450. return
  451. if types.is_backend_variable(name) and name != "_backend_vars":
  452. self._backend_vars.__setitem__(name, value)
  453. self.dirty_vars.add(name)
  454. self.mark_dirty()
  455. return
  456. # Set the attribute.
  457. super().__setattr__(name, value)
  458. # Add the var to the dirty list.
  459. if name in self.vars or name in self.computed_var_dependencies:
  460. self.dirty_vars.add(name)
  461. self.mark_dirty()
  462. # For now, handle router_data updates as a special case
  463. if name == constants.ROUTER_DATA:
  464. self.dirty_vars.add(name)
  465. self.mark_dirty()
  466. # propagate router_data updates down the state tree
  467. for substate in self.substates.values():
  468. setattr(substate, name, value)
  469. def reset(self):
  470. """Reset all the base vars to their default values."""
  471. # Reset the base vars.
  472. fields = self.get_fields()
  473. for prop_name in self.base_vars:
  474. setattr(self, prop_name, fields[prop_name].default)
  475. # Recursively reset the substates.
  476. for substate in self.substates.values():
  477. substate.reset()
  478. # Clean the state.
  479. self.clean()
  480. def get_substate(self, path: Sequence[str]) -> Optional[State]:
  481. """Get the substate.
  482. Args:
  483. path: The path to the substate.
  484. Returns:
  485. The substate.
  486. Raises:
  487. ValueError: If the substate is not found.
  488. """
  489. if len(path) == 0:
  490. return self
  491. if path[0] == self.get_name():
  492. if len(path) == 1:
  493. return self
  494. path = path[1:]
  495. if path[0] not in self.substates:
  496. raise ValueError(f"Invalid path: {path}")
  497. return self.substates[path[0]].get_substate(path[1:])
  498. async def _process(self, event: Event) -> StateUpdate:
  499. """Obtain event info and process event.
  500. Args:
  501. event: The event to process.
  502. Returns:
  503. The state update after processing the event.
  504. Raises:
  505. ValueError: If the state value is None.
  506. """
  507. # Get the event handler.
  508. path = event.name.split(".")
  509. path, name = path[:-1], path[-1]
  510. substate = self.get_substate(path)
  511. handler = substate.event_handlers[name] # type: ignore
  512. if not substate:
  513. raise ValueError(
  514. "The value of state cannot be None when processing an event."
  515. )
  516. return await self._process_event(
  517. handler=handler,
  518. state=substate,
  519. payload=event.payload,
  520. token=event.token,
  521. )
  522. async def _process_event(
  523. self, handler: EventHandler, state: State, payload: Dict, token: str
  524. ) -> StateUpdate:
  525. """Process event.
  526. Args:
  527. handler: Eventhandler to process.
  528. state: State to process the handler.
  529. payload: The event payload.
  530. token: Client token.
  531. Returns:
  532. The state update after processing the event.
  533. """
  534. fn = functools.partial(handler.fn, state)
  535. try:
  536. if asyncio.iscoroutinefunction(fn.func):
  537. events = await fn(**payload)
  538. else:
  539. events = fn(**payload)
  540. except Exception:
  541. error = traceback.format_exc()
  542. print(error)
  543. events = fix_events(
  544. [window_alert("An error occurred. See logs for details.")], token
  545. )
  546. return StateUpdate(events=events)
  547. # Fix the returned events.
  548. events = fix_events(events, token)
  549. # Get the delta after processing the event.
  550. delta = self.get_delta()
  551. # Reset the dirty vars.
  552. self.clean()
  553. # Return the state update.
  554. return StateUpdate(delta=delta, events=events)
  555. def _always_dirty_computed_vars(self) -> Set[str]:
  556. """The set of ComputedVars that always need to be recalculated.
  557. Returns:
  558. Set of all ComputedVar in this state where cache=False
  559. """
  560. return set(
  561. cvar_name
  562. for cvar_name, cvar in self.computed_vars.items()
  563. if not cvar.cache
  564. )
  565. def _mark_dirty_computed_vars(self) -> None:
  566. """Mark ComputedVars that need to be recalculated based on dirty_vars."""
  567. dirty_vars = self.dirty_vars
  568. while dirty_vars:
  569. calc_vars, dirty_vars = dirty_vars, set()
  570. for cvar in self._dirty_computed_vars(from_vars=calc_vars):
  571. self.dirty_vars.add(cvar)
  572. dirty_vars.add(cvar)
  573. actual_var = self.computed_vars.get(cvar)
  574. if actual_var:
  575. actual_var.mark_dirty(instance=self)
  576. def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
  577. """Determine ComputedVars that need to be recalculated based on the given vars.
  578. Args:
  579. from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
  580. Returns:
  581. Set of computed vars to include in the delta.
  582. """
  583. return set(
  584. cvar
  585. for dirty_var in from_vars or self.dirty_vars
  586. for cvar in self.computed_var_dependencies[dirty_var]
  587. )
  588. def get_delta(self) -> Delta:
  589. """Get the delta for the state.
  590. Returns:
  591. The delta for the state.
  592. """
  593. delta = {}
  594. # Apply dirty variables down into substates
  595. self.dirty_vars.update(self._always_dirty_computed_vars())
  596. self.mark_dirty()
  597. # Return the dirty vars for this instance, any cached/dependent computed vars,
  598. # and always dirty computed vars (cache=False)
  599. delta_vars = (
  600. self.dirty_vars.intersection(self.base_vars)
  601. .union(self._dirty_computed_vars())
  602. .union(self._always_dirty_computed_vars())
  603. )
  604. subdelta = {
  605. prop: getattr(self, prop)
  606. for prop in delta_vars
  607. if not types.is_backend_variable(prop)
  608. }
  609. if len(subdelta) > 0:
  610. delta[self.get_full_name()] = subdelta
  611. # Recursively find the substate deltas.
  612. substates = self.substates
  613. for substate in self.dirty_substates:
  614. delta.update(substates[substate].get_delta())
  615. # Format the delta.
  616. delta = format.format_state(delta)
  617. # Return the delta.
  618. return delta
  619. def mark_dirty(self):
  620. """Mark the substate and all parent states as dirty."""
  621. state_name = self.get_name()
  622. if (
  623. self.parent_state is not None
  624. and state_name not in self.parent_state.dirty_substates
  625. ):
  626. self.parent_state.dirty_substates.add(self.get_name())
  627. self.parent_state.mark_dirty()
  628. # have to mark computed vars dirty to allow access to newly computed
  629. # values within the same ComputedVar function
  630. self._mark_dirty_computed_vars()
  631. # Propagate dirty var / computed var status into substates
  632. substates = self.substates
  633. for var in self.dirty_vars:
  634. for substate_name in self.substate_var_dependencies[var]:
  635. self.dirty_substates.add(substate_name)
  636. substate = substates[substate_name]
  637. substate.dirty_vars.add(var)
  638. substate.mark_dirty()
  639. def clean(self):
  640. """Reset the dirty vars."""
  641. # Recursively clean the substates.
  642. for substate in self.dirty_substates:
  643. self.substates[substate].clean()
  644. # Clean this state.
  645. self.dirty_vars = set()
  646. self.dirty_substates = set()
  647. def dict(self, include_computed: bool = True, **kwargs) -> Dict[str, Any]:
  648. """Convert the object to a dictionary.
  649. Args:
  650. include_computed: Whether to include computed vars.
  651. **kwargs: Kwargs to pass to the pydantic dict method.
  652. Returns:
  653. The object as a dictionary.
  654. """
  655. if include_computed:
  656. # Apply dirty variables down into substates to allow never-cached ComputedVar to
  657. # trigger recalculation of dependent vars
  658. self.dirty_vars.update(self._always_dirty_computed_vars())
  659. self.mark_dirty()
  660. base_vars = {
  661. prop_name: self.get_value(getattr(self, prop_name))
  662. for prop_name in self.base_vars
  663. }
  664. computed_vars = (
  665. {
  666. # Include the computed vars.
  667. prop_name: self.get_value(getattr(self, prop_name))
  668. for prop_name in self.computed_vars
  669. }
  670. if include_computed
  671. else {}
  672. )
  673. substate_vars = {
  674. k: v.dict(include_computed=include_computed, **kwargs)
  675. for k, v in self.substates.items()
  676. }
  677. variables = {**base_vars, **computed_vars, **substate_vars}
  678. return {k: variables[k] for k in sorted(variables)}
  679. class DefaultState(State):
  680. """The default empty state."""
  681. pass
  682. class StateUpdate(Base):
  683. """A state update sent to the frontend."""
  684. # The state delta.
  685. delta: Delta = {}
  686. # Events to be added to the event queue.
  687. events: List[Event] = []
  688. class StateManager(Base):
  689. """A class to manage many client states."""
  690. # The state class to use.
  691. state: Type[State] = DefaultState
  692. # The mapping of client ids to states.
  693. states: Dict[str, State] = {}
  694. # The token expiration time (s).
  695. token_expiration: int = constants.TOKEN_EXPIRATION
  696. # The redis client to use.
  697. redis: Optional[Redis] = None
  698. def setup(self, state: Type[State]):
  699. """Set up the state manager.
  700. Args:
  701. state: The state class to use.
  702. """
  703. self.state = state
  704. self.redis = prerequisites.get_redis()
  705. def get_state(self, token: str) -> State:
  706. """Get the state for a token.
  707. Args:
  708. token: The token to get the state for.
  709. Returns:
  710. The state for the token.
  711. """
  712. if self.redis is not None:
  713. redis_state = self.redis.get(token)
  714. if redis_state is None:
  715. self.set_state(token, self.state())
  716. return self.get_state(token)
  717. return cloudpickle.loads(redis_state)
  718. if token not in self.states:
  719. self.states[token] = self.state()
  720. return self.states[token]
  721. def set_state(self, token: str, state: State):
  722. """Set the state for a token.
  723. Args:
  724. token: The token to set the state for.
  725. state: The state to set.
  726. """
  727. if self.redis is None:
  728. return
  729. self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
  730. def _convert_mutable_datatypes(
  731. field_value: Any, reassign_field: Callable, field_name: str
  732. ) -> Any:
  733. """Recursively convert mutable data to the Pc data types.
  734. Note: right now only list & dict would be handled recursively.
  735. Args:
  736. field_value: The target field_value.
  737. reassign_field:
  738. The function to reassign the field in the parent state.
  739. field_name: the name of the field in the parent state
  740. Returns:
  741. The converted field_value
  742. """
  743. if isinstance(field_value, list):
  744. for index in range(len(field_value)):
  745. field_value[index] = _convert_mutable_datatypes(
  746. field_value[index], reassign_field, field_name
  747. )
  748. field_value = PCList(
  749. field_value, reassign_field=reassign_field, field_name=field_name
  750. )
  751. if isinstance(field_value, dict):
  752. for key, value in field_value.items():
  753. field_value[key] = _convert_mutable_datatypes(
  754. value, reassign_field, field_name
  755. )
  756. field_value = PCDict(
  757. field_value, reassign_field=reassign_field, field_name=field_name
  758. )
  759. return field_value