1
0
Эх сурвалжийг харах

add cors support in flask backend

wangweimin 5 жил өмнө
parent
commit
1aea73f295
1 өөрчлөгдсөн 59 нэмэгдсэн , 14 устгасан
  1. 59 14
      pywebio/platform/flask.py

+ 59 - 14
pywebio/platform/flask.py

@@ -20,12 +20,13 @@ Flask backend
 
 """
 import asyncio
+import fnmatch
 import threading
 import time
 from functools import partial
 from typing import Dict
 
-from flask import Flask, request, jsonify, send_from_directory
+from flask import Flask, request, jsonify, send_from_directory, Response
 
 from ..session import CoroutineBasedSession, get_session_implement, AbstractSession, \
     set_session_implement, get_session_implement_for_target
@@ -65,24 +66,46 @@ def _remove_webio_session(sid):
     del _webio_expire[sid]
 
 
-def _webio_view(target, session_expire_seconds):
+def cors_headers(origin, check_origin, headers=None):
+    if headers is None:
+        headers = {}
+
+    if check_origin(origin):
+        headers['Access-Control-Allow-Origin'] = origin
+        headers['Access-Control-Allow-Methods'] = 'GET, POST'
+        headers['Access-Control-Allow-Headers'] = 'content-type, webio-session-id'
+        headers['Access-Control-Expose-Headers'] = 'webio-session-id'
+        headers['Access-Control-Max-Age'] = 1440 * 60
+
+    return headers
+
+
+def _webio_view(target, session_expire_seconds, check_origin):
     """
-    :param coro_func:
+    :param target:
     :param session_expire_seconds:
     :return:
     """
-    if request.args.get('test'):  # 测试接口,当会话使用给予http的backend时,返回 ok
-        return 'ok'
-
     global _last_check_session_expire_ts, _event_loop
     if _event_loop:
         asyncio.set_event_loop(_event_loop)
 
+    if request.method == 'OPTIONS':  # preflight request for CORS
+        headers = cors_headers(request.headers.get('Origin', ''), check_origin)
+        return Response('', headers=headers, status=204)
+
+    headers = {}
+
+    if request.headers.get('Origin'):  # set headers for CORS request
+        headers = cors_headers(request.headers.get('Origin'), check_origin, headers=headers)
+
+    if request.args.get('test'):  # 测试接口,当会话使用给予http的backend时,返回 ok
+        return Response('ok', headers=headers)
+
     webio_session_id = None
-    set_header = False
     if 'webio-session-id' not in request.headers or not request.headers['webio-session-id']:  # start new WebIOSession
-        set_header = True
         webio_session_id = random_str(24)
+        headers['webio-session-id'] = webio_session_id
         Session = get_session_implement()
         webio_session = Session(target)
         _webio_sessions[webio_session_id] = webio_session
@@ -96,23 +119,27 @@ def _webio_view(target, session_expire_seconds):
     if request.method == 'POST':  # client push event
         webio_session.send_client_event(request.json)
         time.sleep(WAIT_MS_ON_POST / 1000.0)
-
     elif request.method == 'GET':  # client pull messages
         pass
 
+    # clean up at intervals
     if time.time() - _last_check_session_expire_ts > REMOVE_EXPIRED_SESSIONS_INTERVAL:
         _remove_expired_sessions(session_expire_seconds)
         _last_check_session_expire_ts = time.time()
 
     response = _make_response(webio_session)
+
     if webio_session.closed():
         _remove_webio_session(webio_session_id)
-    elif set_header:
-        response.headers['webio-session-id'] = webio_session_id
+
+    # set header to response
+    for k, v in headers.items():
+        response.headers[k] = v
+
     return response
 
 
-def webio_view(target, session_expire_seconds, session_type=None):
+def webio_view(target, session_expire_seconds, session_type=None, allowed_origins=None, check_origin=None):
     """获取Flask view"""
 
     if not session_type:
@@ -120,7 +147,15 @@ def webio_view(target, session_expire_seconds, session_type=None):
 
     set_session_implement(session_type)
 
-    view_func = partial(_webio_view, target=target, session_expire_seconds=session_expire_seconds)
+    if check_origin is None:
+        check_origin = lambda origin: any(
+            fnmatch.fnmatch(origin, patten)
+            for patten in allowed_origins
+        )
+
+    view_func = partial(_webio_view, target=target,
+                        session_expire_seconds=session_expire_seconds,
+                        check_origin=check_origin)
     view_func.__name__ = 'webio_view'
     return view_func
 
@@ -134,6 +169,7 @@ def _setup_event_loop():
 
 
 def start_server(target, port=8080, host='localhost',
+                 allowed_origins=None, check_origin=None,
                  session_type=None,
                  disable_asyncio=False,
                  session_expire_seconds=DEFAULT_SESSION_EXPIRE_SECONDS,
@@ -143,6 +179,10 @@ def start_server(target, port=8080, host='localhost',
         a simple function is use ThreadBasedSession.
     :param port: server bind port. set ``0`` to find a free port number to use
     :param host: server bind host. ``host`` may be either an IP address or hostname.  If it's a hostname,
+    :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
+        来源包含协议和域名和端口部分,允许使用 ``*`` 作为通配符。 比如 ``https://*.example.com`` 、 ``*://*.example.com`` 、
+    :param callable check_origin: 请求来源检查函数。接收请求来源(包含协议和域名和端口部分)字符串,
+        返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽律
     :param str session_type: 指定 `Session <pywebio.session.AbstractSession>` 的实现。未设置则根据 ``target`` 类型选择合适的实现。
         接受的值为 `pywebio.session.THREAD_BASED` 和 `pywebio.session.COROUTINE_BASED`
     :param disable_asyncio: 禁用 asyncio 函数。仅在当 ``session_type=COROUTINE_BASED`` 时有效。
@@ -156,7 +196,12 @@ def start_server(target, port=8080, host='localhost',
     """
 
     app = Flask(__name__)
-    app.route('/io', methods=['GET', 'POST'])(webio_view(target, session_expire_seconds, session_type=session_type))
+    app.route('/io', methods=['GET', 'POST', 'OPTIONS'])(
+        webio_view(target, session_expire_seconds,
+                   session_type=session_type,
+                   allowed_origins=allowed_origins,
+                   check_origin=check_origin)
+    )
 
     @app.route('/')
     @app.route('/<path:static_file>')