state.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872
  1. """Define the pynecone state specification."""
  2. from __future__ import annotations
  3. import asyncio
  4. import functools
  5. import inspect
  6. import traceback
  7. from abc import ABC
  8. from typing import (
  9. Any,
  10. Callable,
  11. ClassVar,
  12. Dict,
  13. List,
  14. Optional,
  15. Sequence,
  16. Set,
  17. Tuple,
  18. Type,
  19. Union,
  20. )
  21. import cloudpickle
  22. from redis import Redis
  23. from pynecone import constants
  24. from pynecone.base import Base
  25. from pynecone.event import Event, EventHandler, fix_events, window_alert
  26. from pynecone.utils import format, prerequisites, types
  27. from pynecone.var import BaseVar, ComputedVar, PCDict, PCList, Var
  28. Delta = Dict[str, Any]
  29. class State(Base, ABC):
  30. """The state of the app."""
  31. # A map from the var name to the var.
  32. vars: ClassVar[Dict[str, Var]] = {}
  33. # The base vars of the class.
  34. base_vars: ClassVar[Dict[str, BaseVar]] = {}
  35. # The computed vars of the class.
  36. computed_vars: ClassVar[Dict[str, ComputedVar]] = {}
  37. # Vars inherited by the parent state.
  38. inherited_vars: ClassVar[Dict[str, Var]] = {}
  39. # Backend vars that are never sent to the client.
  40. backend_vars: ClassVar[Dict[str, Any]] = {}
  41. # Backend vars inherited
  42. inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
  43. # Mapping of var name to set of computed variables that depend on it
  44. computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
  45. # The event handlers.
  46. event_handlers: ClassVar[Dict[str, EventHandler]] = {}
  47. # The parent state.
  48. parent_state: Optional[State] = None
  49. # The substates of the state.
  50. substates: Dict[str, State] = {}
  51. # The set of dirty vars.
  52. dirty_vars: Set[str] = set()
  53. # The set of dirty substates.
  54. dirty_substates: Set[str] = set()
  55. # The routing path that triggered the state
  56. router_data: Dict[str, Any] = {}
  57. def __init__(self, *args, **kwargs):
  58. """Initialize the state.
  59. Args:
  60. *args: The args to pass to the Pydantic init method.
  61. **kwargs: The kwargs to pass to the Pydantic init method.
  62. """
  63. super().__init__(*args, **kwargs)
  64. # Setup the substates.
  65. for substate in self.get_substates():
  66. self.substates[substate.get_name()] = substate().set(parent_state=self)
  67. self._init_mutable_fields()
  68. def _init_mutable_fields(self):
  69. """Initialize mutable fields.
  70. So that mutation to them can be detected by the app:
  71. * list
  72. """
  73. for field in self.base_vars.values():
  74. value = getattr(self, field.name)
  75. value_in_pc_data = _convert_mutable_datatypes(
  76. value, self._reassign_field, field.name
  77. )
  78. if types._issubclass(field.type_, Union[List, Dict]):
  79. setattr(self, field.name, value_in_pc_data)
  80. self.clean()
  81. def _reassign_field(self, field_name: str):
  82. """Reassign the given field.
  83. Primarily for mutation in fields of mutable data types.
  84. Args:
  85. field_name: The name of the field we want to reassign
  86. """
  87. setattr(
  88. self,
  89. field_name,
  90. getattr(self, field_name),
  91. )
  92. def __repr__(self) -> str:
  93. """Get the string representation of the state.
  94. Returns:
  95. The string representation of the state.
  96. """
  97. return f"{self.__class__.__name__}({self.dict()})"
  98. @classmethod
  99. def __init_subclass__(cls, **kwargs):
  100. """Do some magic for the subclass initialization.
  101. Args:
  102. **kwargs: The kwargs to pass to the pydantic init_subclass method.
  103. """
  104. super().__init_subclass__(**kwargs)
  105. # Get the parent vars.
  106. parent_state = cls.get_parent_state()
  107. if parent_state is not None:
  108. cls.inherited_vars = parent_state.vars
  109. cls.inherited_backend_vars = parent_state.backend_vars
  110. cls.new_backend_vars = {
  111. name: value
  112. for name, value in cls.__dict__.items()
  113. if types.is_backend_variable(name)
  114. and name not in cls.inherited_backend_vars
  115. }
  116. cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
  117. # Set the base and computed vars.
  118. skip_vars = set(cls.inherited_vars) | {
  119. "parent_state",
  120. "substates",
  121. "dirty_vars",
  122. "dirty_substates",
  123. "router_data",
  124. }
  125. cls.base_vars = {
  126. f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls)
  127. for f in cls.get_fields().values()
  128. if f.name not in skip_vars
  129. }
  130. cls.computed_vars = {
  131. v.name: v.set_state(cls)
  132. for v in cls.__dict__.values()
  133. if isinstance(v, ComputedVar)
  134. }
  135. cls.vars = {
  136. **cls.inherited_vars,
  137. **cls.base_vars,
  138. **cls.computed_vars,
  139. }
  140. cls.computed_var_dependencies = {}
  141. # Setup the base vars at the class level.
  142. for prop in cls.base_vars.values():
  143. cls._init_var(prop)
  144. # Set up the event handlers.
  145. events = {
  146. name: fn
  147. for name, fn in cls.__dict__.items()
  148. if not name.startswith("_") and isinstance(fn, Callable)
  149. }
  150. cls.event_handlers = {name: EventHandler(fn=fn) for name, fn in events.items()}
  151. cls.set_handlers()
  152. @classmethod
  153. def convert_handlers_to_fns(cls):
  154. """Convert the event handlers to functions.
  155. This is done so the state functions can be called as normal functions during runtime.
  156. """
  157. for name, event_handler in cls.event_handlers.items():
  158. setattr(cls, name, event_handler.fn)
  159. for substate in cls.get_substates():
  160. substate.convert_handlers_to_fns()
  161. @classmethod
  162. def set_handlers(cls):
  163. """Set the state class handlers."""
  164. for name, event_handler in cls.event_handlers.items():
  165. setattr(cls, name, event_handler)
  166. @classmethod
  167. @functools.lru_cache()
  168. def get_parent_state(cls) -> Optional[Type[State]]:
  169. """Get the parent state.
  170. Returns:
  171. The parent state.
  172. """
  173. parent_states = [
  174. base
  175. for base in cls.__bases__
  176. if types._issubclass(base, State) and base is not State
  177. ]
  178. assert len(parent_states) < 2, "Only one parent state is allowed."
  179. return parent_states[0] if len(parent_states) == 1 else None # type: ignore
  180. @classmethod
  181. @functools.lru_cache()
  182. def get_substates(cls) -> Set[Type[State]]:
  183. """Get the substates of the state.
  184. Returns:
  185. The substates of the state.
  186. """
  187. return set(cls.__subclasses__())
  188. @classmethod
  189. @functools.lru_cache()
  190. def get_name(cls) -> str:
  191. """Get the name of the state.
  192. Returns:
  193. The name of the state.
  194. """
  195. return format.to_snake_case(cls.__name__)
  196. @classmethod
  197. @functools.lru_cache()
  198. def get_full_name(cls) -> str:
  199. """Get the full name of the state.
  200. Returns:
  201. The full name of the state.
  202. """
  203. name = cls.get_name()
  204. parent_state = cls.get_parent_state()
  205. if parent_state is not None:
  206. name = ".".join((parent_state.get_full_name(), name))
  207. return name
  208. @classmethod
  209. @functools.lru_cache()
  210. def get_class_substate(cls, path: Sequence[str]) -> Type[State]:
  211. """Get the class substate.
  212. Args:
  213. path: The path to the substate.
  214. Returns:
  215. The class substate.
  216. Raises:
  217. ValueError: If the substate is not found.
  218. """
  219. if len(path) == 0:
  220. return cls
  221. if path[0] == cls.get_name():
  222. if len(path) == 1:
  223. return cls
  224. path = path[1:]
  225. for substate in cls.get_substates():
  226. if path[0] == substate.get_name():
  227. return substate.get_class_substate(path[1:])
  228. raise ValueError(f"Invalid path: {path}")
  229. @classmethod
  230. def get_class_var(cls, path: Sequence[str]) -> Any:
  231. """Get the class var.
  232. Args:
  233. path: The path to the var.
  234. Returns:
  235. The class var.
  236. Raises:
  237. ValueError: If the path is invalid.
  238. """
  239. path, name = path[:-1], path[-1]
  240. substate = cls.get_class_substate(tuple(path))
  241. if not hasattr(substate, name):
  242. raise ValueError(f"Invalid path: {path}")
  243. return getattr(substate, name)
  244. @classmethod
  245. def _init_var(cls, prop: BaseVar):
  246. """Initialize a variable.
  247. Args:
  248. prop (BaseVar): The variable to initialize
  249. Raises:
  250. TypeError: if the variable has an incorrect type
  251. """
  252. if not types.is_valid_var_type(prop.type_):
  253. raise TypeError(
  254. "State vars must be primitive Python types, "
  255. "Plotly figures, Pandas dataframes, "
  256. "or subclasses of pc.Base. "
  257. f'Found var "{prop.name}" with type {prop.type_}.'
  258. )
  259. cls._set_var(prop)
  260. cls._create_setter(prop)
  261. cls._set_default_value(prop)
  262. @classmethod
  263. def add_var(cls, name: str, type_: Any, default_value: Any = None):
  264. """Add dynamically a variable to the State.
  265. The variable added this way can be used in the same way as a variable
  266. defined statically in the model.
  267. Args:
  268. name: The name of the variable
  269. type_: The type of the variable
  270. default_value: The default value of the variable
  271. Raises:
  272. NameError: if a variable of this name already exists
  273. """
  274. if name in cls.__fields__:
  275. raise NameError(
  276. f"The variable '{name}' already exist. Use a different name"
  277. )
  278. # create the variable based on name and type
  279. var = BaseVar(name=name, type_=type_)
  280. var.set_state(cls)
  281. # add the pydantic field dynamically (must be done before _init_var)
  282. cls.add_field(var, default_value)
  283. cls._init_var(var)
  284. # update the internal dicts so the new variable is correctly handled
  285. cls.base_vars.update({name: var})
  286. cls.vars.update({name: var})
  287. @classmethod
  288. def _set_var(cls, prop: BaseVar):
  289. """Set the var as a class member.
  290. Args:
  291. prop: The var instance to set.
  292. """
  293. setattr(cls, prop.name, prop)
  294. @classmethod
  295. def _create_setter(cls, prop: BaseVar):
  296. """Create a setter for the var.
  297. Args:
  298. prop: The var to create a setter for.
  299. """
  300. setter_name = prop.get_setter_name(include_state=False)
  301. if setter_name not in cls.__dict__:
  302. setattr(cls, setter_name, prop.get_setter())
  303. @classmethod
  304. def _set_default_value(cls, prop: BaseVar):
  305. """Set the default value for the var.
  306. Args:
  307. prop: The var to set the default value for.
  308. """
  309. # Get the pydantic field for the var.
  310. field = cls.get_fields()[prop.name]
  311. default_value = prop.get_default_value()
  312. if field.required and default_value is not None:
  313. field.required = False
  314. field.default = default_value
  315. def get_token(self) -> str:
  316. """Return the token of the client associated with this state.
  317. Returns:
  318. The token of the client.
  319. """
  320. return self.router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
  321. def get_sid(self) -> str:
  322. """Return the session ID of the client associated with this state.
  323. Returns:
  324. The session ID of the client.
  325. """
  326. return self.router_data.get(constants.RouteVar.SESSION_ID, "")
  327. def get_headers(self) -> Dict:
  328. """Return the headers of the client associated with this state.
  329. Returns:
  330. The headers of the client.
  331. """
  332. return self.router_data.get(constants.RouteVar.HEADERS, {})
  333. def get_client_ip(self) -> str:
  334. """Return the IP of the client associated with this state.
  335. Returns:
  336. The IP of the client.
  337. """
  338. return self.router_data.get(constants.RouteVar.CLIENT_IP, "")
  339. def get_current_page(self) -> str:
  340. """Obtain the path of current page from the router data.
  341. Returns:
  342. The current page.
  343. """
  344. return self.router_data.get(constants.RouteVar.PATH, "")
  345. def get_query_params(self) -> Dict[str, str]:
  346. """Obtain the query parameters for the queried page.
  347. The query object contains both the URI parameters and the GET parameters.
  348. Returns:
  349. The dict of query parameters.
  350. """
  351. return self.router_data.get(constants.RouteVar.QUERY, {})
  352. @classmethod
  353. def setup_dynamic_args(cls, args: dict[str, str]):
  354. """Set up args for easy access in renderer.
  355. Args:
  356. args: a dict of args
  357. """
  358. def argsingle_factory(param):
  359. @ComputedVar
  360. def inner_func(self) -> str:
  361. return self.get_query_params().get(param, "")
  362. return inner_func
  363. def arglist_factory(param):
  364. @ComputedVar
  365. def inner_func(self) -> List:
  366. return self.get_query_params().get(param, [])
  367. return inner_func
  368. for param, value in args.items():
  369. if value == constants.RouteArgType.SINGLE:
  370. func = argsingle_factory(param)
  371. elif value == constants.RouteArgType.LIST:
  372. func = arglist_factory(param)
  373. else:
  374. continue
  375. cls.computed_vars[param] = func.set_state(cls) # type: ignore
  376. setattr(cls, param, func)
  377. def __getattribute__(self, name: str) -> Any:
  378. """Get the state var.
  379. If the var is inherited, get the var from the parent state.
  380. If the Var is a dependent of a ComputedVar, track this status in computed_var_dependencies.
  381. Args:
  382. name: The name of the var.
  383. Returns:
  384. The value of the var.
  385. """
  386. vars = {
  387. **super().__getattribute__("vars"),
  388. **super().__getattribute__("backend_vars"),
  389. }
  390. if name in vars:
  391. parent_frame, parent_frame_locals = _get_previous_recursive_frame_info()
  392. if parent_frame is not None:
  393. computed_vars = super().__getattribute__("computed_vars")
  394. requesting_attribute_name = parent_frame_locals.get("name")
  395. if requesting_attribute_name in computed_vars:
  396. # Keep track of any ComputedVar that depends on this Var
  397. super().__getattribute__("computed_var_dependencies").setdefault(
  398. name, set()
  399. ).add(requesting_attribute_name)
  400. inherited_vars = {
  401. **super().__getattribute__("inherited_vars"),
  402. **super().__getattribute__("inherited_backend_vars"),
  403. }
  404. if name in inherited_vars:
  405. return getattr(super().__getattribute__("parent_state"), name)
  406. elif name in super().__getattribute__("backend_vars"):
  407. return super().__getattribute__("backend_vars").__getitem__(name)
  408. return super().__getattribute__(name)
  409. def __setattr__(self, name: str, value: Any):
  410. """Set the attribute.
  411. If the attribute is inherited, set the attribute on the parent state.
  412. Args:
  413. name: The name of the attribute.
  414. value: The value of the attribute.
  415. """
  416. # Set the var on the parent state.
  417. inherited_vars = {**self.inherited_vars, **self.inherited_backend_vars}
  418. if name in inherited_vars:
  419. setattr(self.parent_state, name, value)
  420. return
  421. if types.is_backend_variable(name):
  422. self.backend_vars.__setitem__(name, value)
  423. self.dirty_vars.add(name)
  424. self.mark_dirty()
  425. return
  426. # Set the attribute.
  427. super().__setattr__(name, value)
  428. # Add the var to the dirty list.
  429. if name in self.vars:
  430. self.dirty_vars.add(name)
  431. self.mark_dirty()
  432. def reset(self):
  433. """Reset all the base vars to their default values."""
  434. # Reset the base vars.
  435. fields = self.get_fields()
  436. for prop_name in self.base_vars:
  437. setattr(self, prop_name, fields[prop_name].default)
  438. # Recursively reset the substates.
  439. for substate in self.substates.values():
  440. substate.reset()
  441. # Clean the state.
  442. self.clean()
  443. def get_substate(self, path: Sequence[str]) -> Optional[State]:
  444. """Get the substate.
  445. Args:
  446. path: The path to the substate.
  447. Returns:
  448. The substate.
  449. Raises:
  450. ValueError: If the substate is not found.
  451. """
  452. if len(path) == 0:
  453. return self
  454. if path[0] == self.get_name():
  455. if len(path) == 1:
  456. return self
  457. path = path[1:]
  458. if path[0] not in self.substates:
  459. raise ValueError(f"Invalid path: {path}")
  460. return self.substates[path[0]].get_substate(path[1:])
  461. async def _process(self, event: Event) -> StateUpdate:
  462. """Obtain event info and process event.
  463. Args:
  464. event: The event to process.
  465. Returns:
  466. The state update after processing the event.
  467. Raises:
  468. ValueError: If the state value is None.
  469. """
  470. # Get the event handler.
  471. path = event.name.split(".")
  472. path, name = path[:-1], path[-1]
  473. substate = self.get_substate(path)
  474. handler = substate.event_handlers[name] # type: ignore
  475. if not substate:
  476. raise ValueError(
  477. "The value of state cannot be None when processing an event."
  478. )
  479. return await self._process_event(
  480. handler=handler,
  481. state=substate,
  482. payload=event.payload,
  483. token=event.token,
  484. )
  485. async def _process_event(
  486. self, handler: EventHandler, state: State, payload: Dict, token: str
  487. ) -> StateUpdate:
  488. """Process event.
  489. Args:
  490. handler: Eventhandler to process.
  491. state: State to process the handler.
  492. payload: The event payload.
  493. token: Client token.
  494. Returns:
  495. The state update after processing the event.
  496. """
  497. fn = functools.partial(handler.fn, state)
  498. try:
  499. if asyncio.iscoroutinefunction(fn.func):
  500. events = await fn(**payload)
  501. else:
  502. events = fn(**payload)
  503. except Exception:
  504. error = traceback.format_exc()
  505. print(error)
  506. events = fix_events(
  507. [window_alert("An error occurred. See logs for details.")], token
  508. )
  509. return StateUpdate(events=events)
  510. # Fix the returned events.
  511. events = fix_events(events, token)
  512. # Get the delta after processing the event.
  513. delta = self.get_delta()
  514. # Reset the dirty vars.
  515. self.clean()
  516. # Return the state update.
  517. return StateUpdate(delta=delta, events=events)
  518. def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
  519. """Get ComputedVars that need to be recomputed based on dirty_vars.
  520. Args:
  521. from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
  522. Returns:
  523. Set of computed vars to include in the delta.
  524. """
  525. dirty_computed_vars = set(
  526. cvar
  527. for dirty_var in from_vars or self.dirty_vars
  528. for cvar in self.computed_vars
  529. if cvar in self.computed_var_dependencies.get(dirty_var, set())
  530. )
  531. if dirty_computed_vars:
  532. # recursive call to catch computed vars that depend on computed vars
  533. return dirty_computed_vars | self._dirty_computed_vars(
  534. from_vars=dirty_computed_vars
  535. )
  536. return dirty_computed_vars
  537. def get_delta(self) -> Delta:
  538. """Get the delta for the state.
  539. Returns:
  540. The delta for the state.
  541. """
  542. delta = {}
  543. # Return the dirty vars, as well as computed vars depending on dirty vars.
  544. subdelta = {
  545. prop: getattr(self, prop)
  546. for prop in self.dirty_vars | self._dirty_computed_vars()
  547. if not types.is_backend_variable(prop)
  548. }
  549. if len(subdelta) > 0:
  550. delta[self.get_full_name()] = subdelta
  551. # Recursively find the substate deltas.
  552. substates = self.substates
  553. for substate in self.dirty_substates:
  554. delta.update(substates[substate].get_delta())
  555. # Format the delta.
  556. delta = format.format_state(delta)
  557. # Return the delta.
  558. return delta
  559. def mark_dirty(self):
  560. """Mark the substate and all parent states as dirty."""
  561. if self.parent_state is not None:
  562. self.parent_state.dirty_substates.add(self.get_name())
  563. self.parent_state.mark_dirty()
  564. def clean(self):
  565. """Reset the dirty vars."""
  566. # Recursively clean the substates.
  567. for substate in self.dirty_substates:
  568. self.substates[substate].clean()
  569. # Clean this state.
  570. self.dirty_vars = set()
  571. self.dirty_substates = set()
  572. def dict(self, include_computed: bool = True, **kwargs) -> Dict[str, Any]:
  573. """Convert the object to a dictionary.
  574. Args:
  575. include_computed: Whether to include computed vars.
  576. **kwargs: Kwargs to pass to the pydantic dict method.
  577. Returns:
  578. The object as a dictionary.
  579. """
  580. base_vars = {
  581. prop_name: self.get_value(getattr(self, prop_name))
  582. for prop_name in self.base_vars
  583. }
  584. computed_vars = (
  585. {
  586. # Include the computed vars.
  587. prop_name: self.get_value(getattr(self, prop_name))
  588. for prop_name in self.computed_vars
  589. }
  590. if include_computed
  591. else {}
  592. )
  593. substate_vars = {
  594. k: v.dict(include_computed=include_computed, **kwargs)
  595. for k, v in self.substates.items()
  596. }
  597. variables = {**base_vars, **computed_vars, **substate_vars}
  598. return {k: variables[k] for k in sorted(variables)}
  599. class DefaultState(State):
  600. """The default empty state."""
  601. pass
  602. class StateUpdate(Base):
  603. """A state update sent to the frontend."""
  604. # The state delta.
  605. delta: Delta = {}
  606. # Events to be added to the event queue.
  607. events: List[Event] = []
  608. class StateManager(Base):
  609. """A class to manage many client states."""
  610. # The state class to use.
  611. state: Type[State] = DefaultState
  612. # The mapping of client ids to states.
  613. states: Dict[str, State] = {}
  614. # The token expiration time (s).
  615. token_expiration: int = constants.TOKEN_EXPIRATION
  616. # The redis client to use.
  617. redis: Optional[Redis] = None
  618. def setup(self, state: Type[State]):
  619. """Set up the state manager.
  620. Args:
  621. state: The state class to use.
  622. """
  623. self.state = state
  624. self.redis = prerequisites.get_redis()
  625. def get_state(self, token: str) -> State:
  626. """Get the state for a token.
  627. Args:
  628. token: The token to get the state for.
  629. Returns:
  630. The state for the token.
  631. """
  632. if self.redis is not None:
  633. redis_state = self.redis.get(token)
  634. if redis_state is None:
  635. self.set_state(token, self.state())
  636. return self.get_state(token)
  637. return cloudpickle.loads(redis_state)
  638. if token not in self.states:
  639. self.states[token] = self.state()
  640. return self.states[token]
  641. def set_state(self, token: str, state: State):
  642. """Set the state for a token.
  643. Args:
  644. token: The token to set the state for.
  645. state: The state to set.
  646. """
  647. if self.redis is None:
  648. return
  649. self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
  650. def _convert_mutable_datatypes(
  651. field_value: Any, reassign_field: Callable, field_name: str
  652. ) -> Any:
  653. """Recursively convert mutable data to the Pc data types.
  654. Note: right now only list & dict would be handled recursively.
  655. Args:
  656. field_value: The target field_value.
  657. reassign_field:
  658. The function to reassign the field in the parent state.
  659. field_name: the name of the field in the parent state
  660. Returns:
  661. The converted field_value
  662. """
  663. if isinstance(field_value, list):
  664. for index in range(len(field_value)):
  665. field_value[index] = _convert_mutable_datatypes(
  666. field_value[index], reassign_field, field_name
  667. )
  668. field_value = PCList(
  669. field_value, reassign_field=reassign_field, field_name=field_name
  670. )
  671. if isinstance(field_value, dict):
  672. for key, value in field_value.items():
  673. field_value[key] = _convert_mutable_datatypes(
  674. value, reassign_field, field_name
  675. )
  676. field_value = PCDict(
  677. field_value, reassign_field=reassign_field, field_name=field_name
  678. )
  679. return field_value
  680. def _get_previous_recursive_frame_info() -> (
  681. Tuple[Optional[inspect.FrameInfo], Dict[str, Any]]
  682. ):
  683. """Find the previous frame of the same function that calls this helper.
  684. For example, if this function is called from `State.__getattribute__`
  685. (parent frame), then the returned frame will be the next earliest call
  686. of the same function.
  687. Returns:
  688. Tuple of (frame_info, local_vars)
  689. If no previous recursive frame is found up the stack, the frame info will be None.
  690. """
  691. _this_frame, parent_frame, *prev_frames = inspect.stack()
  692. for frame in prev_frames:
  693. if frame.frame.f_code == parent_frame.frame.f_code:
  694. return frame, frame.frame.f_locals
  695. return None, {}