state.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865
  1. """Define the pynecone state specification."""
  2. from __future__ import annotations
  3. import asyncio
  4. import functools
  5. import traceback
  6. from abc import ABC
  7. from collections import defaultdict
  8. from typing import (
  9. Any,
  10. Callable,
  11. ClassVar,
  12. Dict,
  13. List,
  14. Optional,
  15. Sequence,
  16. Set,
  17. Type,
  18. Union,
  19. )
  20. import cloudpickle
  21. import pydantic
  22. from redis import Redis
  23. from pynecone import constants
  24. from pynecone.base import Base
  25. from pynecone.event import Event, EventHandler, fix_events, window_alert
  26. from pynecone.utils import format, prerequisites, types
  27. from pynecone.var import BaseVar, ComputedVar, PCDict, PCList, Var
  28. Delta = Dict[str, Any]
  29. class State(Base, ABC, extra=pydantic.Extra.allow):
  30. """The state of the app."""
  31. # A map from the var name to the var.
  32. vars: ClassVar[Dict[str, Var]] = {}
  33. # The base vars of the class.
  34. base_vars: ClassVar[Dict[str, BaseVar]] = {}
  35. # The computed vars of the class.
  36. computed_vars: ClassVar[Dict[str, ComputedVar]] = {}
  37. # Vars inherited by the parent state.
  38. inherited_vars: ClassVar[Dict[str, Var]] = {}
  39. # Backend vars that are never sent to the client.
  40. backend_vars: ClassVar[Dict[str, Any]] = {}
  41. # Backend vars inherited
  42. inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
  43. # The event handlers.
  44. event_handlers: ClassVar[Dict[str, EventHandler]] = {}
  45. # The parent state.
  46. parent_state: Optional[State] = None
  47. # The substates of the state.
  48. substates: Dict[str, State] = {}
  49. # The set of dirty vars.
  50. dirty_vars: Set[str] = set()
  51. # The set of dirty substates.
  52. dirty_substates: Set[str] = set()
  53. # The routing path that triggered the state
  54. router_data: Dict[str, Any] = {}
  55. # Mapping of var name to set of computed variables that depend on it
  56. computed_var_dependencies: Dict[str, Set[str]] = {}
  57. def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
  58. """Initialize the state.
  59. Args:
  60. *args: The args to pass to the Pydantic init method.
  61. parent_state: The parent state.
  62. **kwargs: The kwargs to pass to the Pydantic init method.
  63. """
  64. kwargs["parent_state"] = parent_state
  65. super().__init__(*args, **kwargs)
  66. # Setup the substates.
  67. for substate in self.get_substates():
  68. self.substates[substate.get_name()] = substate(parent_state=self)
  69. # Convert the event handlers to functions.
  70. for name, event_handler in self.event_handlers.items():
  71. fn = functools.partial(event_handler.fn, self)
  72. fn.__module__ = event_handler.fn.__module__ # type: ignore
  73. fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
  74. setattr(self, name, fn)
  75. # Initialize computed vars dependencies.
  76. self.computed_var_dependencies = defaultdict(set)
  77. for cvar_name, cvar in self.computed_vars.items():
  78. # Add the dependencies.
  79. for var in cvar.deps():
  80. self.computed_var_dependencies[var].add(cvar_name)
  81. # Initialize the mutable fields.
  82. self._init_mutable_fields()
  83. def _init_mutable_fields(self):
  84. """Initialize mutable fields.
  85. So that mutation to them can be detected by the app:
  86. * list
  87. """
  88. for field in self.base_vars.values():
  89. value = getattr(self, field.name)
  90. value_in_pc_data = _convert_mutable_datatypes(
  91. value, self._reassign_field, field.name
  92. )
  93. if types._issubclass(field.type_, Union[List, Dict]):
  94. setattr(self, field.name, value_in_pc_data)
  95. self.clean()
  96. def _reassign_field(self, field_name: str):
  97. """Reassign the given field.
  98. Primarily for mutation in fields of mutable data types.
  99. Args:
  100. field_name: The name of the field we want to reassign
  101. """
  102. setattr(
  103. self,
  104. field_name,
  105. getattr(self, field_name),
  106. )
  107. def __repr__(self) -> str:
  108. """Get the string representation of the state.
  109. Returns:
  110. The string representation of the state.
  111. """
  112. return f"{self.__class__.__name__}({self.dict()})"
  113. @classmethod
  114. def __init_subclass__(cls, **kwargs):
  115. """Do some magic for the subclass initialization.
  116. Args:
  117. **kwargs: The kwargs to pass to the pydantic init_subclass method.
  118. """
  119. super().__init_subclass__(**kwargs)
  120. # Get the parent vars.
  121. parent_state = cls.get_parent_state()
  122. if parent_state is not None:
  123. cls.inherited_vars = parent_state.vars
  124. cls.inherited_backend_vars = parent_state.backend_vars
  125. cls.new_backend_vars = {
  126. name: value
  127. for name, value in cls.__dict__.items()
  128. if types.is_backend_variable(name)
  129. and name not in cls.inherited_backend_vars
  130. }
  131. cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
  132. # Set the base and computed vars.
  133. cls.base_vars = {
  134. f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls)
  135. for f in cls.get_fields().values()
  136. if f.name not in cls.get_skip_vars()
  137. }
  138. cls.computed_vars = {
  139. v.name: v.set_state(cls)
  140. for v in cls.__dict__.values()
  141. if isinstance(v, ComputedVar)
  142. }
  143. cls.vars = {
  144. **cls.inherited_vars,
  145. **cls.base_vars,
  146. **cls.computed_vars,
  147. }
  148. cls.event_handlers = {}
  149. # Setup the base vars at the class level.
  150. for prop in cls.base_vars.values():
  151. cls._init_var(prop)
  152. # Set up the event handlers.
  153. events = {
  154. name: fn
  155. for name, fn in cls.__dict__.items()
  156. if not name.startswith("_")
  157. and isinstance(fn, Callable)
  158. and not isinstance(fn, EventHandler)
  159. }
  160. for name, fn in events.items():
  161. handler = EventHandler(fn=fn)
  162. cls.event_handlers[name] = handler
  163. setattr(cls, name, handler)
  164. @classmethod
  165. def get_skip_vars(cls) -> Set[str]:
  166. """Get the vars to skip when serializing.
  167. Returns:
  168. The vars to skip when serializing.
  169. """
  170. return set(cls.inherited_vars) | {
  171. "parent_state",
  172. "substates",
  173. "dirty_vars",
  174. "dirty_substates",
  175. "router_data",
  176. "computed_var_dependencies",
  177. }
  178. @classmethod
  179. @functools.lru_cache()
  180. def get_parent_state(cls) -> Optional[Type[State]]:
  181. """Get the parent state.
  182. Returns:
  183. The parent state.
  184. """
  185. parent_states = [
  186. base
  187. for base in cls.__bases__
  188. if types._issubclass(base, State) and base is not State
  189. ]
  190. assert len(parent_states) < 2, "Only one parent state is allowed."
  191. return parent_states[0] if len(parent_states) == 1 else None # type: ignore
  192. @classmethod
  193. @functools.lru_cache()
  194. def get_substates(cls) -> Set[Type[State]]:
  195. """Get the substates of the state.
  196. Returns:
  197. The substates of the state.
  198. """
  199. return set(cls.__subclasses__())
  200. @classmethod
  201. @functools.lru_cache()
  202. def get_name(cls) -> str:
  203. """Get the name of the state.
  204. Returns:
  205. The name of the state.
  206. """
  207. return format.to_snake_case(cls.__name__)
  208. @classmethod
  209. @functools.lru_cache()
  210. def get_full_name(cls) -> str:
  211. """Get the full name of the state.
  212. Returns:
  213. The full name of the state.
  214. """
  215. name = cls.get_name()
  216. parent_state = cls.get_parent_state()
  217. if parent_state is not None:
  218. name = ".".join((parent_state.get_full_name(), name))
  219. return name
  220. @classmethod
  221. @functools.lru_cache()
  222. def get_class_substate(cls, path: Sequence[str]) -> Type[State]:
  223. """Get the class substate.
  224. Args:
  225. path: The path to the substate.
  226. Returns:
  227. The class substate.
  228. Raises:
  229. ValueError: If the substate is not found.
  230. """
  231. if len(path) == 0:
  232. return cls
  233. if path[0] == cls.get_name():
  234. if len(path) == 1:
  235. return cls
  236. path = path[1:]
  237. for substate in cls.get_substates():
  238. if path[0] == substate.get_name():
  239. return substate.get_class_substate(path[1:])
  240. raise ValueError(f"Invalid path: {path}")
  241. @classmethod
  242. def get_class_var(cls, path: Sequence[str]) -> Any:
  243. """Get the class var.
  244. Args:
  245. path: The path to the var.
  246. Returns:
  247. The class var.
  248. Raises:
  249. ValueError: If the path is invalid.
  250. """
  251. path, name = path[:-1], path[-1]
  252. substate = cls.get_class_substate(tuple(path))
  253. if not hasattr(substate, name):
  254. raise ValueError(f"Invalid path: {path}")
  255. return getattr(substate, name)
  256. @classmethod
  257. def _init_var(cls, prop: BaseVar):
  258. """Initialize a variable.
  259. Args:
  260. prop (BaseVar): The variable to initialize
  261. Raises:
  262. TypeError: if the variable has an incorrect type
  263. """
  264. if not types.is_valid_var_type(prop.type_):
  265. raise TypeError(
  266. "State vars must be primitive Python types, "
  267. "Plotly figures, Pandas dataframes, "
  268. "or subclasses of pc.Base. "
  269. f'Found var "{prop.name}" with type {prop.type_}.'
  270. )
  271. cls._set_var(prop)
  272. cls._create_setter(prop)
  273. cls._set_default_value(prop)
  274. @classmethod
  275. def add_var(cls, name: str, type_: Any, default_value: Any = None):
  276. """Add dynamically a variable to the State.
  277. The variable added this way can be used in the same way as a variable
  278. defined statically in the model.
  279. Args:
  280. name: The name of the variable
  281. type_: The type of the variable
  282. default_value: The default value of the variable
  283. Raises:
  284. NameError: if a variable of this name already exists
  285. """
  286. if name in cls.__fields__:
  287. raise NameError(
  288. f"The variable '{name}' already exist. Use a different name"
  289. )
  290. # create the variable based on name and type
  291. var = BaseVar(name=name, type_=type_)
  292. var.set_state(cls)
  293. # add the pydantic field dynamically (must be done before _init_var)
  294. cls.add_field(var, default_value)
  295. cls._init_var(var)
  296. # update the internal dicts so the new variable is correctly handled
  297. cls.base_vars.update({name: var})
  298. cls.vars.update({name: var})
  299. @classmethod
  300. def _set_var(cls, prop: BaseVar):
  301. """Set the var as a class member.
  302. Args:
  303. prop: The var instance to set.
  304. """
  305. setattr(cls, prop.name, prop)
  306. @classmethod
  307. def _create_setter(cls, prop: BaseVar):
  308. """Create a setter for the var.
  309. Args:
  310. prop: The var to create a setter for.
  311. """
  312. setter_name = prop.get_setter_name(include_state=False)
  313. if setter_name not in cls.__dict__:
  314. event_handler = EventHandler(fn=prop.get_setter())
  315. cls.event_handlers[setter_name] = event_handler
  316. setattr(cls, setter_name, event_handler)
  317. @classmethod
  318. def _set_default_value(cls, prop: BaseVar):
  319. """Set the default value for the var.
  320. Args:
  321. prop: The var to set the default value for.
  322. """
  323. # Get the pydantic field for the var.
  324. field = cls.get_fields()[prop.name]
  325. default_value = prop.get_default_value()
  326. if field.required and default_value is not None:
  327. field.required = False
  328. field.default = default_value
  329. def get_token(self) -> str:
  330. """Return the token of the client associated with this state.
  331. Returns:
  332. The token of the client.
  333. """
  334. return self.router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
  335. def get_sid(self) -> str:
  336. """Return the session ID of the client associated with this state.
  337. Returns:
  338. The session ID of the client.
  339. """
  340. return self.router_data.get(constants.RouteVar.SESSION_ID, "")
  341. def get_headers(self) -> Dict:
  342. """Return the headers of the client associated with this state.
  343. Returns:
  344. The headers of the client.
  345. """
  346. return self.router_data.get(constants.RouteVar.HEADERS, {})
  347. def get_client_ip(self) -> str:
  348. """Return the IP of the client associated with this state.
  349. Returns:
  350. The IP of the client.
  351. """
  352. return self.router_data.get(constants.RouteVar.CLIENT_IP, "")
  353. def get_current_page(self) -> str:
  354. """Obtain the path of current page from the router data.
  355. Returns:
  356. The current page.
  357. """
  358. return self.router_data.get(constants.RouteVar.PATH, "")
  359. def get_query_params(self) -> Dict[str, str]:
  360. """Obtain the query parameters for the queried page.
  361. The query object contains both the URI parameters and the GET parameters.
  362. Returns:
  363. The dict of query parameters.
  364. """
  365. return self.router_data.get(constants.RouteVar.QUERY, {})
  366. @classmethod
  367. def setup_dynamic_args(cls, args: dict[str, str]):
  368. """Set up args for easy access in renderer.
  369. Args:
  370. args: a dict of args
  371. """
  372. def argsingle_factory(param):
  373. @ComputedVar
  374. def inner_func(self) -> str:
  375. return self.get_query_params().get(param, "")
  376. return inner_func
  377. def arglist_factory(param):
  378. @ComputedVar
  379. def inner_func(self) -> List:
  380. return self.get_query_params().get(param, [])
  381. return inner_func
  382. for param, value in args.items():
  383. if value == constants.RouteArgType.SINGLE:
  384. func = argsingle_factory(param)
  385. elif value == constants.RouteArgType.LIST:
  386. func = arglist_factory(param)
  387. else:
  388. continue
  389. cls.computed_vars[param] = func.set_state(cls) # type: ignore
  390. setattr(cls, param, func)
  391. def __getattribute__(self, name: str) -> Any:
  392. """Get the state var.
  393. If the var is inherited, get the var from the parent state.
  394. Args:
  395. name: The name of the var.
  396. Returns:
  397. The value of the var.
  398. """
  399. # If the state hasn't been initialized yet, return the default value.
  400. if not super().__getattribute__("__dict__"):
  401. return super().__getattribute__(name)
  402. inherited_vars = {
  403. **super().__getattribute__("inherited_vars"),
  404. **super().__getattribute__("inherited_backend_vars"),
  405. }
  406. if name in inherited_vars:
  407. return getattr(super().__getattribute__("parent_state"), name)
  408. elif name in super().__getattribute__("backend_vars"):
  409. return super().__getattribute__("backend_vars").__getitem__(name)
  410. return super().__getattribute__(name)
  411. def __setattr__(self, name: str, value: Any):
  412. """Set the attribute.
  413. If the attribute is inherited, set the attribute on the parent state.
  414. Args:
  415. name: The name of the attribute.
  416. value: The value of the attribute.
  417. """
  418. # Set the var on the parent state.
  419. inherited_vars = {**self.inherited_vars, **self.inherited_backend_vars}
  420. if name in inherited_vars:
  421. setattr(self.parent_state, name, value)
  422. return
  423. if types.is_backend_variable(name):
  424. self.backend_vars.__setitem__(name, value)
  425. self.dirty_vars.add(name)
  426. self.mark_dirty()
  427. return
  428. # Set the attribute.
  429. super().__setattr__(name, value)
  430. # Add the var to the dirty list.
  431. if name in self.vars:
  432. self.dirty_vars.add(name)
  433. self.mark_dirty()
  434. def reset(self):
  435. """Reset all the base vars to their default values."""
  436. # Reset the base vars.
  437. fields = self.get_fields()
  438. for prop_name in self.base_vars:
  439. setattr(self, prop_name, fields[prop_name].default)
  440. # Recursively reset the substates.
  441. for substate in self.substates.values():
  442. substate.reset()
  443. # Clean the state.
  444. self.clean()
  445. def get_substate(self, path: Sequence[str]) -> Optional[State]:
  446. """Get the substate.
  447. Args:
  448. path: The path to the substate.
  449. Returns:
  450. The substate.
  451. Raises:
  452. ValueError: If the substate is not found.
  453. """
  454. if len(path) == 0:
  455. return self
  456. if path[0] == self.get_name():
  457. if len(path) == 1:
  458. return self
  459. path = path[1:]
  460. if path[0] not in self.substates:
  461. raise ValueError(f"Invalid path: {path}")
  462. return self.substates[path[0]].get_substate(path[1:])
  463. async def _process(self, event: Event) -> StateUpdate:
  464. """Obtain event info and process event.
  465. Args:
  466. event: The event to process.
  467. Returns:
  468. The state update after processing the event.
  469. Raises:
  470. ValueError: If the state value is None.
  471. """
  472. # Get the event handler.
  473. path = event.name.split(".")
  474. path, name = path[:-1], path[-1]
  475. substate = self.get_substate(path)
  476. handler = substate.event_handlers[name] # type: ignore
  477. if not substate:
  478. raise ValueError(
  479. "The value of state cannot be None when processing an event."
  480. )
  481. return await self._process_event(
  482. handler=handler,
  483. state=substate,
  484. payload=event.payload,
  485. token=event.token,
  486. )
  487. async def _process_event(
  488. self, handler: EventHandler, state: State, payload: Dict, token: str
  489. ) -> StateUpdate:
  490. """Process event.
  491. Args:
  492. handler: Eventhandler to process.
  493. state: State to process the handler.
  494. payload: The event payload.
  495. token: Client token.
  496. Returns:
  497. The state update after processing the event.
  498. """
  499. fn = functools.partial(handler.fn, state)
  500. try:
  501. if asyncio.iscoroutinefunction(fn.func):
  502. events = await fn(**payload)
  503. else:
  504. events = fn(**payload)
  505. except Exception:
  506. error = traceback.format_exc()
  507. print(error)
  508. events = fix_events(
  509. [window_alert("An error occurred. See logs for details.")], token
  510. )
  511. return StateUpdate(events=events)
  512. # Fix the returned events.
  513. events = fix_events(events, token)
  514. # Get the delta after processing the event.
  515. delta = self.get_delta()
  516. # Reset the dirty vars.
  517. self.clean()
  518. # Return the state update.
  519. return StateUpdate(delta=delta, events=events)
  520. def _mark_dirty_computed_vars(self) -> None:
  521. """Mark ComputedVars that need to be recalculated based on dirty_vars."""
  522. dirty_vars = self.dirty_vars
  523. while dirty_vars:
  524. calc_vars, dirty_vars = dirty_vars, set()
  525. for cvar in self._dirty_computed_vars(from_vars=calc_vars):
  526. self.dirty_vars.add(cvar)
  527. dirty_vars.add(cvar)
  528. actual_var = self.computed_vars.get(cvar)
  529. if actual_var:
  530. actual_var.mark_dirty(instance=self)
  531. def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
  532. """Determine ComputedVars that need to be recalculated based on the given vars.
  533. Args:
  534. from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
  535. Returns:
  536. Set of computed vars to include in the delta.
  537. """
  538. return set(
  539. cvar
  540. for dirty_var in from_vars or self.dirty_vars
  541. for cvar in self.computed_var_dependencies[dirty_var]
  542. )
  543. def get_delta(self) -> Delta:
  544. """Get the delta for the state.
  545. Returns:
  546. The delta for the state.
  547. """
  548. delta = {}
  549. # Recursively find the substate deltas.
  550. substates = self.substates
  551. for substate in self.dirty_substates:
  552. delta.update(substates[substate].get_delta())
  553. # Return the dirty vars and dependent computed vars
  554. delta_vars = self.dirty_vars.intersection(self.base_vars).union(
  555. self._dirty_computed_vars()
  556. )
  557. subdelta = {
  558. prop: getattr(self, prop)
  559. for prop in delta_vars
  560. if not types.is_backend_variable(prop)
  561. }
  562. if len(subdelta) > 0:
  563. delta[self.get_full_name()] = subdelta
  564. # Format the delta.
  565. delta = format.format_state(delta)
  566. # Return the delta.
  567. return delta
  568. def mark_dirty(self):
  569. """Mark the substate and all parent states as dirty."""
  570. if self.parent_state is not None:
  571. self.parent_state.dirty_substates.add(self.get_name())
  572. self.parent_state.mark_dirty()
  573. # have to mark computed vars dirty to allow access to newly computed
  574. # values within the same ComputedVar function
  575. self._mark_dirty_computed_vars()
  576. def clean(self):
  577. """Reset the dirty vars."""
  578. # Recursively clean the substates.
  579. for substate in self.dirty_substates:
  580. self.substates[substate].clean()
  581. # Clean this state.
  582. self.dirty_vars = set()
  583. self.dirty_substates = set()
  584. def dict(self, include_computed: bool = True, **kwargs) -> Dict[str, Any]:
  585. """Convert the object to a dictionary.
  586. Args:
  587. include_computed: Whether to include computed vars.
  588. **kwargs: Kwargs to pass to the pydantic dict method.
  589. Returns:
  590. The object as a dictionary.
  591. """
  592. base_vars = {
  593. prop_name: self.get_value(getattr(self, prop_name))
  594. for prop_name in self.base_vars
  595. }
  596. computed_vars = (
  597. {
  598. # Include the computed vars.
  599. prop_name: self.get_value(getattr(self, prop_name))
  600. for prop_name in self.computed_vars
  601. }
  602. if include_computed
  603. else {}
  604. )
  605. substate_vars = {
  606. k: v.dict(include_computed=include_computed, **kwargs)
  607. for k, v in self.substates.items()
  608. }
  609. variables = {**base_vars, **computed_vars, **substate_vars}
  610. return {k: variables[k] for k in sorted(variables)}
  611. class DefaultState(State):
  612. """The default empty state."""
  613. pass
  614. class StateUpdate(Base):
  615. """A state update sent to the frontend."""
  616. # The state delta.
  617. delta: Delta = {}
  618. # Events to be added to the event queue.
  619. events: List[Event] = []
  620. class StateManager(Base):
  621. """A class to manage many client states."""
  622. # The state class to use.
  623. state: Type[State] = DefaultState
  624. # The mapping of client ids to states.
  625. states: Dict[str, State] = {}
  626. # The token expiration time (s).
  627. token_expiration: int = constants.TOKEN_EXPIRATION
  628. # The redis client to use.
  629. redis: Optional[Redis] = None
  630. def setup(self, state: Type[State]):
  631. """Set up the state manager.
  632. Args:
  633. state: The state class to use.
  634. """
  635. self.state = state
  636. self.redis = prerequisites.get_redis()
  637. def get_state(self, token: str) -> State:
  638. """Get the state for a token.
  639. Args:
  640. token: The token to get the state for.
  641. Returns:
  642. The state for the token.
  643. """
  644. if self.redis is not None:
  645. redis_state = self.redis.get(token)
  646. if redis_state is None:
  647. self.set_state(token, self.state())
  648. return self.get_state(token)
  649. return cloudpickle.loads(redis_state)
  650. if token not in self.states:
  651. self.states[token] = self.state()
  652. return self.states[token]
  653. def set_state(self, token: str, state: State):
  654. """Set the state for a token.
  655. Args:
  656. token: The token to set the state for.
  657. state: The state to set.
  658. """
  659. if self.redis is None:
  660. return
  661. self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
  662. def _convert_mutable_datatypes(
  663. field_value: Any, reassign_field: Callable, field_name: str
  664. ) -> Any:
  665. """Recursively convert mutable data to the Pc data types.
  666. Note: right now only list & dict would be handled recursively.
  667. Args:
  668. field_value: The target field_value.
  669. reassign_field:
  670. The function to reassign the field in the parent state.
  671. field_name: the name of the field in the parent state
  672. Returns:
  673. The converted field_value
  674. """
  675. if isinstance(field_value, list):
  676. for index in range(len(field_value)):
  677. field_value[index] = _convert_mutable_datatypes(
  678. field_value[index], reassign_field, field_name
  679. )
  680. field_value = PCList(
  681. field_value, reassign_field=reassign_field, field_name=field_name
  682. )
  683. if isinstance(field_value, dict):
  684. for key, value in field_value.items():
  685. field_value[key] = _convert_mutable_datatypes(
  686. value, reassign_field, field_name
  687. )
  688. field_value = PCDict(
  689. field_value, reassign_field=reassign_field, field_name=field_name
  690. )
  691. return field_value