utils.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import fnmatch
  2. import json
  3. import socket
  4. import urllib.parse
  5. from collections import defaultdict
  6. from ..__version__ import __version__ as version
  7. from ..exceptions import PyWebIOWarning
  8. def cdn_validation(cdn, level='warn', stacklevel=3):
  9. """CDN availability check
  10. :param bool/str cdn: cdn parameter
  11. :param level: warn or error
  12. :param stacklevel: stacklevel=3 to makes the warning refer to cdn_validation() caller’s caller
  13. """
  14. assert level in ('warn', 'error')
  15. if cdn is True and 'dev' in version:
  16. if level == 'warn':
  17. import warnings
  18. warnings.warn("Default CDN is not supported in dev version. Ignore the CDN setting", PyWebIOWarning,
  19. stacklevel=stacklevel)
  20. return False
  21. else:
  22. raise ValueError("Default CDN is not supported in dev version. Please host static files by yourself.")
  23. return cdn
  24. class OriginChecker:
  25. @classmethod
  26. def check_origin(cls, origin, allowed_origins, host):
  27. if cls.is_same_site(origin, host):
  28. return True
  29. return any(
  30. fnmatch.fnmatch(origin, pattern)
  31. for pattern in allowed_origins
  32. )
  33. @staticmethod
  34. def is_same_site(origin, host):
  35. """判断 origin 和 host 是否一致。origin 和 host 都为http协议请求头"""
  36. parsed_origin = urllib.parse.urlparse(origin)
  37. origin = parsed_origin.netloc
  38. origin = origin.lower()
  39. # Check to see that origin matches host directly, including ports
  40. return origin == host
  41. def deserialize_binary_event(data: bytes):
  42. """
  43. Data format:
  44. | event | file_header | file_data | file_header | file_data | ...
  45. The 8 bytes at the beginning of each segment indicate the number of bytes remaining in the segment.
  46. event: {
  47. event: "from_submit",
  48. task_id: that.task_id,
  49. data: {
  50. input_name => input_data
  51. }
  52. }
  53. file_header: {
  54. 'filename': file name,
  55. 'size': file size,
  56. 'mime_type': file type,
  57. 'last_modified': last_modified timestamp,
  58. 'input_name': name of input field
  59. }
  60. Example:
  61. b'\x00\x00\x00\x00\x00\x00\x00E{"event":"from_submit","task_id":"main-4788341456","data":{"data":1}}\x00\x00\x00\x00\x00\x00\x00Y{"filename":"hello.txt","size":2,"mime_type":"text/plain","last_modified":1617119937.276}\x00\x00\x00\x00\x00\x00\x00\x02ss'
  62. """
  63. parts = []
  64. start_idx = 0
  65. while start_idx < len(data):
  66. size = int.from_bytes(data[start_idx:start_idx + 8], "big")
  67. start_idx += 8
  68. content = data[start_idx:start_idx + size]
  69. parts.append(content)
  70. start_idx += size
  71. event = json.loads(parts[0])
  72. files = defaultdict(list)
  73. for idx in range(1, len(parts), 2):
  74. f = json.loads(parts[idx])
  75. f['content'] = parts[idx + 1]
  76. input_name = f.pop('input_name')
  77. files[input_name].append(f)
  78. for input_name in list(event['data'].keys()):
  79. if input_name in files:
  80. event['data'][input_name] = files[input_name]
  81. return event
  82. def get_interface_ip(family: socket.AddressFamily) -> str:
  83. """Get the IP address of an external interface. Used when binding to
  84. 0.0.0.0 or :: to show a more useful URL.
  85. Copy from https://github.com/pallets/werkzeug/blob/df7492ab66aaced5eea964a58309caaadb1e8903/src/werkzeug/serving.py
  86. Under BSD-3-Clause License
  87. """
  88. # arbitrary private address
  89. host = "fd31:f903:5ab5:1::1" if family == socket.AF_INET6 else "10.253.155.219"
  90. with socket.socket(family, socket.SOCK_DGRAM) as s:
  91. try:
  92. s.connect((host, 58162))
  93. except OSError:
  94. return "::1" if family == socket.AF_INET6 else "127.0.0.1"
  95. return s.getsockname()[0] # type: ignore
  96. def print_listen_address(host, port):
  97. if not host:
  98. host = '0.0.0.0'
  99. all_address = False
  100. if host == "0.0.0.0":
  101. all_address = True
  102. host = get_interface_ip(socket.AF_INET)
  103. elif host == "::":
  104. all_address = True
  105. host = get_interface_ip(socket.AF_INET6)
  106. if ':' in host: # ipv6
  107. host = '[%s]' % host
  108. if all_address:
  109. print('Running on all addresses.')
  110. print('Use http://%s:%s/ to access the application' % (host, port))
  111. else:
  112. print('Running on http://%s:%s/' % (host, port))