refreshable.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from __future__ import annotations
  2. from dataclasses import dataclass, field
  3. from typing import Any, Awaitable, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, TypeVar, Union, cast
  4. from typing_extensions import Concatenate, ParamSpec, Self
  5. from .. import background_tasks, core
  6. from ..client import Client
  7. from ..dataclasses import KWONLY_SLOTS
  8. from ..element import Element
  9. from ..helpers import is_coroutine_function
  10. _S = TypeVar('_S')
  11. _T = TypeVar('_T')
  12. _P = ParamSpec('_P')
  13. @dataclass(**KWONLY_SLOTS)
  14. class RefreshableTarget:
  15. container: RefreshableContainer
  16. refreshable: refreshable
  17. instance: Any
  18. args: Tuple[Any, ...]
  19. kwargs: Dict[str, Any]
  20. current_target: ClassVar[Optional[RefreshableTarget]] = None
  21. locals: List[Any] = field(default_factory=list)
  22. next_index: int = 0
  23. def run(self, func: Callable[..., Union[_T, Awaitable[_T]]]) -> Union[_T, Awaitable[_T]]:
  24. """Run the function and return the result."""
  25. RefreshableTarget.current_target = self
  26. self.next_index = 0
  27. # pylint: disable=no-else-return
  28. if is_coroutine_function(func):
  29. async def wait_for_result() -> Any:
  30. with self.container:
  31. if self.instance is None:
  32. result = func(*self.args, **self.kwargs)
  33. else:
  34. result = func(self.instance, *self.args, **self.kwargs)
  35. assert isinstance(result, Awaitable)
  36. return await result
  37. return wait_for_result()
  38. else:
  39. with self.container:
  40. if self.instance is None:
  41. return func(*self.args, **self.kwargs)
  42. else:
  43. return func(self.instance, *self.args, **self.kwargs)
  44. class RefreshableContainer(Element, component='refreshable.js'):
  45. pass
  46. class refreshable(Generic[_P, _T]):
  47. def __init__(self, func: Callable[_P, Union[_T, Awaitable[_T]]]) -> None:
  48. """Refreshable UI functions
  49. The ``@ui.refreshable`` decorator allows you to create functions that have a ``refresh`` method.
  50. This method will automatically delete all elements created by the function and recreate them.
  51. For decorating refreshable methods in classes, there is a ``@ui.refreshable_method`` decorator,
  52. which is equivalent but prevents static type checking errors.
  53. """
  54. self.func = func
  55. self.instance = None
  56. self.targets: List[RefreshableTarget] = []
  57. def __get__(self, instance, _) -> Self:
  58. self.instance = instance
  59. return self
  60. def __getattribute__(self, __name: str) -> Any:
  61. attribute = object.__getattribute__(self, __name)
  62. if __name == 'refresh':
  63. def refresh(*args: Any, _instance=self.instance, **kwargs: Any) -> None:
  64. self.instance = _instance
  65. attribute(*args, **kwargs)
  66. return refresh
  67. return attribute
  68. def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Union[_T, Awaitable[_T]]:
  69. self.prune()
  70. target = RefreshableTarget(container=RefreshableContainer(), refreshable=self, instance=self.instance,
  71. args=args, kwargs=kwargs)
  72. self.targets.append(target)
  73. return target.run(self.func)
  74. def refresh(self, *args: Any, **kwargs: Any) -> None:
  75. """Refresh the UI elements created by this function.
  76. This method accepts the same arguments as the function itself or a subset of them.
  77. It will combine the arguments passed to the function with the arguments passed to this method.
  78. """
  79. self.prune()
  80. for target in self.targets:
  81. if target.instance != self.instance:
  82. continue
  83. target.container.clear()
  84. target.args = args or target.args
  85. target.kwargs.update(kwargs)
  86. try:
  87. result = target.run(self.func)
  88. except TypeError as e:
  89. if 'got multiple values for argument' in str(e):
  90. function = str(e).split()[0].split('.')[-1]
  91. parameter = str(e).split()[-1]
  92. raise TypeError(f'{parameter} needs to be consistently passed to {function} '
  93. 'either as positional or as keyword argument') from e
  94. raise
  95. if is_coroutine_function(self.func):
  96. assert isinstance(result, Awaitable)
  97. if core.loop and core.loop.is_running():
  98. background_tasks.create(result)
  99. else:
  100. core.app.on_startup(result)
  101. def prune(self) -> None:
  102. """Remove all targets that are no longer on a page with a client connection.
  103. This method is called automatically before each refresh.
  104. """
  105. self.targets = [
  106. target
  107. for target in self.targets
  108. if target.container.client.id in Client.instances and target.container.id in target.container.client.elements
  109. ]
  110. class refreshable_method(Generic[_S, _P, _T], refreshable[_P, _T]):
  111. def __init__(self, func: Callable[Concatenate[_S, _P], Union[_T, Awaitable[_T]]]) -> None:
  112. """Refreshable UI methods
  113. The `@ui.refreshable_method` decorator allows you to create methods that have a `refresh` method.
  114. This method will automatically delete all elements created by the function and recreate them.
  115. """
  116. super().__init__(func) # type: ignore
  117. def state(value: Any) -> Tuple[Any, Callable[[Any], None]]:
  118. """Create a state variable that automatically updates its refreshable UI container.
  119. :param value: The initial value of the state variable.
  120. :return: A tuple containing the current value and a function to update the value.
  121. """
  122. target = cast(RefreshableTarget, RefreshableTarget.current_target)
  123. if target.next_index >= len(target.locals):
  124. target.locals.append(value)
  125. else:
  126. value = target.locals[target.next_index]
  127. def set_value(new_value: Any, index=target.next_index) -> None:
  128. if target.locals[index] == new_value:
  129. return
  130. target.locals[index] = new_value
  131. target.refreshable.refresh()
  132. target.next_index += 1
  133. return value, set_value