分布式限流ratelimiter

  • Post author:
  • Post category:其他


基于redis+ratelimiter分布式限流

1、导入依赖

<dependency>
	<groupId>org.springframework.boot</groupId>
	<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>

<dependency>
	<groupId>org.springframework.boot</groupId>
	<artifactId>spring-boot-starter-aop</artifactId>
</dependency>

2、定义注解RateLimiter

/**
 * - Description:
 *      https://www.cnblogs.com/lywJ/p/10715367.html
 *      根据ip限流:在方法上声明注解(同一个ip每秒允许访问5次):@RateLimiter(period = 1, ipLimitCount = 5, limitMode = LimiterMode.IP)
 *      根据key限流:在方法上声明注解(每秒允许访问5次):@RateLimiter(period = 1, keyLimitCount = 5, limitMode = LimiterMode.KEY)
 */
@Target({ElementType.METHOD,ElementType.TYPE})
@Documented
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimiter {

    /**
     * 前缀
     * @return
     */
    String prefix() default "ratelimiter_";

    /**
     * 模块的名字
     *
     * @return String
     */
    String project() default "project_";

    /**
     * 资源key
     * @return
     */
    String key() default "api_";

    /**
     * 给定的时间段
     * @return
     */
    int period() default 30;

    /**
     * key 的限制个数
     * @return
     */
    int keyLimitCount() default 900;

    /**
     * ip 的限制个数
     * @return
     */
    int ipLimitCount() default 15;

    /**
     * 限速模式
     *
     * @return int
     */
    LimiterMode limitMode() default LimiterMode.IP;

}

public enum LimiterMode {

    /**
     * 根据IP限流
     */
    IP,
    /**
     * 根据指定的key限流
     */
    KEY,
    /**
     * 二者结合的方式限流
     */
    COMBINATION

}

3、定义AOP

@Slf4j
@Aspect
public class RateLimiterAop {

    /**
     * 限流lua脚本
     */
    private static final String LUA_SCRIPT;

    static {
        StringBuilder lua = new StringBuilder();
        lua.append("local c\n");
        lua.append("c = redis.call('get',KEYS[1])\n");
        // 执行计算器自加
        lua.append("c = redis.call('incr',KEYS[1])\n");
        lua.append("if tonumber(c) == 1 then\n");
        // 从第一次调用开始限流,设置对应键值的过期
        lua.append("redis.call('expire',KEYS[1],ARGV[1])\n");
        lua.append("end\n");
        lua.append("return c;");
        LUA_SCRIPT = lua.toString();
    }

    private static final String UNKNOWN = "unknown";

    private RedisTemplate<String, Serializable> intRedisTemplate;

    public RateLimiterAop(RedisTemplate<String, Serializable> intRedisTemplate) {
        this.intRedisTemplate = intRedisTemplate;
    }

    @Before("@annotation(com.demo.lixboot.ratelimit.RateLimiter)")
    public void interceptor(JoinPoint joinPoint) {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        boolean limitFlag = false;
        RateLimiter limitAnnotation = method.getAnnotation(RateLimiter.class);
        String project = limitAnnotation.project();
        String limitKey = limitAnnotation.key();
        String ip = "";
        int limitPeriod = limitAnnotation.period();
        int keyLimitCount = limitAnnotation.keyLimitCount();
        int ipLimitCount = limitAnnotation.ipLimitCount();
        if (limitAnnotation.limitMode().equals(LimiterMode.IP) || limitAnnotation.limitMode().equals(LimiterMode.COMBINATION)) {
            ip = getIpAddress();
        }
        if (limitAnnotation.limitMode().equals(LimiterMode.KEY)){
            List<String> keys = Arrays.asList(limitAnnotation.prefix() + "" + project + "" + limitKey);
            RedisScript<Number> redisScriptByKey = new DefaultRedisScript<>(LUA_SCRIPT, Number.class);
            Number keyCount = intRedisTemplate.execute(redisScriptByKey, keys, limitPeriod);
            log.info("根据key限流: project={} , key = {}, count = {}", project, limitKey, keyCount);
            if (keyCount != null && keyCount.intValue() > keyLimitCount) {
                log.info("keyCount:{},keyLimitCount:{}",keyCount,keyLimitCount);
                limitFlag = true;
            }
        }else if (limitAnnotation.limitMode().equals(LimiterMode.COMBINATION)) {
            List<String> ipKeys = Arrays.asList(limitAnnotation.prefix() + "" + project + "" + limitKey + ip);
            List<String> keyKeys = Arrays.asList(limitAnnotation.prefix() + "" + project + "" + limitKey);
            RedisScript<Number> redisScriptByKey = new DefaultRedisScript<>(LUA_SCRIPT, Number.class);
            Number keyCount = intRedisTemplate.execute(redisScriptByKey, keyKeys, limitPeriod);
            RedisScript<Number> redisScriptByIp = new DefaultRedisScript<>(LUA_SCRIPT, Number.class);
            Number ipCount = intRedisTemplate.execute(redisScriptByIp, ipKeys, limitPeriod);
            log.info("key和ip结合限流: project={} , key = {}, ip = {}, count = {}", project, limitKey, ip,
                    ipCount);
            boolean checkFlag = (ipCount != null && ipCount.intValue() > ipLimitCount) || (keyCount != null && keyCount.intValue() > keyLimitCount);
            if (checkFlag) {
                limitFlag = true;
            }
        } else {
            List<String> keys = Arrays.asList(limitAnnotation.prefix() + "" + project + "" + limitKey + ip);
            RedisScript<Number> redisScriptByIp = new DefaultRedisScript<>(LUA_SCRIPT, Number.class);
            Number ipCount = intRedisTemplate.execute(redisScriptByIp, keys, limitPeriod);
            log.info("根据ip限流: project={} , key = {}, ip = {}, count = {}", project, limitKey, ip,
                    ipCount);
            if (ipCount != null && ipCount.intValue() > ipLimitCount) {
                log.info("ipCount:{},ipLimitCount:{}",ipCount,ipLimitCount);
                limitFlag = true;
            }
        }
        if (limitFlag) {
            throw new BusinessException(ResponseCodeEnum.API_ACCESS_FREQUENTLY);
        }

    }

    private String getIpAddress() {
        HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes())
                .getRequest();
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        return ip;
    }

}

4、定义RedisTemplate

默认情况下 spring-boot-data-redis 为我们提供了StringRedisTemplate 但是满足不了其它类型的转换,所以还是得自己去定义其它类型的模板….

@Configuration
@AutoConfigureAfter(RedisAutoConfiguration.class)
public class RateLimiterAutoConfiguration {

    @Bean
    @ConditionalOnMissingBean(name = "intRedisTemplate")
    public RedisTemplate<String, Serializable> intRedisTemplate(RedisConnectionFactory redisConnectionFactory) {
        RedisTemplate<String, Serializable> template = new RedisTemplate<>();
        template.setKeySerializer(new StringRedisSerializer());
        template.setValueSerializer(new GenericJackson2JsonRedisSerializer());
        template.setConnectionFactory(redisConnectionFactory);
        return template;
    }

    @Bean
    public RateLimiterAop rateLimiterAop(RedisTemplate<String, Serializable> intRedisTemplate){
        return new RateLimiterAop(intRedisTemplate);
    }

}

5、controller使用

@RateLimiter(period = 5,ipLimitCount = 2,limitMode = LimiterMode.IP)
@ApiOperation(value = "测试RateLimit", notes = "测试RateLimit")
@GetMapping(value = "/testRateLimit")
public BaseResponse testRateLimit() {
    System.out.println("测试ratelimit方法,该方法同一个IP,5秒内最多访问2次");
    return BaseResponse.success(ResponseCodeEnum.OK);
}



版权声明:本文为lixianrich原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。