slot.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from __future__ import annotations
  2. import asyncio
  3. from typing import TYPE_CHECKING, ClassVar, Dict, Iterator, List, Optional
  4. from typing_extensions import Self
  5. from .logging import log
  6. if TYPE_CHECKING:
  7. from .element import Element
  8. class Slot:
  9. stacks: ClassVar[Dict[int, List[Slot]]] = {}
  10. """Maps asyncio task IDs to slot stacks, which keep track of the current slot in each task."""
  11. def __init__(self, parent: Element, name: str, template: Optional[str] = None) -> None:
  12. self.name = name
  13. self.parent = parent
  14. self.template = template
  15. self.children: List[Element] = []
  16. def __enter__(self) -> Self:
  17. self.get_stack().append(self)
  18. return self
  19. def __exit__(self, *_) -> None:
  20. self.get_stack().pop()
  21. self.prune_stack()
  22. def __iter__(self) -> Iterator[Element]:
  23. return iter(self.children)
  24. @classmethod
  25. def get_stack(cls) -> List[Slot]:
  26. """Return the slot stack of the current asyncio task."""
  27. task_id = get_task_id()
  28. if task_id not in cls.stacks:
  29. cls.stacks[task_id] = []
  30. return cls.stacks[task_id]
  31. @classmethod
  32. def prune_stack(cls) -> None:
  33. """Remove the current slot stack if it is empty."""
  34. task_id = get_task_id()
  35. if not cls.stacks[task_id]:
  36. del cls.stacks[task_id]
  37. @classmethod
  38. async def prune_stacks(cls) -> None:
  39. """Remove stale slot stacks in an endless loop."""
  40. while True:
  41. try:
  42. running = [id(task) for task in asyncio.tasks.all_tasks() if not task.done() and not task.cancelled()]
  43. stale_ids = [task_id for task_id in cls.stacks if task_id not in running]
  44. for task_id in stale_ids:
  45. del cls.stacks[task_id]
  46. except Exception:
  47. # NOTE: make sure the loop doesn't crash
  48. log.exception('Error while pruning slot stacks')
  49. try:
  50. await asyncio.sleep(10)
  51. except asyncio.CancelledError:
  52. break
  53. def get_task_id() -> int:
  54. """Return the ID of the current asyncio task."""
  55. try:
  56. return id(asyncio.current_task())
  57. except RuntimeError:
  58. return 0