chatgpt.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import json
  2. import time
  3. from typing import Dict, List
  4. from openai import OpenAI, Stream
  5. from openai.types.chat import ChatCompletionChunk
  6. import pywebio_battery
  7. from pywebio.input import *
  8. from pywebio.output import *
  9. from pywebio.pin import *
  10. from pywebio.session import set_env, download
  11. class ChatGPTStreamResponse:
  12. """
  13. A wrapper to Stream[ChatCompletionChunk], add a `result()` method to get the final result.
  14. """
  15. def __init__(self, response: Stream[ChatCompletionChunk]):
  16. self.response = response
  17. self.yielded = []
  18. self.finish_reason = None
  19. def __next__(self):
  20. chunk = next(self.response)
  21. self.finish_reason = chunk.choices[0].finish_reason
  22. delta = chunk.choices[0].delta
  23. if delta.content:
  24. self.yielded.append(delta.content)
  25. return delta.content
  26. def __iter__(self):
  27. return self
  28. def result(self):
  29. return ''.join(self.yielded)
  30. class ChatGPT:
  31. def __init__(self, messages: List[Dict] = None, model: str = "gpt-3.5-turbo", client: OpenAI = None, **model_kwargs):
  32. """
  33. Create a chatgpt client
  34. :param messages: A list of messages comprising the conversation so far.
  35. Each message is a dict with keys "role" and "content".
  36. See: https://platform.openai.com/docs/api-reference/chat/create#chat/create-messages
  37. :param model: The model to use.
  38. :param OpenAI client: The openai client to use. If not provided, a new client will be created.
  39. :param model_kwargs: Other parameters to pass to model,
  40. See https://platform.openai.com/docs/api-reference/chat
  41. """
  42. self._client = client or OpenAI()
  43. self._messages = list(messages or [])
  44. self.model_kwargs = dict(model=model, **model_kwargs)
  45. self.pending_stream_reply: ChatGPTStreamResponse = None
  46. self.latest_nonstream_finish_reason = None
  47. def set_model(self, model: str):
  48. """Set the model to use"""
  49. self.model_kwargs['model'] = model
  50. def _ask(self, message: str, stream=True, **model_kwargs):
  51. if self.pending_stream_reply:
  52. self._messages.append({"role": "assistant", "content": self.pending_stream_reply.result()})
  53. self.pending_stream_reply = None
  54. self._messages.append({"role": "user", "content": message})
  55. resp = self._client.chat.completions.create(
  56. **self.model_kwargs,
  57. **model_kwargs,
  58. messages=self._messages,
  59. stream=stream,
  60. )
  61. return resp
  62. def ask(self, message: str, **model_kwargs) -> str:
  63. """
  64. Send a message to chatgpt and get the reply in string
  65. :param message: The message to send
  66. :param model_kwargs: Other parameters to pass to openai.ChatCompletion.create()
  67. :return: The reply from chatgpt
  68. """
  69. resp = self._ask(message, stream=False, **model_kwargs)
  70. reply = resp['choices'][0]
  71. reply_content = reply['message']['content']
  72. self._messages.append({"role": "assistant", "content": reply_content})
  73. self.latest_nonstream_finish_reason = reply['finish_reason']
  74. return reply_content
  75. def ask_stream(self, message: str, **model_kwargs) -> ChatGPTStreamResponse:
  76. """
  77. Send a message to chatgpt and get the reply in stream
  78. :param message: The message to send
  79. :param model_kwargs: Other parameters to pass to openai.ChatCompletion.create()
  80. :return: A iterator that yields the reply from chatgpt.
  81. The iterator will be exhausted when the reply is complete.
  82. """
  83. resp = self._ask(message, stream=True, **model_kwargs)
  84. self.pending_stream_reply = ChatGPTStreamResponse(resp)
  85. return self.pending_stream_reply
  86. def latest_finish_reason(self) -> str:
  87. """The finish reason for the latest reply of chatgpt.
  88. The possible values for finish_reason are:
  89. 'stop': API returned complete model output
  90. 'length': Incomplete model output due to max_tokens parameter or token limit
  91. 'content_filter': Omitted content due to a flag from our content filters
  92. 'null': API response still in progress or incomplete
  93. See: https://platform.openai.com/docs/guides/chat/response-format
  94. """
  95. if self.pending_stream_reply:
  96. return self.pending_stream_reply.finish_reason
  97. return self.latest_nonstream_finish_reason
  98. def messages(self) -> List[Dict]:
  99. """Get all messages of the conversation """
  100. if self.pending_stream_reply:
  101. self._messages.append({"role": "assistant", "content": self.pending_stream_reply.result()})
  102. self.pending_stream_reply = None
  103. return self._messages
  104. def get_openai_config():
  105. openai_config = json.loads(pywebio_battery.get_localstorage('openai_config') or '{}')
  106. if not openai_config:
  107. openai_config = input_group('OpenAI API Config', [
  108. input('API Key', name='api_key', type=TEXT, required=True,
  109. help_text='Get your API key from https://platform.openai.com/account/api-keys'),
  110. input('API Server', name='api_base', type=TEXT, value='https://api.openai.com', required=True),
  111. ])
  112. openai_config['api_base'] = openai_config['api_base'].removesuffix('/v1').strip('/') + '/v1'
  113. pywebio_battery.set_localstorage('openai_config', json.dumps(openai_config))
  114. put_button('Reset OpenAI API Key', reset_openai_config, link_style=True)
  115. return openai_config
  116. def reset_openai_config():
  117. pywebio_battery.set_localstorage('openai_config', json.dumps(None))
  118. toast("Please refresh the page to take effect")
  119. def main():
  120. """"""
  121. set_env(input_panel_fixed=False, output_animation=False)
  122. put_markdown("""
  123. # ChatGPT
  124. A ChatGPT client implemented with PyWebIO. [Source Code](https://github.com/pywebio/PyWebIO/blob/dev/demos/chatgpt.py)
  125. TIPS: refresh page to open a new chat.
  126. """)
  127. put_select('model', ['gpt-3.5-turbo', 'gpt-4'], label='Model')
  128. openai_config = get_openai_config()
  129. client = OpenAI(api_key=openai_config['api_key'], base_url=openai_config['api_base'])
  130. bot = ChatGPT(client=client, model=pin.model)
  131. pin_on_change('model', lambda v: bot.set_model(v))
  132. while True:
  133. form = input_group('', [
  134. input(name='msg', placeholder='Ask ChatGPT'),
  135. actions(name='cmd', buttons=['Send', 'Multi-line Input', 'Save Chat'])
  136. ])
  137. if form['cmd'] == 'Multi-line Input':
  138. form['msg'] = textarea(value=form['msg'])
  139. elif form['cmd'] == 'Save Chat':
  140. messages = [
  141. msg['content'] if msg['role'] == 'user' else f"> {msg['content']}"
  142. for msg in bot.messages()
  143. ]
  144. download(f"chatgpt_{time.strftime('%Y%m%d%H%M%S')}.md",
  145. '\n\n'.join(messages).encode('utf8'))
  146. continue
  147. user_msg = form['msg']
  148. if not user_msg:
  149. continue
  150. put_info(put_text(user_msg, inline=True))
  151. with use_scope(f'reply-{int(time.time())}'):
  152. put_loading('grow', 'info')
  153. try:
  154. reply_chunks = bot.ask_stream(user_msg)
  155. except Exception as e:
  156. popup('ChatGPT Error', put_error(e))
  157. continue
  158. finally:
  159. clear() # clear loading
  160. for chunk in reply_chunks:
  161. put_text(chunk, inline=True)
  162. clear() # clear above text
  163. put_markdown(reply_chunks.result())
  164. if bot.latest_finish_reason() == 'length':
  165. put_error('Incomplete model output due to max_tokens parameter or token limit.')
  166. elif bot.latest_finish_reason() == 'content_filter':
  167. put_warning("Omitted content due to a flag from OpanAI's content filters.")
  168. if __name__ == '__main__':
  169. from pywebio import start_server
  170. start_server(main, port=8080, debug=True, cdn=False)