Forráskód Böngészése

fix mult thread write

AE86 5 éve
szülő
commit
47980ca1b0

+ 6 - 5
dbsyncer-common/src/main/java/org/dbsyncer/common/task/Result.java

@@ -10,7 +10,7 @@ public class Result {
 
     private AtomicLong fail;
 
-    private String error;
+    private StringBuffer error;
 
     public Result() {
         init();
@@ -21,13 +21,14 @@ public class Result {
         init();
     }
 
-    public Result(String error) {
+    public Result(StringBuffer error) {
+        this.fail = new AtomicLong(0);
         this.error = error;
-        init();
     }
 
     private void init(){
         this.fail = new AtomicLong(0);
+        this.error = new StringBuffer();
     }
 
     public List<Map<String, Object>> getData() {
@@ -46,11 +47,11 @@ public class Result {
         this.fail = fail;
     }
 
-    public String getError() {
+    public StringBuffer getError() {
         return error;
     }
 
-    public void setError(String error) {
+    public void setError(StringBuffer error) {
         this.error = error;
     }
 }

+ 2 - 2
dbsyncer-connector/src/main/java/org/dbsyncer/connector/database/AbstractDatabaseConnector.java

@@ -158,7 +158,7 @@ public abstract class AbstractDatabaseConnector implements Database {
         }
         if (CollectionUtils.isEmpty(data)) {
             logger.error("writer data can not be empty.");
-            return new Result("writer data can not be empty.");
+            return new Result(new StringBuffer("writer data can not be empty."));
         }
         final int size = data.size();
         final int fSize = fields.size();
@@ -193,7 +193,7 @@ public abstract class AbstractDatabaseConnector implements Database {
             }
         } catch (Exception e) {
             logger.error(e.getMessage());
-            result.setError(e.getMessage());
+            result.getError().append(e.getMessage());
             result.getFail().set(size);
         } finally {
             // 释放连接

+ 66 - 20
dbsyncer-parser/src/main/java/org/dbsyncer/parser/ParserFactory.java

@@ -13,7 +13,6 @@ import org.dbsyncer.connector.config.*;
 import org.dbsyncer.connector.enums.ConnectorEnum;
 import org.dbsyncer.connector.enums.FilterEnum;
 import org.dbsyncer.connector.enums.OperationEnum;
-import org.dbsyncer.parser.model.Convert;
 import org.dbsyncer.parser.enums.ConvertEnum;
 import org.dbsyncer.parser.enums.ParserEnum;
 import org.dbsyncer.parser.model.*;
@@ -33,11 +32,11 @@ import org.springframework.stereotype.Component;
 import org.springframework.util.Assert;
 
 import java.time.LocalDateTime;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.*;
+import java.util.*;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
 
 /**
  * @author AE86
@@ -275,31 +274,77 @@ public class ParserFactory implements Parser {
     /**
      * 批量写入
      *
-     * @param tConfig
+     * @param config
      * @param command
-     * @param targetFields
+     * @param fields
      * @param target
      * @param threadSize
      * @param batchSize
      * @return
      */
-    private Result executeBatch(ConnectorConfig tConfig, Map<String, String> command, List<Field> targetFields, List<Map<String, Object>> target, int threadSize, int batchSize) {
-        // TODO 拆分任务
+    private Result executeBatch(ConnectorConfig config, Map<String, String> command, List<Field> fields, List<Map<String, Object>> target, int threadSize, int batchSize) {
         // 总数
         int total = target.size();
-        int taskSize = 1;
+        // 单次任务
+        if (total <= batchSize) {
+            return connectorFactory.writer(config, command, fields, target);
+        }
 
+        // 批量任务, 拆分
+        int taskSize = total % batchSize == 0 ? total / batchSize : total / batchSize + 1;
+        threadSize = taskSize <= threadSize ? taskSize : threadSize;
 
-        // 单次任务
-        if(taskSize <= 1){
-            return connectorFactory.writer(tConfig, command, targetFields, target);
+        // 转换为消息队列,根据batchSize获取数据,并发写入
+        Queue<Map<String, Object>> queue = new ConcurrentLinkedQueue<>(target);
+
+        // 创建线程池
+        final ThreadPoolTaskExecutor executor = getThreadPoolTaskExecutor(threadSize);
+        final Result result = new Result();
+        for (; ; ) {
+            if (taskSize <= 0) {
+                break;
+            }
+            final CountDownLatch latch = new CountDownLatch(threadSize);
+            for (int i = 0; i < threadSize; i++) {
+                executor.execute(() -> {
+                    try {
+                        Result w = parallelTask(batchSize, queue, config, command, fields);
+                        // CAS
+                        result.getFail().getAndAdd(w.getFail().get());
+                        result.getError().append(w.getError()).append("\r\n");
+                    } catch (Exception e) {
+                        result.getError().append(e.getMessage()).append("\r\n");
+                        logger.error(e.getMessage());
+                    } finally {
+                        latch.countDown();
+                    }
+                });
+            }
+            try {
+                latch.await();
+            } catch (InterruptedException e) {
+                logger.error(e.getMessage());
+            }
+
+            taskSize -= threadSize;
         }
+        executor.shutdown();
+        return result;
+    }
 
-        // TODO 批量任务走线程池
-        return connectorFactory.writer(tConfig, command, targetFields, target);
+    private Result parallelTask(int batchSize, Queue<Map<String, Object>> queue, ConnectorConfig config, Map<String, String> command, List<Field> fields) {
+        List<Map<String, Object>> data = new ArrayList<>();
+        for (int j = 0; j < batchSize; j++) {
+            Map<String, Object> poll = queue.poll();
+            if (null == poll) {
+                break;
+            }
+            data.add(poll);
+        }
+        return connectorFactory.writer(config, command, fields, data);
     }
 
-    private ThreadPoolTaskExecutor getThreadPoolTaskExecutor(int threadSize){
+    private ThreadPoolTaskExecutor getThreadPoolTaskExecutor(int threadSize) {
         ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
         executor.setCorePoolSize(threadSize);
         executor.setMaxPoolSize(threadSize * 2);
@@ -309,6 +354,7 @@ public class ParserFactory implements Parser {
         executor.setThreadNamePrefix("ParserExecutor");
         executor.setWaitForTasksToCompleteOnShutdown(true);
         executor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy());
+        executor.initialize();
         return executor;
     }
 
@@ -317,7 +363,7 @@ public class ParserFactory implements Parser {
 
         ParserFactory factory = new ParserFactory();
 
-        ThreadPoolTaskExecutor executor = factory.getThreadPoolTaskExecutor(threadSize);
+        final ThreadPoolTaskExecutor executor = factory.getThreadPoolTaskExecutor(threadSize);
         CountDownLatch latch = new CountDownLatch(threadSize);
         for (int i = 0; i < threadSize; i++) {
             executor.execute(() -> {
@@ -331,9 +377,9 @@ public class ParserFactory implements Parser {
                 }
             });
         }
-        executor.shutdown();
         try {
-            latch.wait();
+            latch.await();
+            executor.shutdown();
         } catch (InterruptedException e) {
             e.printStackTrace();
         }