소스 검색

add efficient_chat_with_ai (#2800)

* add efficient_chat_with_ai

* transfer ideas from "Efficient Chat with AI" to "Chat with AI"

* show errors and fix reading chunks

* fix_stream

* code review

* cleanup

---------

Co-authored-by: unknown <chenmeilin@hikvision.com>
Co-authored-by: Falko Schindler <falko@zauberzeug.com>
Co-authored-by: Rodja Trappe <rodja@zauberzeug.com>
Merlin 1 년 전
부모
커밋
f66f995cf7
2개의 변경된 파일19개의 추가작업 그리고 30개의 파일을 삭제
  1. 18 29
      examples/chat_with_ai/main.py
  2. 1 1
      examples/chat_with_ai/requirements.txt

+ 18 - 29
examples/chat_with_ai/main.py

@@ -1,43 +1,33 @@
 #!/usr/bin/env python3
-from typing import List, Tuple
-
-from langchain.chains import ConversationChain
-from langchain.chat_models import ChatOpenAI
+from langchain_openai import ChatOpenAI
 from log_callback_handler import NiceGuiLogElementCallbackHandler
 
-from nicegui import context, ui
+from nicegui import ui
 
 OPENAI_API_KEY = 'not-set'  # TODO: set your OpenAI API key here
 
 
 @ui.page('/')
 def main():
-    llm = ConversationChain(llm=ChatOpenAI(model_name='gpt-3.5-turbo', openai_api_key=OPENAI_API_KEY))
-
-    messages: List[Tuple[str, str]] = []
-    thinking: bool = False
-
-    @ui.refreshable
-    def chat_messages() -> None:
-        for name, text in messages:
-            ui.chat_message(text=text, name=name, sent=name == 'You')
-        if thinking:
-            ui.spinner(size='3rem').classes('self-center')
-        if context.get_client().has_socket_connection:
-            ui.run_javascript('window.scrollTo(0, document.body.scrollHeight)')
+    llm = ChatOpenAI(model_name='gpt-3.5-turbo', streaming=True, openai_api_key=OPENAI_API_KEY)
 
     async def send() -> None:
-        nonlocal thinking
-        message = text.value
-        messages.append(('You', text.value))
-        thinking = True
+        question = text.value
         text.value = ''
-        chat_messages.refresh()
 
-        response = await llm.arun(message, callbacks=[NiceGuiLogElementCallbackHandler(log)])
-        messages.append(('Bot', response))
-        thinking = False
-        chat_messages.refresh()
+        with message_container:
+            ui.chat_message(text=question, name='You', sent=True)
+            response_message = ui.chat_message(name='Bot', sent=False)
+            spinner = ui.spinner(type='dots')
+
+        response = ''
+        async for chunk in llm.astream(question, config={'callbacks': [NiceGuiLogElementCallbackHandler(log)]}):
+            response += chunk.content
+            response_message.clear()
+            with response_message:
+                ui.html(response)
+            ui.run_javascript('window.scrollTo(0, document.body.scrollHeight)')
+        message_container.remove(spinner)
 
     ui.add_css(r'a:link, a:visited {color: inherit !important; text-decoration: none; font-weight: 500}')
 
@@ -49,8 +39,7 @@ def main():
         chat_tab = ui.tab('Chat')
         logs_tab = ui.tab('Logs')
     with ui.tab_panels(tabs, value=chat_tab).classes('w-full max-w-2xl mx-auto flex-grow items-stretch'):
-        with ui.tab_panel(chat_tab).classes('items-stretch'):
-            chat_messages()
+        message_container = ui.tab_panel(chat_tab).classes('items-stretch')
         with ui.tab_panel(logs_tab):
             log = ui.log().classes('w-full h-full')
 

+ 1 - 1
examples/chat_with_ai/requirements.txt

@@ -1,3 +1,3 @@
 langchain>=0.0.142
+langchain_openai
 nicegui
-openai