|
@@ -1,43 +1,33 @@
|
|
|
#!/usr/bin/env python3
|
|
|
-from typing import List, Tuple
|
|
|
-
|
|
|
-from langchain.chains import ConversationChain
|
|
|
-from langchain.chat_models import ChatOpenAI
|
|
|
+from langchain_openai import ChatOpenAI
|
|
|
from log_callback_handler import NiceGuiLogElementCallbackHandler
|
|
|
|
|
|
-from nicegui import context, ui
|
|
|
+from nicegui import ui
|
|
|
|
|
|
OPENAI_API_KEY = 'not-set' # TODO: set your OpenAI API key here
|
|
|
|
|
|
|
|
|
@ui.page('/')
|
|
|
def main():
|
|
|
- llm = ConversationChain(llm=ChatOpenAI(model_name='gpt-3.5-turbo', openai_api_key=OPENAI_API_KEY))
|
|
|
-
|
|
|
- messages: List[Tuple[str, str]] = []
|
|
|
- thinking: bool = False
|
|
|
-
|
|
|
- @ui.refreshable
|
|
|
- def chat_messages() -> None:
|
|
|
- for name, text in messages:
|
|
|
- ui.chat_message(text=text, name=name, sent=name == 'You')
|
|
|
- if thinking:
|
|
|
- ui.spinner(size='3rem').classes('self-center')
|
|
|
- if context.get_client().has_socket_connection:
|
|
|
- ui.run_javascript('window.scrollTo(0, document.body.scrollHeight)')
|
|
|
+ llm = ChatOpenAI(model_name='gpt-3.5-turbo', streaming=True, openai_api_key=OPENAI_API_KEY)
|
|
|
|
|
|
async def send() -> None:
|
|
|
- nonlocal thinking
|
|
|
- message = text.value
|
|
|
- messages.append(('You', text.value))
|
|
|
- thinking = True
|
|
|
+ question = text.value
|
|
|
text.value = ''
|
|
|
- chat_messages.refresh()
|
|
|
|
|
|
- response = await llm.arun(message, callbacks=[NiceGuiLogElementCallbackHandler(log)])
|
|
|
- messages.append(('Bot', response))
|
|
|
- thinking = False
|
|
|
- chat_messages.refresh()
|
|
|
+ with message_container:
|
|
|
+ ui.chat_message(text=question, name='You', sent=True)
|
|
|
+ response_message = ui.chat_message(name='Bot', sent=False)
|
|
|
+ spinner = ui.spinner(type='dots')
|
|
|
+
|
|
|
+ response = ''
|
|
|
+ async for chunk in llm.astream(question, config={'callbacks': [NiceGuiLogElementCallbackHandler(log)]}):
|
|
|
+ response += chunk.content
|
|
|
+ response_message.clear()
|
|
|
+ with response_message:
|
|
|
+ ui.html(response)
|
|
|
+ ui.run_javascript('window.scrollTo(0, document.body.scrollHeight)')
|
|
|
+ message_container.remove(spinner)
|
|
|
|
|
|
ui.add_css(r'a:link, a:visited {color: inherit !important; text-decoration: none; font-weight: 500}')
|
|
|
|
|
@@ -49,8 +39,7 @@ def main():
|
|
|
chat_tab = ui.tab('Chat')
|
|
|
logs_tab = ui.tab('Logs')
|
|
|
with ui.tab_panels(tabs, value=chat_tab).classes('w-full max-w-2xl mx-auto flex-grow items-stretch'):
|
|
|
- with ui.tab_panel(chat_tab).classes('items-stretch'):
|
|
|
- chat_messages()
|
|
|
+ message_container = ui.tab_panel(chat_tab).classes('items-stretch')
|
|
|
with ui.tab_panel(logs_tab):
|
|
|
log = ui.log().classes('w-full h-full')
|
|
|
|