utils.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import fnmatch
  2. import json
  3. import os
  4. import socket
  5. import urllib.parse
  6. from collections import defaultdict
  7. from ..__version__ import __version__ as version
  8. from ..exceptions import PyWebIOWarning
  9. def cdn_validation(cdn, level='warn', stacklevel=3):
  10. """CDN availability check
  11. :param bool/str cdn: cdn parameter
  12. :param level: warn or error
  13. :param stacklevel: stacklevel=3 to makes the warning refer to cdn_validation() caller’s caller
  14. """
  15. assert level in ('warn', 'error')
  16. if cdn is True and 'dev' in version:
  17. if level == 'warn':
  18. import warnings
  19. warnings.warn("Default CDN is not supported in dev version. Ignore the CDN setting", PyWebIOWarning,
  20. stacklevel=stacklevel)
  21. return False
  22. else:
  23. raise ValueError("Default CDN is not supported in dev version. Please host static files by yourself.")
  24. return cdn
  25. class OriginChecker:
  26. @classmethod
  27. def check_origin(cls, origin, allowed_origins, host):
  28. if cls.is_same_site(origin, host):
  29. return True
  30. return any(
  31. fnmatch.fnmatch(origin, pattern)
  32. for pattern in allowed_origins
  33. )
  34. @staticmethod
  35. def is_same_site(origin, host):
  36. """判断 origin 和 host 是否一致。origin 和 host 都为http协议请求头"""
  37. parsed_origin = urllib.parse.urlparse(origin)
  38. origin = parsed_origin.netloc
  39. origin = origin.lower()
  40. # Check to see that origin matches host directly, including ports
  41. return origin == host
  42. def deserialize_binary_event(data: bytes):
  43. """
  44. Binary event message is used to submit form data with files upload to server.
  45. Data message format:
  46. | event | file_header | file_data | file_header | file_data | ...
  47. The 8 bytes at the beginning of each segment indicate the number of bytes remaining in the segment.
  48. event: {
  49. ...
  50. data: {
  51. input_name => input_data,
  52. ...
  53. }
  54. }
  55. file_header: {
  56. 'filename': file name,
  57. 'size': file size,
  58. 'mime_type': file type,
  59. 'last_modified': last_modified timestamp,
  60. 'input_name': name of input field
  61. }
  62. file_data is the file content in bytes.
  63. - When a form field is not a file input, the `event['data'][input_name]` will be the value of the form field.
  64. - When a form field is a single file, the `event['data'][input_name]` is None,
  65. and there will only be one file_header+file_data for the field.
  66. - When a form field is a multiple files, the `event['data'][input_name]` is [],
  67. and there may be multiple file_header+file_data for the field.
  68. Example:
  69. b'\x00\x00\x00\x00\x00\x00\x00E{"event":"from_submit","task_id":"main-4788341456","data":{"data":1}}\x00\x00\x00\x00\x00\x00\x00Y{"filename":"hello.txt","size":2,"mime_type":"text/plain","last_modified":1617119937.276}\x00\x00\x00\x00\x00\x00\x00\x02ss'
  70. """
  71. # split data into segments
  72. parts = []
  73. start_idx = 0
  74. while start_idx < len(data):
  75. size = int.from_bytes(data[start_idx:start_idx + 8], "big")
  76. start_idx += 8
  77. content = data[start_idx:start_idx + size]
  78. parts.append(content)
  79. start_idx += size
  80. event = json.loads(parts[0])
  81. # deserialize file data
  82. files = defaultdict(list)
  83. for idx in range(1, len(parts), 2):
  84. file_header = json.loads(parts[idx])
  85. file_header['content'] = parts[idx + 1]
  86. # Security fix: to avoid interpreting file name as path
  87. file_header['filename'] = os.path.basename(file_header['filename'])
  88. input_name = file_header.pop('input_name')
  89. files[input_name].append(file_header)
  90. # fill file data to event
  91. for input_name in list(event['data'].keys()):
  92. if input_name in files:
  93. init = event['data'][input_name]
  94. event['data'][input_name] = files[input_name]
  95. if init is None: # the file is not multiple
  96. event['data'][input_name] = files[input_name][0] if len(files[input_name]) else None
  97. return event
  98. def get_interface_ip(family: socket.AddressFamily) -> str:
  99. """Get the IP address of an external interface. Used when binding to
  100. 0.0.0.0 or :: to show a more useful URL.
  101. Copy from https://github.com/pallets/werkzeug/blob/df7492ab66aaced5eea964a58309caaadb1e8903/src/werkzeug/serving.py
  102. Under BSD-3-Clause License
  103. """
  104. # arbitrary private address
  105. host = "fd31:f903:5ab5:1::1" if family == socket.AF_INET6 else "10.253.155.219"
  106. with socket.socket(family, socket.SOCK_DGRAM) as s:
  107. try:
  108. s.connect((host, 58162))
  109. except OSError:
  110. return "::1" if family == socket.AF_INET6 else "127.0.0.1"
  111. return s.getsockname()[0] # type: ignore
  112. def print_listen_address(host, port):
  113. if not host:
  114. host = '0.0.0.0'
  115. all_address = False
  116. if host == "0.0.0.0":
  117. all_address = True
  118. host = get_interface_ip(socket.AF_INET)
  119. elif host == "::":
  120. all_address = True
  121. host = get_interface_ip(socket.AF_INET6)
  122. if ':' in host: # ipv6
  123. host = '[%s]' % host
  124. if all_address:
  125. print('Running on all addresses.')
  126. print('Use http://%s:%s/ to access the application' % (host, port))
  127. else:
  128. print('Running on http://%s:%s/' % (host, port))