progress.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. """A module that provides a progress bar for the terminal."""
  2. import dataclasses
  3. import time
  4. from typing import Callable, Sequence
  5. from reflex.utils.console import Reprinter, _get_terminal_width
  6. reprinter = Reprinter()
  7. @dataclasses.dataclass(kw_only=True)
  8. class ProgressBarComponent:
  9. """A protocol for progress bar components."""
  10. colorer: Callable[[str], str] = lambda x: x
  11. def minimum_width(self, current: int, steps: int) -> int:
  12. """Return the minimum width of the component.
  13. Args:
  14. current: The current step.
  15. steps: The total number of steps.
  16. """
  17. ...
  18. def requested_width(self, current: int, steps: int) -> int:
  19. """Return the requested width of the component.
  20. Args:
  21. current: The current step.
  22. steps: The total number of steps.
  23. """
  24. ...
  25. def initialize(self, steps: int) -> None:
  26. """Initialize the component.
  27. Args:
  28. steps: The total number of steps.
  29. """
  30. ...
  31. def get_message(self, current: int, steps: int, max_width: int) -> str:
  32. """Return the message to display.
  33. Args:
  34. current: The current step.
  35. steps: The total number of steps.
  36. max_width: The maximum width of the component.
  37. """
  38. ...
  39. @dataclasses.dataclass
  40. class MessageComponent(ProgressBarComponent):
  41. """A simple component that displays a message."""
  42. message: str = ""
  43. def minimum_width(self, current: int, steps: int) -> int:
  44. """Return the minimum width of the component.
  45. Args:
  46. current: The current step.
  47. steps: The total number of steps.
  48. Returns:
  49. The minimum width of the component.
  50. """
  51. return len(self.message)
  52. def requested_width(self, current: int, steps: int) -> int:
  53. """Return the requested width of the component.
  54. Args:
  55. current: The current step.
  56. steps: The total number of steps.
  57. Returns:
  58. The requested width of the component.
  59. """
  60. return len(self.message)
  61. def initialize(self, steps: int) -> None:
  62. """Initialize the component.
  63. Args:
  64. steps: The total number of steps.
  65. """
  66. def get_message(self, current: int, steps: int, max_width: int) -> str:
  67. """Return the message to display.
  68. Args:
  69. current: The current step.
  70. steps: The total number of steps.
  71. max_width: The maximum width of the component.
  72. Returns:
  73. The message to display.
  74. """
  75. return self.message
  76. @dataclasses.dataclass
  77. class PercentageComponent(ProgressBarComponent):
  78. """A component that displays the percentage of completion."""
  79. def minimum_width(self, current: int, steps: int) -> int:
  80. """Return the minimum width of the component.
  81. Args:
  82. current: The current step.
  83. steps: The total number of steps.
  84. Returns:
  85. The minimum width of the component.
  86. """
  87. return 4
  88. def requested_width(self, current: int, steps: int) -> int:
  89. """Return the requested width of the component.
  90. Args:
  91. current: The current step.
  92. steps: The total number of steps.
  93. Returns:
  94. The requested width of the component.
  95. """
  96. return 4
  97. def initialize(self, steps: int) -> None:
  98. """Initialize the component.
  99. Args:
  100. steps: The total number of steps.
  101. """
  102. def get_message(self, current: int, steps: int, max_width: int) -> str:
  103. """Return the message to display.
  104. Args:
  105. current: The current step.
  106. steps: The total number of steps.
  107. max_width: The maximum width of the component.
  108. Returns:
  109. The message to display.
  110. """
  111. return f"{int(current / steps * 100):3}%"
  112. @dataclasses.dataclass
  113. class TimeComponent(ProgressBarComponent):
  114. """A component that displays the time elapsed."""
  115. initial_time: float | None = None
  116. _cached_time: float | None = dataclasses.field(default=None, init=False)
  117. def _minimum_and_requested_string(
  118. self, current: int, steps: int
  119. ) -> tuple[str, str]:
  120. """Return the minimum and requested string length of the component.
  121. Args:
  122. current: The current step.
  123. steps: The total number of steps.
  124. Returns:
  125. The minimum and requested string length of the component.
  126. Raises:
  127. ValueError: If the component is not initialized.
  128. """
  129. if self.initial_time is None or self._cached_time is None:
  130. raise ValueError("TimeComponent not initialized")
  131. return (
  132. f"{int(self._cached_time - self.initial_time)!s}s",
  133. f"{int((self._cached_time - self.initial_time) * 1000)!s}ms",
  134. )
  135. def minimum_width(self, current: int, steps: int) -> int:
  136. """Return the minimum width of the component.
  137. Args:
  138. current: The current step.
  139. steps: The total number of steps.
  140. Returns:
  141. The minimum width of the component.
  142. Raises:
  143. ValueError: If the component is not initialized.
  144. """
  145. if self.initial_time is None:
  146. raise ValueError("TimeComponent not initialized")
  147. self._cached_time = time.monotonic()
  148. _min, _ = self._minimum_and_requested_string(current, steps)
  149. return len(_min)
  150. def requested_width(self, current: int, steps: int) -> int:
  151. """Return the requested width of the component.
  152. Args:
  153. current: The current step.
  154. steps: The total number of steps.
  155. Returns:
  156. The requested width of the component.
  157. Raises:
  158. ValueError: If the component is not initialized.
  159. """
  160. if self.initial_time is None:
  161. raise ValueError("TimeComponent not initialized")
  162. _, _req = self._minimum_and_requested_string(current, steps)
  163. return len(_req)
  164. def initialize(self, steps: int) -> None:
  165. """Initialize the component.
  166. Args:
  167. steps: The total number of steps.
  168. """
  169. self.initial_time = time.monotonic()
  170. def get_message(self, current: int, steps: int, max_width: int) -> str:
  171. """Return the message to display.
  172. Args:
  173. current: The current step.
  174. steps: The total number of steps.
  175. max_width: The maximum width of the component.
  176. Returns:
  177. The message to display.
  178. Raises:
  179. ValueError: If the component is not initialized.
  180. """
  181. if self.initial_time is None:
  182. raise ValueError("TimeComponent not initialized")
  183. _min, _req = self._minimum_and_requested_string(current, steps)
  184. if len(_req) <= max_width:
  185. return _req
  186. return _min
  187. @dataclasses.dataclass
  188. class CounterComponent(ProgressBarComponent):
  189. """A component that displays the current step and total steps."""
  190. def minimum_width(self, current: int, steps: int) -> int:
  191. """Return the minimum width of the component.
  192. Args:
  193. current: The current step.
  194. steps: The total number of steps.
  195. Returns:
  196. The minimum width of the component.
  197. """
  198. return 1 + 2 * len(str(steps))
  199. def requested_width(self, current: int, steps: int) -> int:
  200. """Return the requested width of the component.
  201. Args:
  202. current: The current step.
  203. steps: The total number of steps.
  204. Returns:
  205. The requested width of the component.
  206. """
  207. return 1 + 2 * len(str(steps))
  208. def initialize(self, steps: int) -> None:
  209. """Initialize the component.
  210. Args:
  211. steps: The total number of steps.
  212. """
  213. def get_message(self, current: int, steps: int, max_width: int) -> str:
  214. """Return the message to display.
  215. Args:
  216. current: The current step.
  217. steps: The total number of steps.
  218. max_width: The maximum width of the component.
  219. Returns:
  220. The message to display.
  221. """
  222. return current.__format__(f"{len(str(steps))}") + "/" + str(steps)
  223. @dataclasses.dataclass
  224. class SimpleProgressComponent(ProgressBarComponent):
  225. """A component that displays a not so fun guy."""
  226. starting_str: str = ""
  227. ending_str: str = ""
  228. complete_str: str = "█"
  229. incomplete_str: str = "░"
  230. def minimum_width(self, current: int, steps: int) -> int:
  231. """Return the minimum width of the component.
  232. Args:
  233. current: The current step.
  234. steps: The total number of steps.
  235. Returns:
  236. The minimum width of the component.
  237. """
  238. return (
  239. len(self.starting_str)
  240. + 2 * len(self.incomplete_str)
  241. + 2 * len(self.complete_str)
  242. + len(self.ending_str)
  243. )
  244. def requested_width(self, current: int, steps: int) -> int:
  245. """Return the requested width of the component.
  246. Args:
  247. current: The current step.
  248. steps: The total number of steps.
  249. Returns:
  250. The requested width of the component.
  251. """
  252. return (
  253. len(self.starting_str)
  254. + steps * max(len(self.incomplete_str), len(self.complete_str))
  255. + len(self.ending_str)
  256. )
  257. def initialize(self, steps: int) -> None:
  258. """Initialize the component.
  259. Args:
  260. steps: The total number of steps.
  261. """
  262. def get_message(self, current: int, steps: int, max_width: int) -> str:
  263. """Return the message to display.
  264. Args:
  265. current: The current step.
  266. steps: The total number of steps.
  267. max_width: The maximum width of the component.
  268. Returns:
  269. The message to display.
  270. """
  271. progress = int(
  272. current
  273. / steps
  274. * (max_width - len(self.starting_str) - len(self.ending_str))
  275. )
  276. complete_part = self.complete_str * (progress // len(self.complete_str))
  277. incomplete_part = self.incomplete_str * (
  278. (
  279. max_width
  280. - len(self.starting_str)
  281. - len(self.ending_str)
  282. - len(complete_part)
  283. )
  284. // len(self.incomplete_str)
  285. )
  286. return self.starting_str + complete_part + incomplete_part + self.ending_str
  287. @dataclasses.dataclass
  288. class ProgressBar:
  289. """A progress bar that displays the progress of a task."""
  290. steps: int
  291. max_width: int = 80
  292. separator: str = " "
  293. components: Sequence[tuple[ProgressBarComponent, int]] = dataclasses.field(
  294. default_factory=lambda: [
  295. (SimpleProgressComponent(), 2),
  296. (CounterComponent(), 3),
  297. (PercentageComponent(), 0),
  298. (TimeComponent(), 1),
  299. ]
  300. )
  301. _printer: Reprinter = dataclasses.field(default_factory=Reprinter, init=False)
  302. _current: int = dataclasses.field(default=0, init=False)
  303. def __post_init__(self):
  304. """Initialize the progress bar."""
  305. for component, _ in self.components:
  306. component.initialize(self.steps)
  307. def print(self):
  308. """Print the current progress bar state."""
  309. current_terminal_width = _get_terminal_width()
  310. components_by_priority = [
  311. (index, component)
  312. for index, (component, _) in sorted(
  313. enumerate(self.components), key=lambda x: x[1][1], reverse=True
  314. )
  315. ]
  316. possible_width = min(current_terminal_width, self.max_width)
  317. sum_of_minimum_widths = sum(
  318. component.minimum_width(self._current, self.steps)
  319. for _, component in components_by_priority
  320. )
  321. if sum_of_minimum_widths > possible_width:
  322. used_width = 0
  323. visible_components: list[tuple[int, ProgressBarComponent, int]] = []
  324. for index, component in components_by_priority:
  325. if (
  326. used_width
  327. + component.minimum_width(self._current, self.steps)
  328. + len(self.separator)
  329. > possible_width
  330. ):
  331. continue
  332. used_width += component.minimum_width(self._current, self.steps)
  333. visible_components.append(
  334. (
  335. index,
  336. component,
  337. component.requested_width(self._current, self.steps),
  338. )
  339. )
  340. else:
  341. components = [
  342. (
  343. priority,
  344. component,
  345. component.minimum_width(self._current, self.steps),
  346. )
  347. for (component, priority) in self.components
  348. ]
  349. while True:
  350. sum_of_assigned_width = sum(width for _, _, width in components)
  351. extra_width = (
  352. possible_width
  353. - sum_of_assigned_width
  354. - (len(self.separator) * (len(components) - 1))
  355. )
  356. possible_extra_width_to_take = [
  357. (
  358. max(
  359. 0,
  360. component.requested_width(self._current, self.steps)
  361. - width,
  362. ),
  363. priority,
  364. )
  365. for priority, component, width in components
  366. ]
  367. sum_of_possible_extra_width = sum(
  368. width for width, _ in possible_extra_width_to_take
  369. )
  370. if sum_of_possible_extra_width <= 0 or extra_width <= 0:
  371. break
  372. min_width, max_prioririty = min(
  373. filter(lambda x: x[0] > 0, possible_extra_width_to_take),
  374. key=lambda x: x[0] / x[1],
  375. )
  376. maximum_prioririty_repeats = min_width / max_prioririty
  377. give_width = [
  378. min(width, maximum_prioririty_repeats * priority)
  379. for width, priority in possible_extra_width_to_take
  380. ]
  381. sum_of_give_width = sum(give_width)
  382. normalized_give_width = [
  383. width / sum_of_give_width * min(extra_width, sum_of_give_width)
  384. for width in give_width
  385. ]
  386. components = [
  387. (index, component, int(width + give))
  388. for (index, component, width), give in zip(
  389. components, normalized_give_width, strict=True
  390. )
  391. ]
  392. if sum(width for _, _, width in components) == sum_of_minimum_widths:
  393. break
  394. visible_components = [
  395. (index, component, width)
  396. for index, (_, component, width) in enumerate(components)
  397. if width > 0
  398. ]
  399. messages = [
  400. self.get_message(component, width)
  401. for _, component, width in sorted(visible_components, key=lambda x: x[0])
  402. ]
  403. self._printer.reprint(self.separator.join(messages))
  404. def get_message(self, component: ProgressBarComponent, width: int):
  405. """Get the message for a given component.
  406. Args:
  407. component: The component to get the message for.
  408. width: The width of the component.
  409. Returns:
  410. The message for the component
  411. """
  412. message = component.get_message(self._current, self.steps, width)
  413. return component.colorer(message[:width])
  414. def update(self, step: int):
  415. """Update the progress bar by a given step.
  416. Args:
  417. step: The step to update the progress bar by.
  418. """
  419. self._current += step
  420. self.print()
  421. def finish(self):
  422. """Finish the progress bar."""
  423. self._current = self.steps
  424. self.print()
  425. self._printer.finish()