ThreadPoolTaskExecutor 线程池低感知池透传参数实现.md

ThreadPoolTaskExecutor 线程池是一个非常好用的线程,我们可以通过装饰器设计模式为他拓展贯穿业务链路的参数,透传时使用slf4j MDC还可以直接输出一些记录到日志里面

技术简述

ThreadPoolTaskExecutor

一个 Spring 推荐的线程池,Spring 为其提供拒绝策略和处理流程,底层调用 ThreadPoolExecutor。

slf4j MDC

MDC是将线程中你想体现在日志文件中的数据统一管理,根据你的日志配置是否输出。主要做日志链路追踪,支持多线程安全操作。如可以做在日志中体现请求用户 IP 地址,http 客户端的 user-agent ,或者加一个日志跟踪编号等,不同的日志系统有不同的底层实现,如 logback 使用的是 ThreadLocal 。所以也可以仅使用 ThreadLocal 透传参数

装饰器设计模式

装饰器模式又叫包装模式,数据结构型模式;是指在不改变现有对象结构的情况下,动态的给改对象增加一些职责(即增加其额外功能)的模式,像 AOP 就是大家最熟悉的一种实现了

实现思路

  1. 实现 BeanPostProcessor 接口,在 bean 初始化结束后,将 ThreadPoolTaskExecutor 实例的对象替换为一个继承 ThreadPoolTaskExecutor 类的包装类
  2. 包装类重写执行 Runnable 的 execute 与 submit 方法,其目的是将 Runnable 替换为一个实现了 Runnable 接口的 Runnable 包装类 (同理 Callable 也是如此)
  3. 这个 Runnable 包装类做的第一件事是初始化的时候获取当前的线程ID与本地线程参数
  4. Runnable 包装类第二件事是线程池执行它时,先调用前置工作将初始化时获取的参数设置到这根线程中,再调用原来 Runnable.run() 方法
  5. 在执行结束后 Runnable 包装类删除线程中设置的参数

代码实现

按照思路来

声明与定义

ServiceContext:服务上下文,本地线程存储的对象

BzThreadPoolTaskExecutor:继承 ThreadPoolTaskExecutor 类的包装类

BzRunnable:Runnable 的 包装类

ServiceContext代码

/**
 * 服务上下文
 */
public class ServiceContext {
    //准备一个全局静态的 ThreadLocal 变量
    private static final ThreadLocal<ServiceContext> contexts = new ThreadLocal<ServiceContext>() {
        // 每个线程开始的时候创建一个 ServiceContext
        @Override
        protected ServiceContext initialValue() {
            return new ServiceContext();
        }
    };
    // 常用信息
    private Map<String, String> headers = new HashMap<String, String>();

    // 获取当前线程的 ThreadLocal 变量
    public static ServiceContext getContext() {
        return contexts.get();
    }
    // 返回一个克隆版本
    public Map<String, String> getHeaders() {
        Map<String, String> m = new HashMap<String, String>(headers.size());
        for (Map.Entry<String, String> e : headers.entrySet()) {
            m.put(e.getKey(), e.getValue());
        }
        return m;
    }

    public void setHeaders(Map<String, String> headers) {
        this.headers = headers;
    }

    // 通过 getContext() 获取本地线程变量,在开启新的线程初始进去
    // 这方法一般用在各种拦截器的时候初始这根线程的本地线程变量,如:Dubbo拦截器、Spring拦截器
    public void initBy(ServiceContext parentContext) {
        // 如果为空,或内存地址一直就代表是一个对象
        if (parentContext == null || parentContext == this) {
            return;
        }
        headers.clear();
        headers.putAll(parentContext.headers);
    }
    // 删除此线程局部变量的当前线程值
    public static void removeContext() {
        contexts.remove();
    }
}

BzThreadPoolTaskExecutor代码

import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
public final class BzThreadPoolTaskExecutor extends ThreadPoolTaskExecutor {

    private final ThreadPoolTaskExecutor executor;

    public BzThreadPoolTaskExecutor(ThreadPoolTaskExecutor executor){
        if(Objects.equals(null, executor)){
            throw new IllegalArgumentException("executor can't be null");
        }
        this.executor = executor;
    }

    @Override
    public void execute(Runnable task) {
        executor.execute(BzRunnable.newInstance(task));
    }

    @Override
    public void execute(Runnable task, long startTimeout) {
        executor.execute(BzbRunnable.newInstance(task), startTimeout);
    }

    @Override
    public Future<?> submit(Runnable task) {
        return executor.submit(BzRunnable.newInstance(task));
    }

    @Override
    public <T> Future<T> submit(Callable<T> task) {
        // Callable 也是如此
        return executor.submit(BzCallable.newInstance(task));
    }
}

BzRunnable 代码

import org.slf4j.MDC;
import java.util.Map;
import java.util.Objects;

public class BzRunnable implements Runnable{

    private final Runnable task;
    // 下面获取的时候还是在第一个线程,所以获取到的是第一个线程的本地线程变量
    // 获取当前线程ID
    private final long mainThreadId = Thread.currentThread().getId();
    // 获取当前线程本地环境变量
    private Map<String,String> serviceContextHeaders = ServiceContext.getContext().getHeaders();

    public BzRunnable(Runnable task){
        if (Objects.equals(null, task)) {
            throw new IllegalArgumentException("task can't be null");
        }
        this.task = task;
    }

    public static BzRunnable newInstance(Runnable runnable){
        return new BzRunnable(runnable);
    }

    @Override
    public void run() {
        try{
            beforeExecute();
            // 注意这里使用的是run方法而非start
            task.run();
        }finally {
            afterExecute();
        }
    }

    private void beforeExecute(){
        try{
            ServiceContext context;
            //提交当前线程和执行当前线程为同一线程,且ServiceContext未清空时,不需要重新设置和清空
            if(Objects.equals(Thread.currentThread().getId() , mainThreadId)
                    && !ServiceContext.getContext().getHeaders().isEmpty()){
                return;
            }
            ServiceContext.removeContext();
            context = ServiceContext.getContext();
            context.setHeaders(serviceContextHeaders);
            for (String str: serviceContextHeaders.keySet()){
                // 写进MDC里面
                // logback中可以在日志文件中通过 %X{str} 输出和和当前线程相关联的MDC
                MDC.put(str,serviceContextHeaders.get(str));
            }
        }catch (Exception e){
            throw e;
        }
    }
    private void afterExecute(){
        try {
            ServiceContext.removeContext();
            // 移除MDC信息
            for (String str: serviceContextHeaders.keySet()){
                // 写进MDC里面
                MDC.remove(str);
            }
        } catch (Exception e) {
            throw e;
        }
    }

}

替换 ThreadPoolTaskExecutor Bean

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

/**
 * ThreadPoolTaskExecutor Bean实例后置处理器
 * @author laijunjie
 * @date 2020/12/23
 */
public class ThreadPoolTaskExecutorBeanPostProcessor implements BeanPostProcessor {
    @Override
    public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
        return bean;
    }

    @Override
    public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
        if(bean instanceof ThreadPoolTaskExecutor){
            ThreadPoolTaskExecutor executor = (ThreadPoolTaskExecutor)bean;
            // 替换为自定义的
            return new BzThreadPoolTaskExecutor(executor);
        }else{
            return bean;
        }
    }
}

总结

  • 这里仅介绍了 ThreadPoolTaskExecutor 使用 ThreadLocal 实现线程之间的参数传递,再通过 MDC 将关键信息可控的输出到日志文件中
  • ThreadLocal 一开始的上下文信息设置应该是各种过滤器与拦截器中实现。
  • 有道绕点的地方一:Runnable 包装类初始化获取上下文信息,这是 ThreadLocal 的功能,根据线程获取该线程的上下文信息。但第二次执行是线程池使用线程执行这个包装类的 run 方法,这时此线程非原来的线程,上下文信息自然不同。
  • 有道绕的地方二: 执行原 Runnable 的 run() 这方法会让有些人觉得是开启了一个新线程,其实它就只是一个方法而已
  • BzCallable 类并没有补全,因为我觉得它与 BzRunnable 代码非常相似。

不止 ThreadPoolTaskExecutor 可以这也替换,其他对象也可以这么替换,装饰器设计模式配上 BeanPostProcessor 感觉可以为很多源码增加功能。有种 ”神“ 的感觉(●'◡'●)