项目开发中对于一些数据的处理需要用到多线程,比如文件的批量上传,数据库的分批写入,大文件的分段下载等。 通常会使用spring自带的线程池处理,做到对线程的定制化处理和更好的可控,建议使用自定义的线程池。 主要涉及到的几个点:
1. 自定义线程工厂(ThreadFactoryBuilder),主要用于线程的命名,方便追踪
2. 自定义的线程池(ThreadPoolExecutorUtils),可以按功能优化配置参数
3. 一个抽象的多线程任务处理接口(OperationThreadService)和通用实现(OperationThread)
4. 统一的调度实现(MultiThreadOperationUtils)
核心思想:分治归并,每个线程计算出自己的结果,最后统一汇总。
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicLong;
public class ThreadFactoryBuilder {
private static Logger logger = LoggerFactory.getLogger(ThreadFactoryBuilder.class);
private String nameFormat = null;
private boolean daemon = false;
private int priority = Thread.NORM_PRIORITY;
public ThreadFactoryBuilder setNameFormat(String nameFormat) {
if (nameFormat == null) {
throw new NullPointerException();
}
this.nameFormat = nameFormat;
return this;
}
public ThreadFactoryBuilder setDaemon(boolean daemon) {
this.daemon = daemon;
return this;
}
public ThreadFactoryBuilder setPriority(int priority) {
if (priority < Thread.MIN_PRIORITY) {
throw new IllegalArgumentException(String.format(
"Thread priority (%s) must be >= %s", priority, Thread.MIN_PRIORITY));
}
if (priority > Thread.MAX_PRIORITY) {
throw new IllegalArgumentException(String.format(
"Thread priority (%s) must be <= %s", priority, Thread.MAX_PRIORITY));
}
this.priority = priority;
return this;
}
public ThreadFactory build() {
return build(this);
}
private static ThreadFactory build(ThreadFactoryBuilder builder) {
final String nameFormat = builder.nameFormat;
final Boolean daemon = builder.daemon;
final Integer priority = builder.priority;
final AtomicLong count = new AtomicLong(0);
return (Runnable runnable) -> {
Thread thread = new Thread(runnable);
if (nameFormat != null) {
thread.setName(String.format(nameFormat, count.getAndIncrement()));
}
if (daemon != null) {
thread.setDaemon(daemon);
}
thread.setPriority(priority);
thread.setUncaughtExceptionHandler((t, e) -> {
String threadName = t.getName();
logger.error("error occurred! threadName: {}, error msg: {}", threadName, e.getMessage(), e);
});
return thread;
};
}
}
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.concurrent.*;
public class ThreadPoolExecutorUtils {
private static Logger logger = LoggerFactory.getLogger(ThreadFactoryBuilder.class);
public static int defaultCoreSize = Runtime.getRuntime().availableProcessors();
private static int pollWaitingTime = 60;
private static int defaultQueueSize = 10 * 1000;
private static int defaultMaxSize = 4 * defaultCoreSize;
private static String threadName = "custom-pool";
public static ThreadPoolExecutor getExecutorPool(int waitingTime, int coreSize, int maxPoolSize, int queueSize) {
pollWaitingTime = waitingTime;
defaultCoreSize = coreSize;
defaultMaxSize = maxPoolSize;
defaultQueueSize = queueSize;
return getExecutorPool();
}
public static ThreadPoolExecutor getExecutorPool(int waitingTime, int queueSize, int maxPoolSize) {
pollWaitingTime = waitingTime;
defaultQueueSize = queueSize;
defaultMaxSize = maxPoolSize;
return getExecutorPool();
}
public static ThreadPoolExecutor getExecutorPool(int waitingTime, int queueSize) {
pollWaitingTime = waitingTime;
defaultQueueSize = queueSize;
return getExecutorPool();
}
public static ThreadPoolExecutor getExecutorPool(int waitingTime) {
pollWaitingTime = waitingTime;
return getExecutorPool();
}
public static ThreadPoolExecutor getExecutorPool() {
return getExecutorPool(threadName);
}
public static ThreadPoolExecutor getExecutorPool(String threadName) {
ThreadFactory factory = new ThreadFactoryBuilder()
.setNameFormat(threadName + "-%d")
.build();
BlockingQueue<Runnable> queue = new ArrayBlockingQueue<>(defaultQueueSize);
ThreadPoolExecutor poolExecutor = new ThreadPoolExecutor(defaultCoreSize,
defaultMaxSize, 60, TimeUnit.SECONDS, queue, factory,
(r, executor) -> {
if (!executor.isShutdown()) {
logger.warn("ThreadPoolExecutor is over working, please check the thread tasks! ");
}
}) {
@Override
protected void afterExecute(Runnable r, Throwable t) {
super.afterExecute(r, t);
if (t == null && r instanceof Future<?>) {
try {
Future<?> future = (Future<?>) r;
future.get();
} catch (CancellationException ce) {
t = ce;
} catch (ExecutionException ee) {
t = ee.getCause();
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
}
}
if (t != null) {
logger.error("customThreadPool error msg: {}", t.getMessage(), t);
}
}
};
poolExecutor.prestartAllCoreThreads();
return poolExecutor;
}
public static void closeAfterComplete(ThreadPoolExecutor pool) {
pool.shutdown();
try {
if (!pool.awaitTermination(pollWaitingTime, TimeUnit.SECONDS)) {
pool.shutdownNow();
}
} catch (InterruptedException e) {
logger.error("ThreadPool overtime: {}", e.getMessage());
//(重新)丢弃所有尚未被处理的任务,同时会设置线程池中每个线程的中断标志位
pool.shutdownNow();
// 保持中断状态
Thread.currentThread().interrupt();
}
}
}
import java.util.Arrays;
public class PartitionElements {
private long index;
private long batchCounts;
private long partitions;
private long totalCounts;
private Object[] args;
private Object data;
public PartitionElements() {
}
public PartitionElements(long batchCounts, long totalCounts, Object[] args) {
this.batchCounts = batchCounts;
this.totalCounts = totalCounts;
this.partitions = aquirePartitions(totalCounts, batchCounts);
this.args = args;
}
public PartitionElements(long index, PartitionElements elements) {
this.index = index;
this.batchCounts = elements.getBatchCounts();
this.partitions = elements.getPartitions();
this.totalCounts = elements.getTotalCounts();
this.args = elements.getArgs();
}
public long aquirePartitions(long totalCounts, long batchCounts) {
long partitions = totalCounts / batchCounts;
if (totalCounts % batchCounts != 0) {
partitions = partitions + 1;
}
// 兼容任务总数total = 1 的情况
if (partitions == 0) {
partitions = 1;
}
return partitions;
}
public long getIndex() {
return index;
}
public void setIndex(long index) {
this.index = index;
}
public long getBatchCounts() {
return batchCounts;
}
public void setBatchCounts(long batchCounts) {
this.batchCounts = batchCounts;
}
public long getPartitions() {
return partitions;
}
public void setPartitions(long partitions) {
this.partitions = partitions;
}
public long getTotalCounts() {
return totalCounts;
}
public void setTotalCounts(long totalCounts) {
this.totalCounts = totalCounts;
}
public Object[] getArgs() {
return args;
}
public void setArgs(Object[] args) {
this.args = args;
}
public Object getData() {
return data;
}
public void setData(Object data) {
this.data = data;
}
@Override
public String toString() {
return "PartitionElements{" +
"index=" + index +
", batchCounts=" + batchCounts +
", partitions=" + partitions +
", totalCounts=" + totalCounts +
", args=" + Arrays.toString(args) +
'}';
}
}
import cn.henry.study.common.bo.PartitionElements;
public interface OperationThreadService {
long count(Object[] args) throws Exception;
Object prepare(Object[] args) throws Exception;
Object invoke(PartitionElements elements) throws Exception;
void post(PartitionElements elements, Object object) throws Exception;
Object finished(Object object) throws Exception;
}
import cn.henry.study.common.bo.PartitionElements;
import cn.henry.study.common.service.OperationThreadService;
import cn.henry.study.common.thread.OperationThread;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
public class MultiThreadOperationUtils {
private static Logger logger = LoggerFactory.getLogger(MultiThreadOperationUtils.class);
public static Object batchExecute(OperationThreadService service, Object[] args) throws Exception {
long totalCounts = service.count(args);
long batchCounts = totalCounts / ThreadPoolExecutorUtils.defaultCoreSize;
// 兼容任务少于核心线程数的情况
if (batchCounts == 0) {
batchCounts = 1L;
}
PartitionElements elements = new PartitionElements(batchCounts, totalCounts, args);
return batchExecute(service, elements);
}
public static Object batchExecute(OperationThreadService service, long batchCounts, Object[] args) throws Exception {
long totalCounts = service.count(args);
PartitionElements elements = new PartitionElements(batchCounts, totalCounts, args);
return batchExecute(service, elements);
}
private static Object batchExecute(OperationThreadService service, PartitionElements elements) throws Exception {
ThreadPoolExecutor executor = ThreadPoolExecutorUtils.getExecutorPool();
// 在多线程分治任务之前的预处理方法,返回业务数据
final Object obj = service.prepare(elements.getArgs());
// 预防list和map的resize,初始化给定容量,可提高性能
ArrayList<Future<PartitionElements>> futures = new ArrayList<>((int) elements.getPartitions());
OperationThread opThread = null;
Future<PartitionElements> future = null;
// 添加线程任务
for (int i = 0; i < elements.getPartitions(); i++) {
// 划定任务分布
opThread = new OperationThread(new PartitionElements(i + 1, elements), service);
future = executor.submit(opThread);
futures.add(future);
}
// 关闭线程池
executor.shutdown();
// 阻塞线程,同步处理数据
futures.forEach(f -> {
try {
// 线程单个任务结束后的归并方法
service.post(f.get(), obj);
} catch (Exception e) {
logger.error("post routine fail", e);
}
});
return service.finished(obj);
}
}
import cn.henry.study.common.bo.PartitionElements;
import cn.henry.study.common.service.OperationThreadService;
import cn.henry.study.common.utils.MultiThreadOperationUtils;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
public class MultiThreadServiceTest implements OperationThreadService {
private static Logger logger = LoggerFactory.getLogger(MultiThreadServiceTest.class);
@Override
public long count(Object[] args) throws Exception {
return 100L;
}
@Override
public Object prepare(Object[] args) throws Exception {
return "success";
}
@Override
public Object invoke(PartitionElements elements) throws Exception {
List<Object> list = new ArrayList<>((int) elements.getBatchCounts());
for (int i = 0; i < elements.getIndex(); i++) {
list.add("test_" + i);
}
return list;
}
@Override
public void post(PartitionElements elements, Object object) throws Exception {
String insertSql = "insert into test (id) values ";
StringBuilder sb = new StringBuilder();
List<Object> datas = (List<Object>) elements.getData();
for (int i = 0; i < datas.size(); i++) {
if ((i + 1) % 5 == 0 || (i + 1) == datas.size()) {
sb.append("('" + datas.get(i) + "')");
logger.info("{}: 测试insert sql: {}", elements, insertSql + sb.toString());
sb = new StringBuilder();
} else {
sb.append("('" + datas.get(i) + "'),");
}
}
}
@Override
public Object finished(Object object) throws Exception {
return object;
}
@Test
public void testBatchExecute() {
try {
Object object = MultiThreadOperationUtils.batchExecute(new MultiThreadServiceTest(), 10, new Object[]{"test"});
logger.info("测试完成: {}", object.toString());
} catch (Exception e) {
e.printStackTrace();
}
}
}
总结:这是一个抽象之后的多线程业务流程处理方式,已在生产环境使用,多线程的重点在业务分割和思想上,有清晰的责任划分。
到此这篇关于java项目中的多线程实践的文章就介绍到这了,更多相关java多线程实践内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!