flask.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. """
  2. Flask backend
  3. .. note::
  4. 在 AsyncBasedSession 会话中,若在协程任务函数内调用 asyncio 中的协程函数,需要使用 asyncio_coroutine
  5. .. attention::
  6. PyWebIO 的会话状态保存在进程内,所以不支持多进程部署的Flask。
  7. 比如使用 ``uWSGI`` 部署Flask,并使用 ``--processes n`` 选项设置了多进程;
  8. 或者使用 ``nginx`` 等反向代理将流量负载到多个 Flask 副本上。
  9. A note on run Flask with uWSGI:
  10. If you start uWSGI without threads, the Python GIL will not be enabled,
  11. so threads generated by your application will never run. `uWSGI doc <https://uwsgi-docs.readthedocs.io/en/latest/WSGIquickstart.html#a-note-on-python-threads>`_
  12. 在Flask backend中,PyWebIO使用单独一个线程来运行事件循环。如果程序中没有使用到asyncio中的协程函数,
  13. 可以在 start_flask_server 参数中设置 ``disable_asyncio=False`` 来关闭对asyncio协程函数的支持。
  14. 如果您需要使用asyncio协程函数,那么需要在在uWSGI中使用 ``--enable-thread`` 选项开启线程支持。
  15. """
  16. import asyncio
  17. import fnmatch
  18. import logging
  19. import threading
  20. import time
  21. from functools import partial
  22. from typing import Dict
  23. from flask import Flask, request, jsonify, send_from_directory, Response
  24. from ..session import CoroutineBasedSession, get_session_implement, AbstractSession, \
  25. register_session_implement_for_target
  26. from ..utils import STATIC_PATH
  27. from ..utils import random_str, LRUDict
  28. logger = logging.getLogger(__name__)
  29. # todo: use lock to avoid thread race condition
  30. # type: Dict[str, AbstractSession]
  31. _webio_sessions = {} # WebIOSessionID -> WebIOSession()
  32. _webio_expire = LRUDict() # WebIOSessionID -> last active timestamp。按照最后活跃时间递增排列
  33. DEFAULT_SESSION_EXPIRE_SECONDS = 60 # 超过60s会话不活跃则视为会话过期
  34. SESSIONS_CLEANUP_INTERVAL = 20 # 清理过期会话间隔(秒)
  35. WAIT_MS_ON_POST = 100 # 在处理完POST请求时,等待WAIT_MS_ON_POST毫秒再读取返回数据。Task的command可以立即返回
  36. _event_loop = None
  37. def _make_response(webio_session: AbstractSession):
  38. return jsonify(webio_session.get_task_commands())
  39. def _remove_expired_sessions(session_expire_seconds):
  40. logger.debug("removing expired sessions")
  41. """清除当前会话列表中的过期会话"""
  42. while _webio_expire:
  43. sid, active_ts = _webio_expire.popitem(last=False)
  44. if time.time() - active_ts < session_expire_seconds:
  45. # 当前session未过期
  46. _webio_expire[sid] = active_ts
  47. _webio_expire.move_to_end(sid, last=False)
  48. break
  49. # 清理session
  50. logger.debug("session %s expired" % sid)
  51. session = _webio_sessions.get(sid)
  52. if session:
  53. session.close()
  54. del _webio_sessions[sid]
  55. _last_check_session_expire_ts = 0 # 上次检查session有效期的时间戳
  56. def _remove_webio_session(sid):
  57. _webio_sessions.pop(sid, None)
  58. _webio_expire.pop(sid, None)
  59. def cors_headers(origin, check_origin, headers=None):
  60. if headers is None:
  61. headers = {}
  62. if check_origin(origin):
  63. headers['Access-Control-Allow-Origin'] = origin
  64. headers['Access-Control-Allow-Methods'] = 'GET, POST'
  65. headers['Access-Control-Allow-Headers'] = 'content-type, webio-session-id'
  66. headers['Access-Control-Expose-Headers'] = 'webio-session-id'
  67. headers['Access-Control-Max-Age'] = 1440 * 60
  68. return headers
  69. def _webio_view(target, session_cls, session_expire_seconds, session_cleanup_interval, check_origin):
  70. """
  71. :param target: 任务函数
  72. :param session_cls: 会话实现类
  73. :param session_expire_seconds: 会话不活跃过期时间。
  74. :param session_cleanup_interval: 会话清理间隔。
  75. :param callable check_origin: callback(origin) -> bool
  76. :return:
  77. """
  78. global _last_check_session_expire_ts, _event_loop
  79. if _event_loop:
  80. asyncio.set_event_loop(_event_loop)
  81. if request.method == 'OPTIONS': # preflight request for CORS
  82. headers = cors_headers(request.headers.get('Origin', ''), check_origin)
  83. return Response('', headers=headers, status=204)
  84. headers = {}
  85. if request.headers.get('Origin'): # set headers for CORS request
  86. headers = cors_headers(request.headers.get('Origin'), check_origin, headers=headers)
  87. if request.args.get('test'): # 测试接口,当会话使用给予http的backend时,返回 ok
  88. return Response('ok', headers=headers)
  89. webio_session_id = None
  90. # webio-session-id 的请求头为空时,创建新 Session
  91. if 'webio-session-id' not in request.headers or not request.headers['webio-session-id']: # start new WebIOSession
  92. webio_session_id = random_str(24)
  93. headers['webio-session-id'] = webio_session_id
  94. webio_session = session_cls(target)
  95. _webio_sessions[webio_session_id] = webio_session
  96. elif request.headers['webio-session-id'] not in _webio_sessions: # WebIOSession deleted
  97. return jsonify([dict(command='close_session')])
  98. else:
  99. webio_session_id = request.headers['webio-session-id']
  100. webio_session = _webio_sessions[webio_session_id]
  101. if request.method == 'POST': # client push event
  102. if request.json is not None:
  103. webio_session.send_client_event(request.json)
  104. time.sleep(WAIT_MS_ON_POST / 1000.0)
  105. elif request.method == 'GET': # client pull messages
  106. pass
  107. _webio_expire[webio_session_id] = time.time()
  108. # clean up at intervals
  109. if time.time() - _last_check_session_expire_ts > session_cleanup_interval:
  110. _last_check_session_expire_ts = time.time()
  111. _remove_expired_sessions(session_expire_seconds)
  112. response = _make_response(webio_session)
  113. if webio_session.closed():
  114. _remove_webio_session(webio_session_id)
  115. # set header to response
  116. for k, v in headers.items():
  117. response.headers[k] = v
  118. return response
  119. def webio_view(target,
  120. session_expire_seconds=DEFAULT_SESSION_EXPIRE_SECONDS,
  121. session_cleanup_interval=SESSIONS_CLEANUP_INTERVAL,
  122. allowed_origins=None, check_origin=None):
  123. """获取用于与Flask进行整合的view函数
  124. :param target: 任务函数。任务函数为协程函数时,使用 :ref:`基于协程的会话实现 <coroutine_based_session>` ;任务函数为普通函数时,使用基于线程的会话实现。
  125. :param int session_expire_seconds: 会话不活跃过期时间。
  126. :param int session_cleanup_interval: 会话清理间隔。
  127. :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
  128. 来源包含协议和域名和端口部分,允许使用 Unix shell 风格的匹配模式:
  129. - ``*`` 为通配符
  130. - ``?`` 匹配单个字符
  131. - ``[seq]`` 匹配seq内的字符
  132. - ``[!seq]`` 匹配不在seq内的字符
  133. 比如 ``https://*.example.com`` 、 ``*://*.example.com``
  134. :param callable check_origin: 请求来源检查函数。接收请求来源(包含协议和域名和端口部分)字符串,
  135. 返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽略
  136. :return: Flask视图函数
  137. """
  138. session_cls = register_session_implement_for_target(target)
  139. if check_origin is None:
  140. check_origin = lambda origin: any(
  141. fnmatch.fnmatch(origin, patten)
  142. for patten in allowed_origins or []
  143. )
  144. view_func = partial(_webio_view, target=target, session_cls=session_cls,
  145. session_expire_seconds=session_expire_seconds,
  146. session_cleanup_interval=session_cleanup_interval,
  147. check_origin=check_origin)
  148. view_func.__name__ = 'webio_view'
  149. return view_func
  150. def run_event_loop(debug=False):
  151. """运行事件循环
  152. 基于协程的会话在启动Flask服务器之前需要启动一个单独的线程来运行事件循环。
  153. :param debug: Set the debug mode of the event loop.
  154. See also: https://docs.python.org/3/library/asyncio-dev.html#asyncio-debug-mode
  155. """
  156. global _event_loop
  157. _event_loop = asyncio.new_event_loop()
  158. _event_loop.set_debug(debug)
  159. asyncio.set_event_loop(_event_loop)
  160. _event_loop.run_forever()
  161. def start_server(target, port=8080, host='localhost',
  162. allowed_origins=None, check_origin=None,
  163. disable_asyncio=False,
  164. session_cleanup_interval=SESSIONS_CLEANUP_INTERVAL,
  165. session_expire_seconds=DEFAULT_SESSION_EXPIRE_SECONDS,
  166. debug=False, **flask_options):
  167. """启动一个 Flask server 来运行PyWebIO的 ``target`` 服务
  168. :param target: task function. It's a coroutine function is use CoroutineBasedSession or
  169. a simple function is use ThreadBasedSession.
  170. :param port: server bind port. set ``0`` to find a free port number to use
  171. :param host: server bind host. ``host`` may be either an IP address or hostname. If it's a hostname,
  172. :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
  173. 来源包含协议和域名和端口部分,允许使用 Unix shell 风格的匹配模式:
  174. - ``*`` 为通配符
  175. - ``?`` 匹配单个字符
  176. - ``[seq]`` 匹配seq内的字符
  177. - ``[!seq]`` 匹配不在seq内的字符
  178. 比如 ``https://*.example.com`` 、 ``*://*.example.com``
  179. :param callable check_origin: 请求来源检查函数。接收请求来源(包含协议和域名和端口部分)字符串,
  180. 返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽略
  181. :param bool disable_asyncio: 禁用 asyncio 函数。仅在当 ``session_type=COROUTINE_BASED`` 时有效。
  182. 在Flask backend中使用asyncio需要单独开启一个线程来运行事件循环,
  183. 若程序中没有使用到asyncio中的异步函数,可以开启此选项来避免不必要的资源浪费
  184. :param int session_expire_seconds: 会话过期时间。若 session_expire_seconds 秒内没有收到客户端的请求,则认为会话过期。
  185. :param int session_cleanup_interval: 会话清理间隔。
  186. :param bool debug: Flask debug mode
  187. :param flask_options: Additional keyword arguments passed to the constructor of ``flask.Flask.run``.
  188. ref: https://flask.palletsprojects.com/en/1.1.x/api/?highlight=flask%20run#flask.Flask.run
  189. """
  190. app = Flask(__name__)
  191. app.route('/io', methods=['GET', 'POST', 'OPTIONS'])(
  192. webio_view(target, session_expire_seconds=session_expire_seconds,
  193. session_cleanup_interval=session_cleanup_interval,
  194. allowed_origins=allowed_origins,
  195. check_origin=check_origin)
  196. )
  197. @app.route('/')
  198. @app.route('/<path:static_file>')
  199. def serve_static_file(static_file='index.html'):
  200. return send_from_directory(STATIC_PATH, static_file)
  201. if not disable_asyncio and get_session_implement() is CoroutineBasedSession:
  202. threading.Thread(target=run_event_loop, daemon=True).start()
  203. app.run(host=host, port=port, debug=debug, **flask_options)