Java实现分布式Redis限流记录

本章将简单记录Redis分布式限流实现,还有很多不懂的地方先记录

背景

  • Spring Framework 3.2.8.RELEASE版本
  • Jedis 2.72 版本 (2.6 版本 Jedis 就支持 Lua 脚本调用了,方法是 eval )

因为 Redis 的使用是需要多个数据源,所以整体 Redis 的限流并未向网上常见的通过 Spring 的 redisTemplate 这种方式进行实现,甚至都没引入 spring-data-redis 这个包,所以实现仅使用了 Jedis 依赖包提供的。

实现代码细节

  1. 准备个限流器接口
/**
 * 限流器
 *
 */
public interface LimiterStrategy {
    /**,
     * 返回是否应该通过
     *
     * @param key
     * @return
     */
    boolean access(String key);

}
  1. 准备个限流器抽象类实现上面接口
import xxx..framework.redis.RedisManager;
import xxx..limiter.policy.LimiterPolicy;
import xxx..common.utils.ServerValidUtils;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.ByteStreams;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.io.ClassPathResource;
import org.springframework.scripting.support.ResourceScriptSource;
import redis.clients.jedis.Jedis;

import java.io.InputStream;
import java.util.concurrent.ConcurrentMap;

/**
 * 限流器的抽象父类
 */
@Slf4j
public abstract class AbstractLimiterStrategy implements LimiterStrategy {

    // 这里是避免重复读取文件与防止并发问题
    private final static ConcurrentMap<String,String> scriptMapping = Maps.newConcurrentMap();

    // lua 脚本路径
    private String scriptPath;

    // lua 脚本所需参数
    private LimiterPolicy limiterPolicy;

    // lua 脚本内容
    private String script;

    /**
     * 抽象父类限流器的构造器
     *
     * @param scriptPath
     * @param limiterPolicy 一个参数的封装类
     */
    public AbstractLimiterStrategy(String scriptPath, LimiterPolicy limiterPolicy) {
        // ServerValidUtils 自己实现的断言
        ServerValidUtils.validBlank(scriptPath, "scriptPathv is null");
        this.scriptPath = scriptPath;
        ServerValidUtils.validObj(limiterPolicy,"limiterPolicy Can't NULL");
        this.limiterPolicy = limiterPolicy;
        this.init();
    }

    public AbstractLimiterStrategy(LimiterPolicy limiterPolicy) {
        this.scriptPath = this.LimiterFilePath();
        ServerValidUtils.validObj(limiterPolicy,"limiterPolicy Can't NULL");
        this.limiterPolicy = limiterPolicy;
        this.init();
    }

    // 这个抽象方法是获取文件路径的
    public abstract String LimiterFilePath();

    /**
     * 初始化限流器脚本内容
     */
    private void init() {
        String mapScript = scriptMapping.get(this.scriptPath);
        if(StringUtils.isBlank(mapScript)){
            try {
                // 构建获取 lua 脚本的脚本
                // classpath: 扫描的是resources目录下的
                // 获取资源
                ResourceScriptSource resourceScriptSource = new ResourceScriptSource(new ClassPathResource(this.scriptPath));
                InputStream inputStream = resourceScriptSource.getResource().getInputStream();
                byte[] scriptBytes = ByteStreams.toByteArray(inputStream);
                scriptMapping.putIfAbsent(this.scriptPath,new String(scriptBytes));
            } catch (Exception e) {
                log.error("init limiter error: The file may not exist", e);
                throw new RuntimeException(e);
            }
        }
        this.script = scriptMapping.get(this.scriptPath);
    }

    @Override
    public boolean access(String key) {
        // RedisManager.getJedisPool().getResource() 是自己内部封装的,重点是获取到 Jedis 
        try(Jedis jedis = RedisManager.getJedisPool().getResource()){
            // 调用 eval 方法,参数分别是:脚本内容、Keys集合,传入参数集合
            Long remain = (Long)jedis.eval(this.script, Lists.asList(key, new String[]{}), limiterPolicy.toParams());
            // remain 这个脚本返回的不是剩余数量(具体看脚本实现)
            log.info("限流器类别:{} | key :{} 限流器内许可数量为:{} ", limiterPolicy.getClass().getSimpleName(), key, remain);
            return remain > 0;
        }catch (Exception e){
            log.error("限流器调用错误",e);
            return false;
        }
    }

}
  1. 再继承这个抽象类实现一个令牌桶限流器
import xxx.limiter.policy.LimiterPolicy;

/**
 * 令牌桶限流器
 *
 */
public class TokenBucketLimiterStrategy extends AbstractLimiterStrategy {

    /**
     * lua脚本路径
     * 该脚本每次调用 access 仅减少一个令牌 (脚本内觉得的)
     */
    static final String SCRIPT_FILE_NAME = "lua/Barrel-Token.lua";

    // LimiterPolicy 脚本所需参数类
    public TokenBucketLimiterStrategy(LimiterPolicy limiterPolicy) {
        super(limiterPolicy);
    }

    @Override
    public String LimiterFilePath() {
        return SCRIPT_FILE_NAME;
    }
}
  1. 脚本参数接口
import java.util.List;

/**
 * 限制器脚本所需参数接口
 */
public interface LimiterPolicy {

    /**
     * 转成字符串数组,数组顺序与脚本取参顺序有关
     * @return
     */
    List<String> toParams();

}
  1. 令牌限流脚本所需参数的实现类
import com.google.common.collect.Lists;

import java.util.List;

/**
 * 令牌桶限流器的执行对象
 */
public class TokenBucketLimiterPolicy implements LimiterPolicy {

    /**
     * 限流时间间隔
     * (重置桶内令牌的时间间隔)
     */
    private final long resetBucketInterval;
    /**
     * 最大令牌数量
     */
    private final long bucketMaxTokens;

    /**
     * 初始可存储数量
     */
    private final long initTokens;

    /**
     * 每个令牌产生的时间
     */
    private final long intervalPerPermit;

    /**
     * 令牌桶对象的构造器
     * @param bucketMaxTokens 桶的令牌上限
     * @param resetBucketInterval 限流时间间隔 (单位毫秒)
     * @param initTokens 初始化令牌数
     */
    public TokenBucketLimiterPolicy(long bucketMaxTokens, long resetBucketInterval, long initTokens) {
        // 最大令牌数
        this.bucketMaxTokens = bucketMaxTokens;
        // 限流时间间隔
        this.resetBucketInterval = resetBucketInterval;
        // 令牌的产生间隔 = 限流时间 / 最大令牌数
        this.intervalPerPermit = resetBucketInterval / bucketMaxTokens;
        // 初始令牌数
        this.initTokens = initTokens;
    }

    public long getResetBucketInterval() {
        return resetBucketInterval;
    }

    public long getBucketMaxTokens() {
        return bucketMaxTokens;
    }

    public long getInitTokens() {
        return initTokens;
    }

    public long getIntervalPerPermit() {
        return intervalPerPermit;
    }

    //这个顺序和脚本取值有关系
    @Override
    public List<String> toParams() {
        List<String > list = Lists.newArrayList();
        list.add(String.valueOf(getIntervalPerPermit()));
        list.add(String.valueOf(System.currentTimeMillis()));
        list.add(String.valueOf(getInitTokens()));
        list.add(String.valueOf(getBucketMaxTokens()));
        list.add(String.valueOf(getResetBucketInterval()));
        return list;
    }

}
  1. 在src.main.resources目录创建lua目录放个Barrel-Token.lua文件,内容如下
--[[
  1. key - 令牌桶的 key
  2. intervalPerTokens - 生成令牌的间隔(ms)
  3. curTime - 当前时间
  4. initTokens - 令牌桶初始化的令牌数
  5. bucketMaxTokens - 令牌桶的上限
  6. resetBucketInterval - 重置桶内令牌的时间间隔
  7. currentTokens - 当前桶内令牌数
  8. bucket - 当前 key 的令牌桶对象
]] --

local key = KEYS[1]
local intervalPerTokens = tonumber(ARGV[1])
local curTime = tonumber(ARGV[2])
local initTokens = tonumber(ARGV[3])
local bucketMaxTokens = tonumber(ARGV[4])
local resetBucketInterval = tonumber(ARGV[5])

local bucket = redis.call('hgetall', key)
local currentTokens

-- 若当前桶未初始化,先初始化令牌桶
if table.maxn(bucket) == 0 then
    -- 初始桶内令牌
    currentTokens = initTokens
    -- 设置桶最近的填充时间是当前
    redis.call('hset', key, 'lastRefillTime', curTime)
    -- 初始化令牌桶的过期时间, 设置为间隔的 1.5 倍
    redis.call('pexpire', key, resetBucketInterval * 1.5)

-- 若桶已初始化,开始计算桶内令牌
-- 为什么等于 4 ? 因为有两对 field, 加起来长度是 4
-- { "lastRefillTime(上一次更新时间)","curTime(更新时间值)","tokensRemaining(当前保留的令牌)","令牌数" }
elseif table.maxn(bucket) == 4 then

    -- 上次填充时间
    local lastRefillTime = tonumber(bucket[2])
    -- 剩余的令牌数
    local tokensRemaining = tonumber(bucket[4])

    -- 当前时间大于上次填充时间
    if curTime > lastRefillTime then

        -- 拿到当前时间与上次填充时间的时间间隔
        -- 举例理解: curTime = 2620 , lastRefillTime = 2000, intervalSinceLast = 620
        local intervalSinceLast = curTime - lastRefillTime

        -- 如果当前时间间隔 大于 令牌的生成间隔
        -- 举例理解: intervalSinceLast = 620, resetBucketInterval = 1000
        if intervalSinceLast > resetBucketInterval then

            -- 将当前令牌填充满
            currentTokens = initTokens

            -- 更新重新填充时间
            redis.call('hset', key, 'lastRefillTime', curTime)
            
        -- 如果当前时间间隔 小于 令牌的生成间隔
        else

            -- 可授予的令牌 = 向下取整数( 上次填充时间与当前时间的时间间隔 / 两个令牌许可之间的时间间隔 )
            -- 举例理解 : intervalPerTokens = 200 ms , 令牌间隔时间为 200ms
            --           intervalSinceLast = 620 ms , 当前距离上一个填充时间差为 620ms
            --           grantedTokens = 620/200 = 3.1 = 3
            local grantedTokens = math.floor(intervalSinceLast / intervalPerTokens)

            -- 可授予的令牌 > 0 时
            -- 举例理解 : grantedTokens = 620/200 = 3.1 = 3
            if grantedTokens > 0 then

                -- 生成的令牌 = 上次填充时间与当前时间的时间间隔 % 两个令牌许可之间的时间间隔
                -- 举例理解 : padMillis = 620%200 = 20
                --           curTime = 2620
                --           curTime - padMillis = 2600
                local padMillis = math.fmod(intervalSinceLast, intervalPerTokens)

                -- 将当前令牌桶更新到上一次生成时间
                redis.call('hset', key, 'lastRefillTime', curTime - padMillis)
            end

            -- 更新当前令牌桶中的令牌数
            -- Math.min(根据时间生成的令牌数 + 剩下的令牌数, 桶的限制) => 超出桶最大令牌的就丢弃
            currentTokens = math.min(grantedTokens + tokensRemaining, bucketMaxTokens)
        end
    else
        -- 如果当前时间小于或等于上次更新的时间, 说明刚刚初始化, 当前令牌数量等于桶内令牌数
        -- 不需要重新填充
        currentTokens = tokensRemaining
    end
end

-- 如果当前桶内令牌小于 0,抛出异常
assert(currentTokens >= 0)

-- 如果当前令牌 == 0 ,更新桶内令牌, 返回 0
if currentTokens == 0 then
    redis.call('hset', key, 'tokensRemaining', currentTokens)
    return 0
else
    -- 如果当前令牌 大于 0, 更新当前桶内的令牌 -1 , 再返回当前桶内令牌数
    redis.call('hset', key, 'tokensRemaining', currentTokens - 1)
    return currentTokens
end
  1. 脚本思路图

image-20220601204721992

单元测试

下面是1秒20个令牌,初始化0个,所以是每过50毫秒创建一个令牌,所以看最后stopWatch.getTotalTimeMillis()输出时间数下抢到令牌的数量,去掉部分开始与结束时间消耗,是没有问题的。

@Test
public void test5() throws InterruptedException {
    final TokenBucketLimiterStrategy tokenBucketLimiterStrategy = new TokenBucketLimiterStrategy(new TokenBucketLimiterPolicy(20, 1000, 0));
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    execute(new Runnable() {
        @Override
        public void run() {
            ServiceContext.getContext().setFbAccessNo("xna");
            tokenBucketLimiterStrategy.access("test-lua");
        }
    },20,10);
    stopWatch.stop();
    System.out.println(stopWatch.getTotalTimeMillis());
}
public static void execute(final Runnable  run, int threadSize, int loop){
    AtomicInteger count = new AtomicInteger();
    for (int j = 0; j <loop ; j++) {
        System.out.println("第"+(j+1)+"轮并发测试,每轮并发数"+threadSize);
        final CountDownLatch countDownLatch = new CountDownLatch(1);
        Set<Thread> threads = new HashSet<>(threadSize);
        //批量新建线程
        for (int i = 0; i <threadSize ; i++) {
            threads.add(
                    new Thread(new Runnable() {
                        @Override
                        public void run() {
                            try {
                                countDownLatch.await();
                                run.run();
                            } catch (InterruptedException e) {
                                e.printStackTrace();
                            }
                        }
                    },"Thread"+count.getAndIncrement()));
        }
        //开启所有线程并确保其进入Waiting状态
        for (Thread thread : threads) {
            thread.start();
            while(thread.getState() != Thread.State.WAITING);
        }
        //唤醒所有在countDownLatch上等待的线程
        countDownLatch.countDown();
        //等待所有线程执行完毕,开启下一轮
        for (Thread thread : threads) {
            try {
                thread.join();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }
}
public static void execute(Runnable  run){
    execute(run,1000,1);
}
public static void execute(Runnable  run,int threadSize){
    execute(run,threadSize,1);
}

总结

Get到了Redis的脚本调用方式,也走了不少弯路,最后还是成功了,总之有时间学下更高级Redis指令使用吧