最近需要自己做一个限流功能,其他业务代码都好说。唯一的终点就是限流实现,想到redis可以实现令牌桶。一拍脑门,就用它了!
场景描述
真实开发中才发现想的太简单,如果是基于redis提供的命令在代码中调用的话,效率是小事。原子性根本无法保证!如下:
- 线程1 获取到令牌桶获取总数为10
- 线程1 消耗1个令牌,剩余9个令牌
- 线程2 获取到令牌桶获取总数为10
- 线程1 刷新令牌为9个
- 线程2 消耗1个令牌,剩余9个令牌
- 线程2 刷新令牌为9个
原子性完全无法保证。加锁? 那多节点岂不是需要在引入分布式锁?看来服务器中实现不可取,保证原子性的话,势必要写LUA代码执行了。网上翻阅了一些demo,没找到想要的。想到自己搭建的spring cloud gateway是有内置令牌桶实现的。开始翻阅源码,终于找到最好的方案。
gateway中redis令牌桶实现类是:org.springframework.cloud.gateway.filter.ratelimit.RedisRateLimiter
public Mono<Response> isAllowed(String routeId, String id) {
...
// 这行是通过lua判断是否被限流
Flux<List<Long>> flux = this.redisTemplate.execute(this.script, keys, scriptArgs);
...
}
顺着这个类找到lua代码是从org.springframework.cloud.gateway.config.GatewayRedisAutoConfiguration注入进来的
public RedisScript redisRequestRateLimiterScript() {
DefaultRedisScript redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(
new ResourceScriptSource(new ClassPathResource("META-INF/scripts/request_rate_limiter.lua")));
redisScript.setResultType(List.class);
return redisScript;
}
核心在
request_rate_limiter.lua
这个文件中
-- 获取到限流资源令牌数的key和响应时间戳的key
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)
-- 分别获取填充速率、令牌桶容量、当前时间戳、消耗令牌数
local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
-- 计算出失效时间,大概是令牌桶填满时间的两倍
local fill_time = capacity/rate
local ttl = math.floor(fill_time*2)
--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)
-- 获取到最近一次的剩余令牌数,如果不存在说明令牌桶是满的
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)
-- 上次消耗令牌的时间戳,不存在视为0
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)
-- 计算出间隔时间
local delta = math.max(0, now-last_refreshed)
-- 剩余令牌数量 = “令牌桶容量” 和 “最后令牌数+(填充速率*时间间隔)”之间的最小值
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
-- 如果剩余令牌数量大于等于消耗令牌的数量则流量通过,否则不通过
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
new_tokens = filled_tokens - requested
allowed_num = 1
end
--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)
-- 最后保存数据现场
if ttl > 0 then
redis.call("setex", tokens_key, ttl, new_tokens)
redis.call("setex", timestamp_key, ttl, now)
end
-- return { allowed_num, new_tokens, capacity, filled_tokens, requested, new_tokens }
return { allowed_num, new_tokens }
解决方案
好了,可以开始写自己的限流工具类了。
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
/**
* redis限流器
*
*/
public class MyRedisLimiter{
private RedisTemplate redisTemplate;
private static final Long SUCCESS_FLAG = 1L;
/**
* 判断是否允许访问
*@param id 这次获取令牌桶的id
*@param rate 每秒填充速率
*@param capacity 令牌桶最大容量
*@param tokens 每次访问消耗几个令牌
*@return true 允许访问 false 不允许访问
*/
public boolean isAllowed(String id,int rate,int capacity,int tokens){
RedisScript<Long> redisScript = new DefaultRedisScript<>(SCRIPT,Long.class);
Object result = redisTemplate.execute(redisScript,
getKey(id),rate, capacity,
Instant.now().getEpochSecond(), tokens);
return SUCCESS_FLAG.equals(result);
}
private List<String> getKey(String id){
String prefix = "limiter:"+id;
String tokenKey = prefix + ":tokens";
String timestampKey = prefix + ":timestamp";
return Arrays.asList(tokenKey, timestampKey);
}
private static final String SCRIPT = "local tokens_key = KEYS[1]\n" +
"local timestamp_key = KEYS[2]\n" +
"local rate = tonumber(ARGV[1])\n" +
"local capacity = tonumber(ARGV[2])\n" +
"local now = tonumber(ARGV[3])\n" +
"local requested = tonumber(ARGV[4])\n" +
"local fill_time = capacity/rate\n" +
"local ttl = math.floor(fill_time*2)\n" +
"local last_tokens = tonumber(redis.call('get', tokens_key))\n" +
"if last_tokens == nil then\n" +
" last_tokens = capacity\n" +
"end\n" +
"local last_refreshed = tonumber(redis.call('get', timestamp_key))\n" +
"if last_refreshed == nil then\n" +
" last_refreshed = 0\n" +
"end\n" +
"local diff_time = math.max(0, now-last_refreshed)\n" +
"local filled_tokens = math.min(capacity, last_tokens+(diff_time*rate))\n" +
"local allowed = filled_tokens >= requested\n" +
"local new_tokens = filled_tokens\n" +
"local allowed_num = 0\n" +
"if allowed then\n" +
" new_tokens = filled_tokens - requested\n" +
" allowed_num = 1\n" +
"end\n" +
"if ttl > 0 then\n" +
" redis.call('setex', tokens_key, ttl, new_tokens)\n" +
" redis.call('setex', timestamp_key, ttl, now)\n" +
"end\n" +
"return allowed_num\n";
}
完事儿,这个令牌桶实现完全是参考的spring cloud gateway,食用起来也比较放心。
简单描述一下这个工具类使用的效果吧。
isAllowed("testId1",1,60,1);
上面这个描述代表:令牌桶testId1。每分钟可通过访问60次。
当然这是理想情况。极限情况的话,应该是可以访问120次的。极限场景如下
- testId1令牌桶未被使用的时间>=60秒
- testId1令牌桶在某一刻开始被使用
- 令牌被消耗的同时也在被填充
- 在最初被使用的60秒内,令牌桶初始有60个令牌,使用期限有填充了60个
我个人理解,令牌桶主要是为了保证使用速率。对于上面这个场景,到底算不算bug。看每个人的使用情况。不过我已经对上述情况做了解决,只需要如下的几点小改动。
- 传入周期时间内最大流量数(周期时间:桶容量60,填充速率1/s,那么周期时间=60s)
- 获取key的方法增加一个记录周期时间内数量的key
- 修改lua脚本
改动如下:
/**
* 判断是否允许访问
*@param id 这次获取令牌桶的id
*@param rate 每秒填充速率
*@param capacity 令牌桶最大容量
*@param tokens 每次访问消耗几个令牌
*@param maxCount 周期时间内最大访问量
*@return true 允许访问 false 不允许访问
*/
public boolean isAllowed(String id,int rate,int capacity,int tokens,int maxCount){
RedisScript<Long> redisScript = new DefaultRedisScript<>(SCRIPT,Long.class);
Object result = redisTemplate.execute(redisScript,
getKey(id),rate, capacity,
Instant.now().getEpochSecond(), tokens,maxCount);
return SUCCESS_FLAG.equals(result);
}
private List<String> getKey(String id){
String prefix = "limiter:"+id;
String tokenKey = prefix + ":tokens";
String timestampKey = prefix + ":timestamp";
String countKey = prefix + ":count";
return Arrays.asList(tokenKey, timestampKey,countKey);
}
private static final String SCRIPT = "local tokens_key = KEYS[1]\n" +
"local timestamp_key = KEYS[2]\n" +
"local count_key = KEYS[3]\n" +
"local rate = tonumber(ARGV[1])\n" +
"local capacity = tonumber(ARGV[2])\n" +
"local now = tonumber(ARGV[3])\n" +
"local requested = tonumber(ARGV[4])\n" +
"local min_max = tonumber(ARGV[5])\n" +
"local fill_time = capacity/rate\n" +
"local ttl = math.floor(fill_time*2)\n" +
"local has_count = tonumber(redis.call('get', count_key))\n" +
"if has_count == nil then\n" +
" has_count = 0\n" +
"end\n" +
"if has_count >= min_max then\n" +
"return 0\n" +
"end\n" +
"local last_tokens = tonumber(redis.call('get', tokens_key))\n" +
"if last_tokens == nil then\n" +
" last_tokens = capacity\n" +
"end\n" +
"local last_refreshed = tonumber(redis.call('get', timestamp_key))\n" +
"if last_refreshed == nil then\n" +
" last_refreshed = 0\n" +
"end\n" +
"local diff_time = math.max(0, now-last_refreshed)\n" +
"local filled_tokens = math.min(capacity, last_tokens+(diff_time*rate))\n" +
"local allowed = filled_tokens >= requested\n" +
"local new_tokens = filled_tokens\n" +
"local allowed_num = 0\n" +
"if allowed then\n" +
" new_tokens = filled_tokens - requested\n" +
" allowed_num = 1\n" +
"end\n" +
"if ttl > 0 then\n" +
" redis.call('setex', tokens_key, ttl, new_tokens)\n" +
" redis.call('setex', timestamp_key, ttl, now)\n" +
"end\n" +
"local count_ttl = tonumber(redis.call('ttl',count_key))\n" +
"if count_ttl < 0 then\n" +
" count_ttl = fill_time\n" +
"end\n" +
"redis.call('setex', count_key,count_ttl , has_count+1)\n" +
"return allowed_num\n";
}
这样的改动做到了保持访问速率和吞吐量的可控,但是有没有必要这样就看自己的需求了。
最后想说一句,优秀的开源项目真的是宝藏。学会使用它,然后站在巨人的肩膀上前进