Selaa lähdekoodia

feat: add textract driver to puterai module

KernelDeimos 9 kuukautta sitten
vanhempi
säilyke
f924d48b02

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 3427 - 1273
package-lock.json


+ 2 - 0
src/backend/exports.js

@@ -26,6 +26,7 @@ const { testlaunch } = require("./src/index.js");
 const BaseService = require("./src/services/BaseService.js");
 const { Context } = require("./src/util/context.js");
 const { TestDriversModule } = require("./src/modules/test-drivers/TestDriversModule.js");
+const { PuterAIModule } = require("./src/modules/puterai/PuterAIModule.js");
 
 
 module.exports = {
@@ -48,4 +49,5 @@ module.exports = {
     LocalDiskStorageModule,
     SelfHostedModule,
     TestDriversModule,
+    PuterAIModule,
 };

+ 1 - 0
src/backend/package.json

@@ -7,6 +7,7 @@
     "test": "npx mocha"
   },
   "dependencies": {
+    "@aws-sdk/client-textract": "^3.621.0",
     "@heyputer/kv.js": "^0.1.3",
     "@heyputer/multest": "^0.0.2",
     "@heyputer/puter-js-common": "^1.0.0",

+ 32 - 0
src/backend/src/modules/puterai/AIInterfaceService.js

@@ -0,0 +1,32 @@
+const BaseService = require("../../services/BaseService");
+
+class AIInterfaceService extends BaseService {
+    async ['__on_driver.register.interfaces'] () {
+        const svc_registry = this.services.get('registry');
+        const col_interfaces = svc_registry.get('interfaces');
+        
+        col_interfaces.set('puter-ocr', {
+            description: 'Optical character recognition',
+            methods: {
+                recognize: {
+                    description: 'Recognize text in an image or document.',
+                    parameters: {
+                        source: {
+                            type: 'file',
+                        },
+                    },
+                    result: {
+                        type: {
+                            $: 'stream',
+                            content_type: 'image',
+                        }
+                    },
+                },
+            }
+        });
+    }
+}
+
+module.exports = {
+    AIInterfaceService
+};

+ 137 - 0
src/backend/src/modules/puterai/AWSTextractService.js

@@ -0,0 +1,137 @@
+const { TextractClient, AnalyzeDocumentCommand, InvalidS3ObjectException } = require("@aws-sdk/client-textract");
+
+const BaseService = require("../../services/BaseService");
+
+class AWSTextractService extends BaseService {
+    _construct () {
+        this.clients_ = {};
+    }
+
+    static IMPLEMENTS = {
+        ['puter-ocr']: {
+            async recognize ({ source, test_mode }) {
+                const resp = await this.analyze_document(source);
+
+                // Simplify the response for common interface
+                const puter_response = {
+                    blocks: []
+                };
+    
+                for ( const block of resp.Blocks ) {
+                    if ( block.BlockType === 'PAGE' ) continue;
+                    if ( block.BlockType === 'CELL' ) continue;
+                    if ( block.BlockType === 'TABLE' ) continue;
+                    if ( block.BlockType === 'MERGED_CELL' ) continue;
+                    if ( block.BlockType === 'LAYOUT_FIGURE' ) continue;
+                    if ( block.BlockType === 'LAYOUT_TEXT' ) continue;
+    
+                    const puter_block = {
+                        type: `text/textract:${block.BlockType}`,
+                        confidence: block.Confidence,
+                        text: block.Text,
+                    };
+                    puter_response.blocks.push(puter_block);
+                }
+    
+                return puter_response;
+            }
+        },
+    };
+
+    _create_aws_credentials () {
+        return {
+            accessKeyId: this.config.aws.access_key,
+            secretAccessKey: this.config.aws.secret_key,
+        };
+    }
+
+    _get_client (region) {
+        if ( ! region ) {
+            region = this.config.aws?.region ?? this.global_config.aws?.region
+                ?? 'us-west-2';
+        }
+        if ( this.clients_[region] ) return this.clients_[region];
+
+        this.clients_[region] = new TextractClient({
+            credentials: this._create_aws_credentials(),
+            region,
+        });
+
+        return this.clients_[region];
+    }
+
+    async analyze_document (file_facade) {
+        const {
+            client, document, using_s3
+        } = await this._get_client_and_document(file_facade);
+
+        const command = new AnalyzeDocumentCommand({
+            Document: document,
+            FeatureTypes: [
+                // 'TABLES',
+                // 'FORMS',
+                // 'SIGNATURES',
+                'LAYOUT'
+            ],
+        });
+
+        try {
+            return await client.send(command);
+        } catch (e) {
+            if ( using_s3 && e instanceof InvalidS3ObjectException ) {
+                const { client, document } =
+                    await this._get_client_and_document(file_facade, true);
+                const command = new AnalyzeDocumentCommand({
+                    Document: document,
+                    FeatureTypes: [
+                        'LAYOUT',
+                    ],
+                })
+                return await client.send(command);
+            }
+
+            throw e;
+        }
+
+        throw new Error('expected to be unreachable');
+    }
+
+    async _get_client_and_document (file_facade, force_buffer) {
+        const try_s3info = await file_facade.get('s3-info');
+        if ( try_s3info && ! force_buffer ) {
+            console.log('S3 INFO', try_s3info)
+            return {
+                using_s3: true,
+                client: this._get_client(try_s3info.bucket_region),
+                document: {
+                    S3Object: {
+                        Bucket: try_s3info.bucket,
+                        Name: try_s3info.key,
+                    },
+                },
+            };
+        }
+
+        const try_buffer = await file_facade.get('buffer');
+        if ( try_buffer ) {
+            const base64 = try_buffer.toString('base64');
+            return {
+                client: this._get_client(),
+                document: {
+                    Bytes: try_buffer,
+                },
+            };
+        }
+
+        const fsNode = await file_facade.get('fs-node');
+        if ( fsNode && ! await fsNode.exists() ) {
+            throw APIError.create('subject_does_not_exist');
+        }
+
+        throw new Error('No suitable input for Textract');
+    }
+}
+
+module.exports = {
+    AWSTextractService,
+};

+ 17 - 0
src/backend/src/modules/puterai/PuterAIModule.js

@@ -0,0 +1,17 @@
+const { AdvancedBase } = require("@heyputer/puter-js-common");
+
+class PuterAIModule extends AdvancedBase {
+    async install (context) {
+        const services = context.get('services');
+
+        const { AIInterfaceService } = require('./AIInterfaceService');
+        services.registerService('__ai-interfaces', AIInterfaceService);
+
+        const { AWSTextractService } = require('./AWSTextractService');
+        services.registerService('aws-textract', AWSTextractService);
+    }
+}
+
+module.exports = {
+    PuterAIModule,
+};

+ 17 - 0
src/backend/src/modules/puterai/doc/requests.md

@@ -0,0 +1,17 @@
+```javascript
+await (await fetch("http://api.puter.localhost:4100/drivers/call", {
+    "headers": {
+        "Content-Type": "application/json",
+        "Authorization": `Bearer ${puter.authToken}`,
+    },
+    "body": JSON.stringify({
+        interface: 'puter-ocr',
+        driver: 'aws-textract',
+        method: 'recognize',
+        args: {
+            source: '~/Desktop/testocr.png',
+        },
+    }),
+    "method": "POST",
+})).json();
+```

+ 2 - 0
tools/run-selfhosted.js

@@ -85,6 +85,7 @@ const main = async () => {
         LocalDiskStorageModule,
         SelfHostedModule,
         TestDriversModule,
+        PuterAIModule,
     } = (await import('@heyputer/backend')).default;
 
     const k = new Kernel({
@@ -95,6 +96,7 @@ const main = async () => {
     k.add_module(new LocalDiskStorageModule());
     k.add_module(new SelfHostedModule());
     k.add_module(new TestDriversModule());
+    k.add_module(new PuterAIModule());
     k.boot();
 };
 

Kaikkia tiedostoja ei voida näyttää, sillä liian monta tiedostoa muuttui tässä diffissä