ws-server.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import tornado.websocket
  2. import time, json
  3. from collections import defaultdict
  4. class Future:
  5. def __init__(self):
  6. self.result = None
  7. self._callbacks = []
  8. def add_done_callback(self, fn):
  9. self._callbacks.append(fn)
  10. def set_result(self, result):
  11. self.result = result
  12. for fn in self._callbacks:
  13. fn(self)
  14. def __iter__(self):
  15. yield self
  16. return self.result
  17. class Task:
  18. def __init__(self, coro):
  19. self.coro = coro
  20. f = Future()
  21. f.set_result(None)
  22. self.step(f)
  23. self.result = None # 协程的返回值
  24. self.on_task_finish = None # 协程完毕的回调函数
  25. def step(self, future):
  26. try:
  27. # send会进入到coro执行, 即fetch, 直到下次yield
  28. # next_future 为yield返回的对象
  29. next_future = self.coro.send(future.result)
  30. next_future.add_done_callback(self.step)
  31. except StopIteration as e:
  32. if len(e.args) == 1:
  33. self.result = e.args[0]
  34. if self.on_task_finish:
  35. self.on_task_finish(self.result)
  36. return
  37. # 非阻塞协程工具库
  38. def text_input_coro(prompt):
  39. """
  40. yield出来的为Future对象,每次yield前注册event,event的callback为给该Future对象set-result
  41. yield的返回值为改Future对象的值
  42. :return:
  43. """
  44. # 注册event
  45. msg_id = Msg.gen_msg_id()
  46. msg = dict(command="text_input", data=dict(prompt=prompt, msg_id=msg_id))
  47. f = Future()
  48. Msg.add_callback(msg_id, f.set_result)
  49. Global.active_ws.write_message(json.dumps(msg))
  50. input_text = yield from f
  51. Msg.unregister_msg(msg_id)
  52. return input_text
  53. def text_print(text, *, ws=None):
  54. msg = dict(command="text_print", data=text)
  55. (ws or Global.active_ws).write_message(json.dumps(msg))
  56. # 业务逻辑 协程
  57. def my_coro():
  58. text_print("Welcome to ws-repl")
  59. name = yield from text_input_coro('input your name:')
  60. text_print("go go go %s!" % name)
  61. age = yield from text_input_coro('input your age:')
  62. text_print("So young!!")
  63. class Msg:
  64. mid2callback = defaultdict(list)
  65. @staticmethod
  66. def gen_msg_id():
  67. mid = '%s-%s' % (Global.active_ws.sid, int(time.time()))
  68. return mid
  69. @classmethod
  70. def add_callback(cls, msg_id, callback):
  71. cls.mid2callback[msg_id].append(callback)
  72. @classmethod
  73. def get_callbacks(cls, msg_id):
  74. return cls.mid2callback[msg_id]
  75. @classmethod
  76. def get_callbacks(cls, msg_id):
  77. return cls.mid2callback[msg_id]
  78. @classmethod
  79. def unregister_msg(cls, msg_id):
  80. del cls.mid2callback[msg_id]
  81. class Global:
  82. active_ws: "EchoWebSocket"
  83. class EchoWebSocket(tornado.websocket.WebSocketHandler):
  84. def check_origin(self, origin):
  85. return True
  86. def get_compression_options(self):
  87. # Non-None enables compression with default options.
  88. return {}
  89. def open(self):
  90. print("WebSocket opened")
  91. self.set_nodelay(True)
  92. ############
  93. self.sid = int(time.time())
  94. self.coro = my_coro()
  95. Global.active_ws = self
  96. self.task = Task(self.coro)
  97. self.task.on_task_finish = self.on_task_finish
  98. def on_task_finish(self, result):
  99. text_print('Task finish, return: %s\nBye, bye!!' % result, ws=self)
  100. self.close()
  101. def on_message(self, message):
  102. print('on_message', message)
  103. # self.write_message(u"You said: " + message)
  104. # { msg_id: , data: }
  105. data = json.loads(message)
  106. Global.active_ws = self
  107. callbacks = Msg.get_callbacks(data['msg_id'])
  108. for c in callbacks:
  109. c(data['data'])
  110. def on_close(self):
  111. print("WebSocket closed")
  112. handlers = [(r"/test", EchoWebSocket)]
  113. app = tornado.web.Application(handlers=handlers, debug=True)
  114. http_server = tornado.httpserver.HTTPServer(app)
  115. http_server.listen(8080)
  116. tornado.ioloop.IOLoop.instance().start()