state.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756
  1. """Define the pynecone state specification."""
  2. from __future__ import annotations
  3. import asyncio
  4. import functools
  5. import traceback
  6. from abc import ABC
  7. from typing import (
  8. Any,
  9. Callable,
  10. ClassVar,
  11. Dict,
  12. List,
  13. Optional,
  14. Sequence,
  15. Set,
  16. Type,
  17. Union,
  18. )
  19. import cloudpickle
  20. from redis import Redis
  21. from pynecone import constants, utils
  22. from pynecone.base import Base
  23. from pynecone.event import Event, EventHandler, window_alert
  24. from pynecone.var import BaseVar, ComputedVar, PCDict, PCList, Var
  25. Delta = Dict[str, Any]
  26. class State(Base, ABC):
  27. """The state of the app."""
  28. # A map from the var name to the var.
  29. vars: ClassVar[Dict[str, Var]] = {}
  30. # The base vars of the class.
  31. base_vars: ClassVar[Dict[str, BaseVar]] = {}
  32. # The computed vars of the class.
  33. computed_vars: ClassVar[Dict[str, ComputedVar]] = {}
  34. # Vars inherited by the parent state.
  35. inherited_vars: ClassVar[Dict[str, Var]] = {}
  36. # Backend vars that are never sent to the client.
  37. backend_vars: ClassVar[Dict[str, Any]] = {}
  38. # Backend vars inherited
  39. inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
  40. # The parent state.
  41. parent_state: Optional[State] = None
  42. # The substates of the state.
  43. substates: Dict[str, State] = {}
  44. # The set of dirty vars.
  45. dirty_vars: Set[str] = set()
  46. # The set of dirty substates.
  47. dirty_substates: Set[str] = set()
  48. # The routing path that triggered the state
  49. router_data: Dict[str, Any] = {}
  50. def __init__(self, *args, **kwargs):
  51. """Initialize the state.
  52. Args:
  53. *args: The args to pass to the Pydantic init method.
  54. **kwargs: The kwargs to pass to the Pydantic init method.
  55. """
  56. super().__init__(*args, **kwargs)
  57. # Setup the substates.
  58. for substate in self.get_substates():
  59. self.substates[substate.get_name()] = substate().set(parent_state=self)
  60. self._init_mutable_fields()
  61. def _init_mutable_fields(self):
  62. """Initialize mutable fields.
  63. So that mutation to them can be detected by the app:
  64. * list
  65. """
  66. for field in self.base_vars.values():
  67. value = getattr(self, field.name)
  68. value_in_pc_data = _convert_mutable_datatypes(
  69. value, self._reassign_field, field.name
  70. )
  71. if utils._issubclass(field.type_, Union[List, Dict]):
  72. setattr(self, field.name, value_in_pc_data)
  73. self.clean()
  74. def _reassign_field(self, field_name: str):
  75. """Reassign the given field.
  76. Primarily for mutation in fields of mutable data types.
  77. Args:
  78. field_name: The name of the field we want to reassign
  79. """
  80. setattr(
  81. self,
  82. field_name,
  83. getattr(self, field_name),
  84. )
  85. def __repr__(self) -> str:
  86. """Get the string representation of the state.
  87. Returns:
  88. The string representation of the state.
  89. """
  90. return f"{self.__class__.__name__}({self.dict()})"
  91. @classmethod
  92. def __init_subclass__(cls, **kwargs):
  93. """Do some magic for the subclass initialization.
  94. Args:
  95. **kwargs: The kwargs to pass to the pydantic init_subclass method.
  96. """
  97. super().__init_subclass__(**kwargs)
  98. # Get the parent vars.
  99. parent_state = cls.get_parent_state()
  100. if parent_state is not None:
  101. cls.inherited_vars = parent_state.vars
  102. cls.inherited_backend_vars = parent_state.backend_vars
  103. cls.new_backend_vars = {
  104. name: value
  105. for name, value in cls.__dict__.items()
  106. if utils.is_backend_variable(name)
  107. and name not in cls.inherited_backend_vars
  108. }
  109. cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
  110. # Set the base and computed vars.
  111. skip_vars = set(cls.inherited_vars) | {
  112. "parent_state",
  113. "substates",
  114. "dirty_vars",
  115. "dirty_substates",
  116. "router_data",
  117. }
  118. cls.base_vars = {
  119. f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls)
  120. for f in cls.get_fields().values()
  121. if f.name not in skip_vars
  122. }
  123. cls.computed_vars = {
  124. v.name: v.set_state(cls)
  125. for v in cls.__dict__.values()
  126. if isinstance(v, ComputedVar)
  127. }
  128. cls.vars = {
  129. **cls.inherited_vars,
  130. **cls.base_vars,
  131. **cls.computed_vars,
  132. }
  133. # Setup the base vars at the class level.
  134. for prop in cls.base_vars.values():
  135. cls._init_var(prop)
  136. # Set up the event handlers.
  137. events = {
  138. name: fn
  139. for name, fn in cls.__dict__.items()
  140. if not name.startswith("_") and isinstance(fn, Callable)
  141. }
  142. for name, fn in events.items():
  143. event_handler = EventHandler(fn=fn)
  144. setattr(cls, name, event_handler)
  145. @classmethod
  146. @functools.lru_cache()
  147. def get_parent_state(cls) -> Optional[Type[State]]:
  148. """Get the parent state.
  149. Returns:
  150. The parent state.
  151. """
  152. parent_states = [
  153. base
  154. for base in cls.__bases__
  155. if utils._issubclass(base, State) and base is not State
  156. ]
  157. assert len(parent_states) < 2, "Only one parent state is allowed."
  158. return parent_states[0] if len(parent_states) == 1 else None # type: ignore
  159. @classmethod
  160. @functools.lru_cache()
  161. def get_substates(cls) -> Set[Type[State]]:
  162. """Get the substates of the state.
  163. Returns:
  164. The substates of the state.
  165. """
  166. return set(cls.__subclasses__())
  167. @classmethod
  168. @functools.lru_cache()
  169. def get_name(cls) -> str:
  170. """Get the name of the state.
  171. Returns:
  172. The name of the state.
  173. """
  174. return utils.to_snake_case(cls.__name__)
  175. @classmethod
  176. @functools.lru_cache()
  177. def get_full_name(cls) -> str:
  178. """Get the full name of the state.
  179. Returns:
  180. The full name of the state.
  181. """
  182. name = cls.get_name()
  183. parent_state = cls.get_parent_state()
  184. if parent_state is not None:
  185. name = ".".join((parent_state.get_full_name(), name))
  186. return name
  187. @classmethod
  188. @functools.lru_cache()
  189. def get_class_substate(cls, path: Sequence[str]) -> Type[State]:
  190. """Get the class substate.
  191. Args:
  192. path: The path to the substate.
  193. Returns:
  194. The class substate.
  195. Raises:
  196. ValueError: If the substate is not found.
  197. """
  198. if len(path) == 0:
  199. return cls
  200. if path[0] == cls.get_name():
  201. if len(path) == 1:
  202. return cls
  203. path = path[1:]
  204. for substate in cls.get_substates():
  205. if path[0] == substate.get_name():
  206. return substate.get_class_substate(path[1:])
  207. raise ValueError(f"Invalid path: {path}")
  208. @classmethod
  209. def get_class_var(cls, path: Sequence[str]) -> Any:
  210. """Get the class var.
  211. Args:
  212. path: The path to the var.
  213. Returns:
  214. The class var.
  215. Raises:
  216. ValueError: If the path is invalid.
  217. """
  218. path, name = path[:-1], path[-1]
  219. substate = cls.get_class_substate(tuple(path))
  220. if not hasattr(substate, name):
  221. raise ValueError(f"Invalid path: {path}")
  222. return getattr(substate, name)
  223. @classmethod
  224. def _init_var(cls, prop: BaseVar):
  225. """Initialize a variable.
  226. Args:
  227. prop (BaseVar): The variable to initialize
  228. Raises:
  229. TypeError: if the variable has an incorrect type
  230. """
  231. if not utils.is_valid_var_type(prop.type_):
  232. raise TypeError(
  233. "State vars must be primitive Python types, "
  234. "Plotly figures, Pandas dataframes, "
  235. "or subclasses of pc.Base. "
  236. f'Found var "{prop.name}" with type {prop.type_}.'
  237. )
  238. cls._set_var(prop)
  239. cls._create_setter(prop)
  240. cls._set_default_value(prop)
  241. @classmethod
  242. def add_var(cls, name: str, type_: Any, default_value: Any = None):
  243. """Add dynamically a variable to the State.
  244. The variable added this way can be used in the same way as a variable
  245. defined statically in the model.
  246. Args:
  247. name: The name of the variable
  248. type_: The type of the variable
  249. default_value: The default value of the variable
  250. Raises:
  251. NameError: if a variable of this name already exists
  252. """
  253. if name in cls.__fields__:
  254. raise NameError(
  255. f"The variable '{name}' already exist. Use a different name"
  256. )
  257. # create the variable based on name and type
  258. var = BaseVar(name=name, type_=type_)
  259. var.set_state(cls)
  260. # add the pydantic field dynamically (must be done before _init_var)
  261. cls.add_field(var, default_value)
  262. cls._init_var(var)
  263. # update the internal dicts so the new variable is correctly handled
  264. cls.base_vars.update({name: var})
  265. cls.vars.update({name: var})
  266. @classmethod
  267. def _set_var(cls, prop: BaseVar):
  268. """Set the var as a class member.
  269. Args:
  270. prop: The var instance to set.
  271. """
  272. setattr(cls, prop.name, prop)
  273. @classmethod
  274. def _create_setter(cls, prop: BaseVar):
  275. """Create a setter for the var.
  276. Args:
  277. prop: The var to create a setter for.
  278. """
  279. setter_name = prop.get_setter_name(include_state=False)
  280. if setter_name not in cls.__dict__:
  281. setattr(cls, setter_name, prop.get_setter())
  282. @classmethod
  283. def _set_default_value(cls, prop: BaseVar):
  284. """Set the default value for the var.
  285. Args:
  286. prop: The var to set the default value for.
  287. """
  288. # Get the pydantic field for the var.
  289. field = cls.get_fields()[prop.name]
  290. default_value = prop.get_default_value()
  291. if field.required and default_value is not None:
  292. field.required = False
  293. field.default = default_value
  294. def get_token(self) -> str:
  295. """Return the token of the client associated with this state.
  296. Returns:
  297. The token of the client.
  298. """
  299. return self.router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
  300. def get_sid(self) -> str:
  301. """Return the session ID of the client associated with this state.
  302. Returns:
  303. The session ID of the client.
  304. """
  305. return self.router_data.get(constants.RouteVar.SESSION_ID, "")
  306. def get_headers(self) -> Dict:
  307. """Return the headers of the client associated with this state.
  308. Returns:
  309. The headers of the client.
  310. """
  311. return self.router_data.get(constants.RouteVar.HEADERS, {})
  312. def get_client_ip(self) -> str:
  313. """Return the IP of the client associated with this state.
  314. Returns:
  315. The IP of the client.
  316. """
  317. return self.router_data.get(constants.RouteVar.CLIENT_IP, "")
  318. def get_current_page(self) -> str:
  319. """Obtain the path of current page from the router data.
  320. Returns:
  321. The current page.
  322. """
  323. return self.router_data.get(constants.RouteVar.PATH, "")
  324. def get_query_params(self) -> Dict[str, str]:
  325. """Obtain the query parameters for the queried page.
  326. The query object contains both the URI parameters and the GET parameters.
  327. Returns:
  328. The dict of query parameters.
  329. """
  330. return self.router_data.get(constants.RouteVar.QUERY, {})
  331. @classmethod
  332. def setup_dynamic_args(cls, args: dict[str, str]):
  333. """Set up args for easy access in renderer.
  334. Args:
  335. args: a dict of args
  336. """
  337. def argsingle_factory(param):
  338. @ComputedVar
  339. def inner_func(self) -> str:
  340. return self.get_query_params().get(param, "")
  341. return inner_func
  342. def arglist_factory(param):
  343. @ComputedVar
  344. def inner_func(self) -> List:
  345. return self.get_query_params().get(param, [])
  346. return inner_func
  347. for param, value in args.items():
  348. if value == constants.RouteArgType.SINGLE:
  349. func = argsingle_factory(param)
  350. elif value == constants.RouteArgType.LIST:
  351. func = arglist_factory(param)
  352. else:
  353. continue
  354. cls.computed_vars[param] = func.set_state(cls) # type: ignore
  355. setattr(cls, param, func)
  356. def __getattribute__(self, name: str) -> Any:
  357. """Get the state var.
  358. If the var is inherited, get the var from the parent state.
  359. Args:
  360. name: The name of the var.
  361. Returns:
  362. The value of the var.
  363. """
  364. inherited_vars = {
  365. **super().__getattribute__("inherited_vars"),
  366. **super().__getattribute__("inherited_backend_vars"),
  367. }
  368. if name in inherited_vars:
  369. return getattr(super().__getattribute__("parent_state"), name)
  370. elif name in super().__getattribute__("backend_vars"):
  371. return super().__getattribute__("backend_vars").__getitem__(name)
  372. return super().__getattribute__(name)
  373. def __setattr__(self, name: str, value: Any):
  374. """Set the attribute.
  375. If the attribute is inherited, set the attribute on the parent state.
  376. Args:
  377. name: The name of the attribute.
  378. value: The value of the attribute.
  379. """
  380. # Set the var on the parent state.
  381. inherited_vars = {**self.inherited_vars, **self.inherited_backend_vars}
  382. if name in inherited_vars:
  383. setattr(self.parent_state, name, value)
  384. return
  385. if utils.is_backend_variable(name):
  386. self.backend_vars.__setitem__(name, value)
  387. self.mark_dirty()
  388. return
  389. # Set the attribute.
  390. super().__setattr__(name, value)
  391. # Add the var to the dirty list.
  392. if name in self.vars:
  393. self.dirty_vars.add(name)
  394. self.mark_dirty()
  395. def reset(self):
  396. """Reset all the base vars to their default values."""
  397. # Reset the base vars.
  398. fields = self.get_fields()
  399. for prop_name in self.base_vars:
  400. setattr(self, prop_name, fields[prop_name].default)
  401. # Recursively reset the substates.
  402. for substate in self.substates.values():
  403. substate.reset()
  404. # Clean the state.
  405. self.clean()
  406. def get_substate(self, path: Sequence[str]) -> Optional[State]:
  407. """Get the substate.
  408. Args:
  409. path: The path to the substate.
  410. Returns:
  411. The substate.
  412. Raises:
  413. ValueError: If the substate is not found.
  414. """
  415. if len(path) == 0:
  416. return self
  417. if path[0] == self.get_name():
  418. if len(path) == 1:
  419. return self
  420. path = path[1:]
  421. if path[0] not in self.substates:
  422. raise ValueError(f"Invalid path: {path}")
  423. return self.substates[path[0]].get_substate(path[1:])
  424. async def process(self, event: Event) -> StateUpdate:
  425. """Process an event.
  426. Args:
  427. event: The event to process.
  428. Returns:
  429. The state update after processing the event.
  430. """
  431. # Get the event handler.
  432. path = event.name.split(".")
  433. path, name = path[:-1], path[-1]
  434. substate = self.get_substate(path)
  435. handler = getattr(substate, name)
  436. # Process the event.
  437. fn = functools.partial(handler.fn, substate)
  438. try:
  439. if asyncio.iscoroutinefunction(fn.func):
  440. events = await fn(**event.payload)
  441. else:
  442. events = fn(**event.payload)
  443. except Exception:
  444. error = traceback.format_exc()
  445. print(error)
  446. events = utils.fix_events(
  447. [window_alert("An error occurred. See logs for details.")], event.token
  448. )
  449. return StateUpdate(events=events)
  450. # Fix the returned events.
  451. events = utils.fix_events(events, event.token)
  452. # Get the delta after processing the event.
  453. delta = self.get_delta()
  454. # Reset the dirty vars.
  455. self.clean()
  456. # Return the state update.
  457. return StateUpdate(delta=delta, events=events)
  458. def get_delta(self) -> Delta:
  459. """Get the delta for the state.
  460. Returns:
  461. The delta for the state.
  462. """
  463. delta = {}
  464. # Return the dirty vars, as well as all computed vars.
  465. subdelta = {
  466. prop: getattr(self, prop)
  467. for prop in self.dirty_vars | self.computed_vars.keys()
  468. }
  469. if len(subdelta) > 0:
  470. delta[self.get_full_name()] = subdelta
  471. # Recursively find the substate deltas.
  472. substates = self.substates
  473. for substate in self.dirty_substates:
  474. delta.update(substates[substate].get_delta())
  475. # Format the delta.
  476. delta = utils.format_state(delta)
  477. # Return the delta.
  478. return delta
  479. def mark_dirty(self):
  480. """Mark the substate and all parent states as dirty."""
  481. if self.parent_state is not None:
  482. self.parent_state.dirty_substates.add(self.get_name())
  483. self.parent_state.mark_dirty()
  484. def clean(self):
  485. """Reset the dirty vars."""
  486. # Recursively clean the substates.
  487. for substate in self.dirty_substates:
  488. self.substates[substate].clean()
  489. # Clean this state.
  490. self.dirty_vars = set()
  491. self.dirty_substates = set()
  492. def dict(self, include_computed: bool = True, **kwargs) -> Dict[str, Any]:
  493. """Convert the object to a dictionary.
  494. Args:
  495. include_computed: Whether to include computed vars.
  496. **kwargs: Kwargs to pass to the pydantic dict method.
  497. Returns:
  498. The object as a dictionary.
  499. """
  500. base_vars = {
  501. prop_name: self.get_value(getattr(self, prop_name))
  502. for prop_name in self.base_vars
  503. }
  504. computed_vars = (
  505. {
  506. # Include the computed vars.
  507. prop_name: self.get_value(getattr(self, prop_name))
  508. for prop_name in self.computed_vars
  509. }
  510. if include_computed
  511. else {}
  512. )
  513. substate_vars = {
  514. k: v.dict(include_computed=include_computed, **kwargs)
  515. for k, v in self.substates.items()
  516. }
  517. variables = {**base_vars, **computed_vars, **substate_vars}
  518. return {k: variables[k] for k in sorted(variables)}
  519. class DefaultState(State):
  520. """The default empty state."""
  521. pass
  522. class StateUpdate(Base):
  523. """A state update sent to the frontend."""
  524. # The state delta.
  525. delta: Delta = {}
  526. # Events to be added to the event queue.
  527. events: List[Event] = []
  528. class StateManager(Base):
  529. """A class to manage many client states."""
  530. # The state class to use.
  531. state: Type[State] = DefaultState
  532. # The mapping of client ids to states.
  533. states: Dict[str, State] = {}
  534. # The token expiration time (s).
  535. token_expiration: int = constants.TOKEN_EXPIRATION
  536. # The redis client to use.
  537. redis: Optional[Redis] = None
  538. def setup(self, state: Type[State]):
  539. """Set up the state manager.
  540. Args:
  541. state: The state class to use.
  542. """
  543. self.state = state
  544. self.redis = utils.get_redis()
  545. def get_state(self, token: str) -> State:
  546. """Get the state for a token.
  547. Args:
  548. token: The token to get the state for.
  549. Returns:
  550. The state for the token.
  551. """
  552. if self.redis is not None:
  553. redis_state = self.redis.get(token)
  554. if redis_state is None:
  555. self.set_state(token, self.state())
  556. return self.get_state(token)
  557. return cloudpickle.loads(redis_state)
  558. if token not in self.states:
  559. self.states[token] = self.state()
  560. return self.states[token]
  561. def set_state(self, token: str, state: State):
  562. """Set the state for a token.
  563. Args:
  564. token: The token to set the state for.
  565. state: The state to set.
  566. """
  567. if self.redis is None:
  568. return
  569. self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
  570. def _convert_mutable_datatypes(
  571. field_value: Any, reassign_field: Callable, field_name: str
  572. ) -> Any:
  573. """Recursively convert mutable data to the Pc data types.
  574. Note: right now only list & dict would be handled recursively.
  575. Args:
  576. field_value: The target field_value.
  577. reassign_field:
  578. The function to reassign the field in the parent state.
  579. field_name: the name of the field in the parent state
  580. Returns:
  581. The converted field_value
  582. """
  583. if isinstance(field_value, list):
  584. for index in range(len(field_value)):
  585. field_value[index] = _convert_mutable_datatypes(
  586. field_value[index], reassign_field, field_name
  587. )
  588. field_value = PCList(
  589. field_value, reassign_field=reassign_field, field_name=field_name
  590. )
  591. if isinstance(field_value, dict):
  592. for key, value in field_value.items():
  593. field_value[key] = _convert_mutable_datatypes(
  594. value, reassign_field, field_name
  595. )
  596. field_value = PCDict(
  597. field_value, reassign_field=reassign_field, field_name=field_name
  598. )
  599. return field_value