Переглянути джерело

dev: add stream support to Gemini

KernelDeimos 3 місяців тому
батько
коміт
5169d4bb40

+ 41 - 8
src/backend/src/modules/puterai/GeminiService.js

@@ -1,6 +1,8 @@
 const BaseService = require("../../services/BaseService");
 const { GoogleGenerativeAI } = require('@google/generative-ai');
 const GeminiSquareHole = require("./lib/GeminiSquareHole");
+const { TypedValue } = require("../../services/drivers/meta/Runtime");
+const putility = require("@heyputer/putility");
 
 class GeminiService extends BaseService {
     async _init () {
@@ -38,7 +40,6 @@ class GeminiService extends BaseService {
 
                 // History is separate, so the last message gets special treatment.
                 const last_message = messages.pop();
-                console.log('last message?', last_message)
                 const last_message_parts = last_message.parts.map(
                     part => typeof part === 'string' ? part : part.text
                 );
@@ -47,15 +48,36 @@ class GeminiService extends BaseService {
                     history: messages,
                 });
                 
-                const genResult = await chat.sendMessage(last_message_parts)
+                const usage_calculator = GeminiSquareHole.create_usage_calculator({
+                    model_details: (await this.models_()).find(m => m.id === model),
+                });
+                    
+                if ( stream ) {
+                    const genResult = await chat.sendMessageStream(last_message_parts)
+                    const stream = genResult.stream;
+
+                    const usage_promise = new putility.libs.promise.TeePromise();
+                    return new TypedValue({ $: 'ai-chat-intermediate' }, {
+                        stream: true,
+                        init_chat_stream:
+                            GeminiSquareHole.create_chat_stream_handler({
+                                stream, usage_promise,
+                            }),
+                        usage_promise: usage_promise.then(usageMetadata => {
+                            return usage_calculator({ usageMetadata });
+                        }),
+                    })
+                } else {
+                    const genResult = await chat.sendMessage(last_message_parts)
 
-                debugger;
-                const message = genResult.response.candidates[0];
-                message.content = message.content.parts;
-                message.role = 'assistant';
+                    const message = genResult.response.candidates[0];
+                    message.content = message.content.parts;
+                    message.role = 'assistant';
 
-                const result = { message };
-                return result;
+                    const result = { message };
+                    result.usage = usage_calculator(genResult.response);
+                    return result;
+                }
             }
         }
     }
@@ -73,6 +95,17 @@ class GeminiService extends BaseService {
                     output: 30,
                 },
             },
+            {
+                id: 'gemini-2.0-flash',
+                name: 'Gemini 2.0 Flash',
+                context: 131072,
+                cost: {
+                    currency: 'usd-cents',
+                    tokens: 1_000_000,
+                    input: 10,
+                    output: 40,
+                },
+            },
         ];
     }
 }

+ 51 - 0
src/backend/src/modules/puterai/lib/GeminiSquareHole.js

@@ -18,4 +18,55 @@ module.exports = class GeminiSquareHole {
 
         return messages;
     }
+
+    static create_usage_calculator = ({ model_details }) => {
+        return ({ usageMetadata }) => {
+            const tokens = [];
+            
+            tokens.push({
+                type: 'prompt',
+                model: model_details.id,
+                amount: usageMetadata.promptTokenCount,
+                cost: model_details.cost.input * usageMetadata.promptTokenCount,
+            });
+
+            tokens.push({
+                type: 'completion',
+                model: model_details.id,
+                amount: usageMetadata.candidatesTokenCount,
+                cost: model_details.cost.output * usageMetadata.candidatesTokenCount,
+            });
+
+            return tokens;
+        };
+    };
+
+    static create_chat_stream_handler = ({
+        stream, // GenerateContentStreamResult:stream
+        usage_promise,
+    }) => async ({ chatStream }) => {
+        const message = chatStream.message();
+        let textblock = message.contentBlock({ type: 'text' });
+        let last_usage = null;
+        for await ( const chunk of stream ) {
+            // This is spread across several lines so that the stack trace
+            // is more helpful if we get an exception because of an
+            // inconsistent response from the model.
+            const candidate = chunk.candidates[0];
+            const content = candidate.content;
+            const parts = content.parts;
+            for ( const part of parts ) {
+                const text = part.text;
+                textblock.addText(text);
+            }
+
+            last_usage = chunk.usageMetadata;
+        }
+
+        usage_promise.resolve(last_usage);
+
+        textblock.end();
+        message.end();
+        chatStream.end();
+    }
 }