浏览代码

feat: add image generation driver to puterai module

KernelDeimos 10 月之前
父节点
当前提交
fb26fdbc56

+ 32 - 0
src/backend/src/modules/puterai/AIInterfaceService.js

@@ -38,6 +38,38 @@ class AIInterfaceService extends BaseService {
                 }
             }
         });
+
+        col_interfaces.set('puter-image-generation', {
+            description: 'AI Image Generation.',
+            methods: {
+                generate: {
+                    description: 'Generate an image from a prompt.',
+                    parameters: {
+                        prompt: { type: 'string' },
+                    },
+                    result_choices: [
+                        {
+                            names: ['image'],
+                            type: {
+                                $: 'stream',
+                                content_type: 'image',
+                            }
+                        },
+                        {
+                            names: ['url'],
+                            type: {
+                                $: 'string:url:web',
+                                content_type: 'image',
+                            }
+                        },
+                    ],
+                    result: {
+                        description: 'URL of the generated image.',
+                        type: 'string'
+                    }
+                }
+            }
+        });
     }
 }
 

+ 100 - 0
src/backend/src/modules/puterai/OpenAIImageGenerationService.js

@@ -0,0 +1,100 @@
+const BaseService = require("../../services/BaseService");
+const { TypedValue } = require("../../services/drivers/meta/Runtime");
+const { Context } = require("../../util/context");
+
+class OpenAIImageGenerationService extends BaseService {
+    static MODULES = {
+        openai: require('openai'),
+    }
+    async _init () {
+        const sk_key =
+            this.config?.openai?.secret_key ??
+            this.global_config.openai?.secret_key;
+
+        this.openai = new this.modules.openai.OpenAI({
+            apiKey: sk_key
+        });
+    }
+
+    static IMPLEMENTS = {
+        ['puter-image-generation']: {
+            async generate ({ prompt, test_mode }) {
+                const url = await this.generate(prompt, {
+                    ratio: this.constructor.RATIO_SQUARE,
+                });
+
+                if ( test_mode ) {
+                    return new TypedValue({
+                        $: 'string:url:web',
+                        content_type: 'image',
+                    }, 'https://puter-sample-data.puter.site/image_example.png');
+                }
+
+                const image = new TypedValue({
+                    $: 'string:url:web',
+                    content_type: 'image'
+                }, url);
+
+                return image;
+            }
+        }
+    };
+
+    static RATIO_SQUARE = { w: 1024, h: 1024 };
+    static RATIO_PORTRAIT = { w: 1024, h: 1792 };
+    static RATIO_LANDSCAPE = { w: 1792, h: 1024 };
+
+    async generate (prompt, {
+        ratio,
+        model,
+    }) {
+        if ( typeof prompt !== 'string' ) {
+            throw new Error('`prompt` must be a string');
+        }
+
+        if ( ! ratio || ! this._validate_ratio(ratio) ) {
+            throw new Error('`ratio` must be a valid ratio');
+        }
+
+        model = model ?? 'dall-e-3';
+
+        const user_private_uid = Context.get('actor')?.private_uid ?? 'UNKNOWN';
+        if ( user_private_uid === 'UNKNOWN' ) {
+            this.errors.report('chat-completion-service:unknown-user', {
+                message: 'failed to get a user ID for an OpenAI request',
+                alarm: true,
+                trace: true,
+            });
+        }
+
+        const result =
+            await this.openai.images.generate({
+                user: user_private_uid,
+                prompt,
+                size: `${ratio.w}x${ratio.h}`,
+            });
+
+        const spending_meta = {
+            model,
+            size: `${ratio.w}x${ratio.h}`,
+        };
+
+        const svc_spending = Context.get('services').get('spending');
+        svc_spending.record_spending('openai', 'image-generation', spending_meta);
+
+        const url = result.data?.[0]?.url;
+        return url;
+    }
+
+    _validate_ratio (ratio) {
+        return false
+            || ratio === this.constructor.RATIO_SQUARE
+            || ratio === this.constructor.RATIO_PORTRAIT
+            || ratio === this.constructor.RATIO_LANDSCAPE
+            ;
+    }
+}
+
+module.exports = {
+    OpenAIImageGenerationService,
+};

+ 3 - 0
src/backend/src/modules/puterai/PuterAIModule.js

@@ -12,6 +12,9 @@ class PuterAIModule extends AdvancedBase {
 
         const { OpenAICompletionService } = require('./OpenAICompletionService');
         services.registerService('openai-completion', OpenAICompletionService);
+
+        const { OpenAIImageGenerationService } = require('./OpenAIImageGenerationService');
+        services.registerService('openai-image-generation', OpenAIImageGenerationService);
     }
 }
 

+ 18 - 0
src/backend/src/modules/puterai/doc/requests.md

@@ -41,4 +41,22 @@ await (await fetch("http://api.puter.localhost:4100/drivers/call", {
     }),
     "method": "POST",
 })).json();
+```
+
+```javascript
+URL.createObjectURL(await (await fetch("http://api.puter.localhost:4100/drivers/call", {
+  "headers": {
+    "Content-Type": "application/json",
+    "Authorization": `Bearer ${puter.authToken}`,
+  },
+  "body": JSON.stringify({
+      interface: 'puter-image-generation',
+      driver: 'openai-image-generation',
+      method: 'generate',
+      args: {
+        prompt: 'photorealistic teapot made of swiss cheese',
+      }
+  }),
+  "method": "POST",
+})).blob());
 ```