Forráskód Böngészése

feat(puterai): add streaming

KernelDeimos 9 hónapja
szülő
commit
9d5963cdf5

+ 2 - 1
src/backend/src/modules/puterai/AIInterfaceService.js

@@ -33,8 +33,9 @@ class AIInterfaceService extends BaseService {
                     parameters: {
                         messages: { type: 'json' },
                         vision: { type: 'flag' },
+                        stream: { type: 'flag' },
                     },
-                    result: { type: 'json' }
+                    result: { type: 'json' },
                 }
             }
         });

+ 35 - 3
src/backend/src/modules/puterai/OpenAICompletionService.js

@@ -1,7 +1,10 @@
+const { PassThrough } = require('stream');
 const APIError = require('../../api/APIError');
 const BaseService = require('../../services/BaseService');
+const { TypedValue } = require('../../services/drivers/meta/Runtime');
 const { Context } = require('../../util/context');
 const SmolUtil = require('../../util/smolutil');
+const { nou } = require('../../util/langutil');
 
 class OpenAICompletionService extends BaseService {
     static MODULES = {
@@ -20,7 +23,7 @@ class OpenAICompletionService extends BaseService {
 
     static IMPLEMENTS = {
         ['puter-chat-completion']: {
-            async complete ({ messages, test_mode }) {
+            async complete ({ messages, test_mode, stream }) {
                 if ( test_mode ) {
                     const { LoremIpsum } = require('lorem-ipsum');
                     const li = new LoremIpsum({
@@ -50,6 +53,7 @@ class OpenAICompletionService extends BaseService {
                 return await this.complete(messages, {
                     model,
                     moderation: true,
+                    stream,
                 });
             }
         }
@@ -76,7 +80,7 @@ class OpenAICompletionService extends BaseService {
         };
     }
 
-    async complete (messages, { moderation, model }) {
+    async complete (messages, { stream, moderation, model }) {
         // Validate messages
         if ( ! Array.isArray(messages) ) {
             throw new Error('`messages` must be an array');
@@ -199,7 +203,35 @@ class OpenAICompletionService extends BaseService {
             messages: messages,
             model: model,
             max_tokens,
+            stream,
         });
+        
+        if ( stream ) {
+            const entire = [];
+            const stream = new PassThrough();
+            const retval = new TypedValue({
+                $: 'stream',
+                content_type: 'application/x-ndjson',
+                chunked: true,
+            }, stream);
+            (async () => {
+                for await ( const chunk of completion ) {
+                    entire.push(chunk);
+                    if ( chunk.choices.length < 1 ) continue;
+                    if ( chunk.choices[0].finish_reason ) {
+                        stream.end();
+                        break;
+                    }
+                    if ( nou(chunk.choices[0].delta.content) ) continue;
+                    const str = JSON.stringify({
+                        text: chunk.choices[0].delta.content
+                    });
+                    stream.write(str + '\n');
+                }
+            })();
+            return retval;
+        }
+
 
         this.log.info('how many choices?: ' + completion.choices.length);
 
@@ -244,7 +276,7 @@ class OpenAICompletionService extends BaseService {
                 throw new Error('message is not allowed');
             }
         }
-
+        
         return completion.choices[0];
     }
 }

+ 5 - 1
src/backend/src/routers/drivers/call.js

@@ -84,7 +84,7 @@ module.exports = eggspress('/drivers/call', {
     // consider the case where a driver method implements a
     // stream transformation, thus the stream from the request isn't
     // consumed until the response is being sent.
-
+    
     _respond(res, result);
 
     // What we _can_ do is await the request promise while responding
@@ -95,8 +95,12 @@ module.exports = eggspress('/drivers/call', {
 const _respond = (res, result) => {
     if ( result.result instanceof TypedValue ) {
         const tv = result.result;
+        debugger;
         if ( TypeSpec.adapt({ $: 'stream' }).equals(tv.type) ) {
             res.set('Content-Type', tv.type.raw.content_type);
+            if ( tv.type.raw.chunked ) {
+                res.set('Transfer-Encoding', 'chunked');
+            }
             tv.value.pipe(res);
             return;
         }

+ 1 - 1
src/backend/src/services/drivers/CoercionService.js

@@ -88,7 +88,7 @@ class CoercionService extends BaseService {
             return coerced;
         }
 
-        return undefined;
+        return typed_value;
     }
 }