Ver código fonte

update chatgpt demo

wangweimin 1 ano atrás
pai
commit
51ece4a846
1 arquivos alterados com 22 adições e 19 exclusões
  1. 22 19
      demos/chatgpt.py

+ 22 - 19
demos/chatgpt.py

@@ -2,7 +2,8 @@ import json
 import time
 from typing import Dict, List
 
-import openai
+from openai import OpenAI, Stream
+from openai.types.chat import ChatCompletionChunk
 
 import pywebio_battery
 from pywebio.input import *
@@ -12,22 +13,21 @@ from pywebio.session import set_env, download
 
 
 class ChatGPTStreamResponse:
-    def __init__(self, response):
+    """
+    A wrapper to Stream[ChatCompletionChunk], add a `result()` method to get the final result.
+    """
+    def __init__(self, response: Stream[ChatCompletionChunk]):
         self.response = response
         self.yielded = []
         self.finish_reason = None
 
     def __next__(self):
-        # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb
         chunk = next(self.response)
-        self.finish_reason = chunk['choices'][0]['finish_reason']
-
-        # { "role": "assistant" } or { "content": "..."} or {}
-        delta = chunk['choices'][0]['delta']
-        content = delta.get('content', '')
-        if content:
-            self.yielded.append(content)
-        return content
+        self.finish_reason = chunk.choices[0].finish_reason
+        delta = chunk.choices[0].delta
+        if delta.content:
+            self.yielded.append(delta.content)
+        return delta.content
 
     def __iter__(self):
         return self
@@ -38,7 +38,7 @@ class ChatGPTStreamResponse:
 
 class ChatGPT:
 
-    def __init__(self, messages: List[Dict] = None, model: str = "gpt-3.5-turbo", api_key=None, **model_kwargs):
+    def __init__(self, messages: List[Dict] = None, model: str = "gpt-3.5-turbo", client: OpenAI = None, **model_kwargs):
         """
         Create a chatgpt client
 
@@ -46,27 +46,28 @@ class ChatGPT:
           Each message is a dict with keys "role" and "content".
           See: https://platform.openai.com/docs/api-reference/chat/create#chat/create-messages
         :param model: The model to use.
-        :param api_key: The openai api key.
-          Get your API key from https://platform.openai.com/account/api-keys
+        :param OpenAI client: The openai client to use. If not provided, a new client will be created.
         :param model_kwargs: Other parameters to pass to model,
           See https://platform.openai.com/docs/api-reference/chat
         """
+        self._client = client or OpenAI()
         self._messages = list(messages or [])
         self.model_kwargs = dict(model=model, **model_kwargs)
-        if api_key:
-            self.model_kwargs['api_key'] = api_key
 
         self.pending_stream_reply: ChatGPTStreamResponse = None
         self.latest_nonstream_finish_reason = None
 
+    def set_model(self, model: str):
+        """Set the model to use"""
+        self.model_kwargs['model'] = model
+
     def _ask(self, message: str, stream=True, **model_kwargs):
         if self.pending_stream_reply:
             self._messages.append({"role": "assistant", "content": self.pending_stream_reply.result()})
             self.pending_stream_reply = None
 
         self._messages.append({"role": "user", "content": message})
-
-        resp = openai.ChatCompletion.create(
+        resp = self._client.chat.completions.create(
             **self.model_kwargs,
             **model_kwargs,
             messages=self._messages,
@@ -158,8 +159,10 @@ def main():
     put_select('model', ['gpt-3.5-turbo', 'gpt-4'], label='Model')
 
     openai_config = get_openai_config()
+    client = OpenAI(api_key=openai_config['api_key'], base_url=openai_config['api_base'])
 
-    bot = ChatGPT(api_key=openai_config['api_key'], api_base=openai_config['api_base'], model=pin.model)
+    bot = ChatGPT(client=client, model=pin.model)
+    pin_on_change('model', lambda v: bot.set_model(v))
     while True:
         form = input_group('', [
             input(name='msg', placeholder='Ask ChatGPT'),