Browse Source

Merge branch '1.4' into globals

Falko Schindler 1 year ago
parent
commit
a22d449987
3 changed files with 78 additions and 9 deletions
  1. 55 0
      examples/chat_with_ai/log_callback_handler.py
  2. 20 6
      examples/chat_with_ai/main.py
  3. 3 3
      fly.toml

+ 55 - 0
examples/chat_with_ai/log_callback_handler.py

@@ -0,0 +1,55 @@
+from langchain.callbacks.base import BaseCallbackHandler
+from langchain.schema import AgentAction, AgentFinish
+from typing import Dict, Any, Optional
+from nicegui.element import Element
+
+
+class NiceGuiLogElementCallbackHandler(BaseCallbackHandler):
+    """Callback Handler that writes to the log element of NicGui."""
+
+    def __init__(self, element: Element) -> None:
+        """Initialize callback handler."""
+        self.element = element
+
+    def print_text(self, message: str) -> None:
+        self.element.push(message)
+        self.element.update()
+
+    def on_chain_start(
+        self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
+    ) -> None:
+        """Print out that we are entering a chain."""
+        self.print_text(
+            f"\n\n> Entering new {serialized['id'][-1]} chain...",
+        )
+
+    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
+        """Print out that we finished a chain."""
+        self.print_text("\n> Finished chain.")
+        self.print_text(f"\nOutputs: {outputs}")
+
+    def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
+        """Run on agent action."""
+        self.print_text(action.log)
+
+    def on_tool_end(
+        self,
+        output: str,
+        observation_prefix: Optional[str] = None,
+        llm_prefix: Optional[str] = None,
+        **kwargs: Any,
+    ) -> None:
+        """If not the final action, print out observation."""
+        if observation_prefix is not None:
+            self.print_text(f"\n{observation_prefix}")
+        self.print_text(output)
+        if llm_prefix is not None:
+            self.print_text(f"\n{llm_prefix}")
+
+    def on_text(self, text: str, **kwargs: Any) -> None:
+        """Run when agent ends."""
+        self.print_text(text)
+
+    def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
+        """Run on agent end."""
+        self.print_text(finish.log)

+ 20 - 6
examples/chat_with_ai/main.py

@@ -3,7 +3,9 @@ from typing import List, Tuple
 
 
 from langchain.chains import ConversationChain
 from langchain.chains import ConversationChain
 from langchain.chat_models import ChatOpenAI
 from langchain.chat_models import ChatOpenAI
+from log_callback_handler import NiceGuiLogElementCallbackHandler
 
 
+import nicegui.globals
 from nicegui import Client, ui
 from nicegui import Client, ui
 
 
 OPENAI_API_KEY = 'not-set'  # TODO: set your OpenAI API key here
 OPENAI_API_KEY = 'not-set'  # TODO: set your OpenAI API key here
@@ -20,11 +22,13 @@ def chat_messages() -> None:
         ui.chat_message(text=text, name=name, sent=name == 'You')
         ui.chat_message(text=text, name=name, sent=name == 'You')
     if thinking:
     if thinking:
         ui.spinner(size='3rem').classes('self-center')
         ui.spinner(size='3rem').classes('self-center')
-    ui.run_javascript('window.scrollTo(0, document.body.scrollHeight)')
+    if nicegui.globals.get_client().has_socket_connection:
+        ui.run_javascript('window.scrollTo(0, document.body.scrollHeight)')
 
 
 
 
 @ui.page('/')
 @ui.page('/')
-async def main(client: Client):
+async def main():
+
     async def send() -> None:
     async def send() -> None:
         global thinking
         global thinking
         message = text.value
         message = text.value
@@ -33,17 +37,27 @@ async def main(client: Client):
         text.value = ''
         text.value = ''
         chat_messages.refresh()
         chat_messages.refresh()
 
 
-        response = await llm.arun(message)
+        response = await llm.arun(message, callbacks=[NiceGuiLogElementCallbackHandler(log)])
         messages.append(('Bot', response))
         messages.append(('Bot', response))
         thinking = False
         thinking = False
         chat_messages.refresh()
         chat_messages.refresh()
 
 
     anchor_style = r'a:link, a:visited {color: inherit !important; text-decoration: none; font-weight: 500}'
     anchor_style = r'a:link, a:visited {color: inherit !important; text-decoration: none; font-weight: 500}'
     ui.add_head_html(f'<style>{anchor_style}</style>')
     ui.add_head_html(f'<style>{anchor_style}</style>')
-    await client.connected()
 
 
-    with ui.column().classes('w-full max-w-2xl mx-auto items-stretch'):
-        chat_messages()
+    # the queries below are used to expand the contend down to the footer (content can then use flex-grow to expand)
+    ui.query('.q-page').classes('flex')
+    ui.query('.nicegui-content').classes('w-full')
+
+    with ui.tabs().classes('w-full') as tabs:
+        chat_tab = ui.tab('Chat')
+        logs_tab = ui.tab('Logs')
+    with ui.tab_panels(tabs, value=chat_tab).classes('w-full max-w-2xl mx-auto flex-grow items-stretch'):
+        with ui.tab_panel(chat_tab):
+            with ui.column().classes('w-full'):
+                chat_messages()
+        with ui.tab_panel(logs_tab):
+            log = ui.log().classes('w-full h-full')
 
 
     with ui.footer().classes('bg-white'), ui.column().classes('w-full max-w-3xl mx-auto my-6'):
     with ui.footer().classes('bg-white'), ui.column().classes('w-full max-w-3xl mx-auto my-6'):
         with ui.row().classes('w-full no-wrap items-center'):
         with ui.row().classes('w-full no-wrap items-center'):

+ 3 - 3
fly.toml

@@ -24,9 +24,9 @@ kill_timeout = "5s"
   protocol = "tcp"
   protocol = "tcp"
   internal_port = 8080
   internal_port = 8080
   processes = ["app"]
   processes = ["app"]
-  auto_stop_machines = false
+  auto_stop_machines = true
   auto_start_machines = true
   auto_start_machines = true
-  min_machines_running = 19
+  min_machines_running = 10
 
 
   [[services.ports]]
   [[services.ports]]
     port = 80
     port = 80
@@ -37,7 +37,7 @@ kill_timeout = "5s"
     port = 443
     port = 443
     handlers = ["tls", "http"]
     handlers = ["tls", "http"]
   [services.concurrency]
   [services.concurrency]
-    type = "requests"
+    type = "connections"
     hard_limit = 50
     hard_limit = 50
     soft_limit = 20
     soft_limit = 20