Java实现分布式Redis限流记录
本章将简单记录Redis分布式限流实现,还有很多不懂的地方先记录
背景
- Spring Framework 3.2.8.RELEASE版本
- Jedis 2.72 版本 (2.6 版本 Jedis 就支持 Lua 脚本调用了,方法是 eval )
因为 Redis 的使用是需要多个数据源,所以整体 Redis 的限流并未向网上常见的通过 Spring 的 redisTemplate 这种方式进行实现,甚至都没引入 spring-data-redis 这个包,所以实现仅使用了 Jedis 依赖包提供的。
实现代码细节
- 准备个限流器接口
/**
* 限流器
*
*/
public interface LimiterStrategy {
/**,
* 返回是否应该通过
*
* @param key
* @return
*/
boolean access(String key);
}
- 准备个限流器抽象类实现上面接口
import xxx..framework.redis.RedisManager;
import xxx..limiter.policy.LimiterPolicy;
import xxx..common.utils.ServerValidUtils;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.ByteStreams;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.io.ClassPathResource;
import org.springframework.scripting.support.ResourceScriptSource;
import redis.clients.jedis.Jedis;
import java.io.InputStream;
import java.util.concurrent.ConcurrentMap;
/**
* 限流器的抽象父类
*/
@Slf4j
public abstract class AbstractLimiterStrategy implements LimiterStrategy {
// 这里是避免重复读取文件与防止并发问题
private final static ConcurrentMap<String,String> scriptMapping = Maps.newConcurrentMap();
// lua 脚本路径
private String scriptPath;
// lua 脚本所需参数
private LimiterPolicy limiterPolicy;
// lua 脚本内容
private String script;
/**
* 抽象父类限流器的构造器
*
* @param scriptPath
* @param limiterPolicy 一个参数的封装类
*/
public AbstractLimiterStrategy(String scriptPath, LimiterPolicy limiterPolicy) {
// ServerValidUtils 自己实现的断言
ServerValidUtils.validBlank(scriptPath, "scriptPathv is null");
this.scriptPath = scriptPath;
ServerValidUtils.validObj(limiterPolicy,"limiterPolicy Can't NULL");
this.limiterPolicy = limiterPolicy;
this.init();
}
public AbstractLimiterStrategy(LimiterPolicy limiterPolicy) {
this.scriptPath = this.LimiterFilePath();
ServerValidUtils.validObj(limiterPolicy,"limiterPolicy Can't NULL");
this.limiterPolicy = limiterPolicy;
this.init();
}
// 这个抽象方法是获取文件路径的
public abstract String LimiterFilePath();
/**
* 初始化限流器脚本内容
*/
private void init() {
String mapScript = scriptMapping.get(this.scriptPath);
if(StringUtils.isBlank(mapScript)){
try {
// 构建获取 lua 脚本的脚本
// classpath: 扫描的是resources目录下的
// 获取资源
ResourceScriptSource resourceScriptSource = new ResourceScriptSource(new ClassPathResource(this.scriptPath));
InputStream inputStream = resourceScriptSource.getResource().getInputStream();
byte[] scriptBytes = ByteStreams.toByteArray(inputStream);
scriptMapping.putIfAbsent(this.scriptPath,new String(scriptBytes));
} catch (Exception e) {
log.error("init limiter error: The file may not exist", e);
throw new RuntimeException(e);
}
}
this.script = scriptMapping.get(this.scriptPath);
}
@Override
public boolean access(String key) {
// RedisManager.getJedisPool().getResource() 是自己内部封装的,重点是获取到 Jedis
try(Jedis jedis = RedisManager.getJedisPool().getResource()){
// 调用 eval 方法,参数分别是:脚本内容、Keys集合,传入参数集合
Long remain = (Long)jedis.eval(this.script, Lists.asList(key, new String[]{}), limiterPolicy.toParams());
// remain 这个脚本返回的不是剩余数量(具体看脚本实现)
log.info("限流器类别:{} | key :{} 限流器内许可数量为:{} ", limiterPolicy.getClass().getSimpleName(), key, remain);
return remain > 0;
}catch (Exception e){
log.error("限流器调用错误",e);
return false;
}
}
}
- 再继承这个抽象类实现一个令牌桶限流器
import xxx.limiter.policy.LimiterPolicy;
/**
* 令牌桶限流器
*
*/
public class TokenBucketLimiterStrategy extends AbstractLimiterStrategy {
/**
* lua脚本路径
* 该脚本每次调用 access 仅减少一个令牌 (脚本内觉得的)
*/
static final String SCRIPT_FILE_NAME = "lua/Barrel-Token.lua";
// LimiterPolicy 脚本所需参数类
public TokenBucketLimiterStrategy(LimiterPolicy limiterPolicy) {
super(limiterPolicy);
}
@Override
public String LimiterFilePath() {
return SCRIPT_FILE_NAME;
}
}
- 脚本参数接口
import java.util.List;
/**
* 限制器脚本所需参数接口
*/
public interface LimiterPolicy {
/**
* 转成字符串数组,数组顺序与脚本取参顺序有关
* @return
*/
List<String> toParams();
}
- 令牌限流脚本所需参数的实现类
import com.google.common.collect.Lists;
import java.util.List;
/**
* 令牌桶限流器的执行对象
*/
public class TokenBucketLimiterPolicy implements LimiterPolicy {
/**
* 限流时间间隔
* (重置桶内令牌的时间间隔)
*/
private final long resetBucketInterval;
/**
* 最大令牌数量
*/
private final long bucketMaxTokens;
/**
* 初始可存储数量
*/
private final long initTokens;
/**
* 每个令牌产生的时间
*/
private final long intervalPerPermit;
/**
* 令牌桶对象的构造器
* @param bucketMaxTokens 桶的令牌上限
* @param resetBucketInterval 限流时间间隔 (单位毫秒)
* @param initTokens 初始化令牌数
*/
public TokenBucketLimiterPolicy(long bucketMaxTokens, long resetBucketInterval, long initTokens) {
// 最大令牌数
this.bucketMaxTokens = bucketMaxTokens;
// 限流时间间隔
this.resetBucketInterval = resetBucketInterval;
// 令牌的产生间隔 = 限流时间 / 最大令牌数
this.intervalPerPermit = resetBucketInterval / bucketMaxTokens;
// 初始令牌数
this.initTokens = initTokens;
}
public long getResetBucketInterval() {
return resetBucketInterval;
}
public long getBucketMaxTokens() {
return bucketMaxTokens;
}
public long getInitTokens() {
return initTokens;
}
public long getIntervalPerPermit() {
return intervalPerPermit;
}
//这个顺序和脚本取值有关系
@Override
public List<String> toParams() {
List<String > list = Lists.newArrayList();
list.add(String.valueOf(getIntervalPerPermit()));
list.add(String.valueOf(System.currentTimeMillis()));
list.add(String.valueOf(getInitTokens()));
list.add(String.valueOf(getBucketMaxTokens()));
list.add(String.valueOf(getResetBucketInterval()));
return list;
}
}
- 在src.main.resources目录创建lua目录放个Barrel-Token.lua文件,内容如下
--[[
1. key - 令牌桶的 key
2. intervalPerTokens - 生成令牌的间隔(ms)
3. curTime - 当前时间
4. initTokens - 令牌桶初始化的令牌数
5. bucketMaxTokens - 令牌桶的上限
6. resetBucketInterval - 重置桶内令牌的时间间隔
7. currentTokens - 当前桶内令牌数
8. bucket - 当前 key 的令牌桶对象
]] --
local key = KEYS[1]
local intervalPerTokens = tonumber(ARGV[1])
local curTime = tonumber(ARGV[2])
local initTokens = tonumber(ARGV[3])
local bucketMaxTokens = tonumber(ARGV[4])
local resetBucketInterval = tonumber(ARGV[5])
local bucket = redis.call('hgetall', key)
local currentTokens
-- 若当前桶未初始化,先初始化令牌桶
if table.maxn(bucket) == 0 then
-- 初始桶内令牌
currentTokens = initTokens
-- 设置桶最近的填充时间是当前
redis.call('hset', key, 'lastRefillTime', curTime)
-- 初始化令牌桶的过期时间, 设置为间隔的 1.5 倍
redis.call('pexpire', key, resetBucketInterval * 1.5)
-- 若桶已初始化,开始计算桶内令牌
-- 为什么等于 4 ? 因为有两对 field, 加起来长度是 4
-- { "lastRefillTime(上一次更新时间)","curTime(更新时间值)","tokensRemaining(当前保留的令牌)","令牌数" }
elseif table.maxn(bucket) == 4 then
-- 上次填充时间
local lastRefillTime = tonumber(bucket[2])
-- 剩余的令牌数
local tokensRemaining = tonumber(bucket[4])
-- 当前时间大于上次填充时间
if curTime > lastRefillTime then
-- 拿到当前时间与上次填充时间的时间间隔
-- 举例理解: curTime = 2620 , lastRefillTime = 2000, intervalSinceLast = 620
local intervalSinceLast = curTime - lastRefillTime
-- 如果当前时间间隔 大于 令牌的生成间隔
-- 举例理解: intervalSinceLast = 620, resetBucketInterval = 1000
if intervalSinceLast > resetBucketInterval then
-- 将当前令牌填充满
currentTokens = initTokens
-- 更新重新填充时间
redis.call('hset', key, 'lastRefillTime', curTime)
-- 如果当前时间间隔 小于 令牌的生成间隔
else
-- 可授予的令牌 = 向下取整数( 上次填充时间与当前时间的时间间隔 / 两个令牌许可之间的时间间隔 )
-- 举例理解 : intervalPerTokens = 200 ms , 令牌间隔时间为 200ms
-- intervalSinceLast = 620 ms , 当前距离上一个填充时间差为 620ms
-- grantedTokens = 620/200 = 3.1 = 3
local grantedTokens = math.floor(intervalSinceLast / intervalPerTokens)
-- 可授予的令牌 > 0 时
-- 举例理解 : grantedTokens = 620/200 = 3.1 = 3
if grantedTokens > 0 then
-- 生成的令牌 = 上次填充时间与当前时间的时间间隔 % 两个令牌许可之间的时间间隔
-- 举例理解 : padMillis = 620%200 = 20
-- curTime = 2620
-- curTime - padMillis = 2600
local padMillis = math.fmod(intervalSinceLast, intervalPerTokens)
-- 将当前令牌桶更新到上一次生成时间
redis.call('hset', key, 'lastRefillTime', curTime - padMillis)
end
-- 更新当前令牌桶中的令牌数
-- Math.min(根据时间生成的令牌数 + 剩下的令牌数, 桶的限制) => 超出桶最大令牌的就丢弃
currentTokens = math.min(grantedTokens + tokensRemaining, bucketMaxTokens)
end
else
-- 如果当前时间小于或等于上次更新的时间, 说明刚刚初始化, 当前令牌数量等于桶内令牌数
-- 不需要重新填充
currentTokens = tokensRemaining
end
end
-- 如果当前桶内令牌小于 0,抛出异常
assert(currentTokens >= 0)
-- 如果当前令牌 == 0 ,更新桶内令牌, 返回 0
if currentTokens == 0 then
redis.call('hset', key, 'tokensRemaining', currentTokens)
return 0
else
-- 如果当前令牌 大于 0, 更新当前桶内的令牌 -1 , 再返回当前桶内令牌数
redis.call('hset', key, 'tokensRemaining', currentTokens - 1)
return currentTokens
end
- 脚本思路图
单元测试
下面是1秒20个令牌,初始化0个,所以是每过50毫秒创建一个令牌,所以看最后stopWatch.getTotalTimeMillis()输出时间数下抢到令牌的数量,去掉部分开始与结束时间消耗,是没有问题的。
@Test
public void test5() throws InterruptedException {
final TokenBucketLimiterStrategy tokenBucketLimiterStrategy = new TokenBucketLimiterStrategy(new TokenBucketLimiterPolicy(20, 1000, 0));
StopWatch stopWatch = new StopWatch();
stopWatch.start();
execute(new Runnable() {
@Override
public void run() {
ServiceContext.getContext().setFbAccessNo("xna");
tokenBucketLimiterStrategy.access("test-lua");
}
},20,10);
stopWatch.stop();
System.out.println(stopWatch.getTotalTimeMillis());
}
public static void execute(final Runnable run, int threadSize, int loop){
AtomicInteger count = new AtomicInteger();
for (int j = 0; j <loop ; j++) {
System.out.println("第"+(j+1)+"轮并发测试,每轮并发数"+threadSize);
final CountDownLatch countDownLatch = new CountDownLatch(1);
Set<Thread> threads = new HashSet<>(threadSize);
//批量新建线程
for (int i = 0; i <threadSize ; i++) {
threads.add(
new Thread(new Runnable() {
@Override
public void run() {
try {
countDownLatch.await();
run.run();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
},"Thread"+count.getAndIncrement()));
}
//开启所有线程并确保其进入Waiting状态
for (Thread thread : threads) {
thread.start();
while(thread.getState() != Thread.State.WAITING);
}
//唤醒所有在countDownLatch上等待的线程
countDownLatch.countDown();
//等待所有线程执行完毕,开启下一轮
for (Thread thread : threads) {
try {
thread.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}
public static void execute(Runnable run){
execute(run,1000,1);
}
public static void execute(Runnable run,int threadSize){
execute(run,threadSize,1);
}
总结
Get到了Redis的脚本调用方式,也走了不少弯路,最后还是成功了,总之有时间学下更高级Redis指令使用吧