utils.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import asyncio
  2. import functools
  3. import inspect
  4. import queue
  5. import random
  6. import socket
  7. import string
  8. from collections import OrderedDict
  9. from contextlib import closing
  10. from os.path import abspath, dirname
  11. import time
  12. project_dir = dirname(abspath(__file__))
  13. STATIC_PATH = '%s/html' % project_dir
  14. class Setter:
  15. """
  16. 可以在对象属性上保存数据。
  17. 访问数据对象不存在的属性时会返回None而不是抛出异常。
  18. """
  19. def __getattribute__(self, name):
  20. try:
  21. return super().__getattribute__(name)
  22. except AttributeError:
  23. return None
  24. class ObjectDict(dict):
  25. """
  26. Object like dict, every dict[key] can visite by dict.key
  27. If dict[key] is `Get`, calculate it's value.
  28. """
  29. def __getattr__(self, name):
  30. ret = self.__getitem__(name)
  31. if hasattr(ret, '__get__'):
  32. return ret.__get__(self, ObjectDict)
  33. return ret
  34. def catch_exp_call(func, logger):
  35. """运行函数,将捕获异常记录到日志
  36. :param func: 函数
  37. :param logger: 日志
  38. :return: ``func`` 返回值
  39. """
  40. try:
  41. return func()
  42. except Exception:
  43. logger.exception("Error when invoke `%s`" % func)
  44. def iscoroutinefunction(object):
  45. while isinstance(object, functools.partial):
  46. object = object.func
  47. return asyncio.iscoroutinefunction(object)
  48. def isgeneratorfunction(object):
  49. while isinstance(object, functools.partial):
  50. object = object.func
  51. return inspect.isgeneratorfunction(object)
  52. def get_function_name(func, default=None):
  53. while isinstance(func, functools.partial):
  54. func = func.func
  55. return getattr(func, '__name__', default)
  56. def get_function_doc(func):
  57. while isinstance(func, functools.partial):
  58. func = func.func
  59. return inspect.getdoc(func) or ''
  60. class LimitedSizeQueue(queue.Queue):
  61. """
  62. 有限大小的队列
  63. `get()` 返回全部数据
  64. 队列满时,再 `put()` 会阻塞
  65. """
  66. def get(self):
  67. """获取队列全部数据"""
  68. try:
  69. return super().get(block=False)
  70. except queue.Empty:
  71. return []
  72. def wait_empty(self, timeout=None):
  73. """等待队列内的数据被取走"""
  74. with self.not_full:
  75. if self._qsize() == 0:
  76. return
  77. if timeout is None:
  78. self.not_full.wait()
  79. elif timeout < 0:
  80. raise ValueError("'timeout' must be a non-negative number")
  81. else:
  82. self.not_full.wait(timeout)
  83. def _init(self, maxsize):
  84. self.queue = []
  85. def _qsize(self):
  86. return len(self.queue)
  87. # Put a new item in the queue
  88. def _put(self, item):
  89. self.queue.append(item)
  90. # Get an item from the queue
  91. def _get(self):
  92. all_data = self.queue
  93. self.queue = []
  94. return all_data
  95. async def wait_host_port(host, port, duration=10, delay=2):
  96. """Repeatedly try if a port on a host is open until duration seconds passed
  97. from: https://gist.github.com/betrcode/0248f0fda894013382d7#gistcomment-3161499
  98. :param str host: host ip address or hostname
  99. :param int port: port number
  100. :param int/float duration: Optional. Total duration in seconds to wait, by default 10
  101. :param int/float delay: Optional. Delay in seconds between each try, by default 2
  102. :return: awaitable bool
  103. """
  104. tmax = time.time() + duration
  105. while time.time() < tmax:
  106. try:
  107. _, writer = await asyncio.wait_for(asyncio.open_connection(host, port), timeout=5)
  108. writer.close()
  109. # asyncio.StreamWriter.wait_closed is introduced in py 3.7
  110. # See https://docs.python.org/3/library/asyncio-stream.html#asyncio.StreamWriter.wait_closed
  111. if hasattr(writer, 'wait_closed'):
  112. await writer.wait_closed()
  113. return True
  114. except Exception:
  115. if delay:
  116. await asyncio.sleep(delay)
  117. return False
  118. def get_free_port():
  119. """
  120. pick a free port number
  121. :return int: port number
  122. """
  123. with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
  124. s.bind(('', 0))
  125. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  126. return s.getsockname()[1]
  127. def random_str(length=16):
  128. """生成字母和数组组成的随机字符串
  129. :param int length: 字符串长度
  130. """
  131. candidates = string.ascii_letters + string.digits
  132. return ''.join(random.SystemRandom().choice(candidates) for _ in range(length))
  133. def run_as_function(gen):
  134. res = None
  135. while 1:
  136. try:
  137. res = gen.send(res)
  138. except StopIteration as e:
  139. if len(e.args) == 1:
  140. return e.args[0]
  141. return
  142. async def to_coroutine(gen):
  143. res = None
  144. while 1:
  145. try:
  146. c = gen.send(res)
  147. res = await c
  148. except StopIteration as e:
  149. if len(e.args) == 1:
  150. return e.args[0]
  151. return
  152. class LRUDict(OrderedDict):
  153. """
  154. Store items in the order the keys were last recent updated.
  155. The last recent updated item was in end.
  156. The last furthest updated item was in front.
  157. """
  158. def __setitem__(self, key, value):
  159. OrderedDict.__setitem__(self, key, value)
  160. self.move_to_end(key)
  161. _html_value_chars = set(string.ascii_letters + string.digits + '_-')
  162. def is_html_safe_value(val):
  163. """检查是字符串是否可以作为html属性值"""
  164. return all(i in _html_value_chars for i in val)