自定义注解 + Redis + Lua脚本实现滑动窗口限流

  • Post author:
  • Post category:其他


核心思想:通过redis的zset来实现窗口滑动,从而达到限流的作用。

yml核心配置如下:

spring:
  application:
    name: demo
  redis:
    host: 127.0.0.1
    port: 6379
    database: 1
    
server:
  port: 8181
  servlet:
    context-path: /demo

核心pom文件如下:

	    <!-- redis相关  -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
        </dependency>
	    <!-- aop相关 -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-aop</artifactId>
        </dependency>
       <!-- lombak相关-->
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
        </dependency>



1、自定义注解

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Limit {

    /**
     * 资源描述(日志打印使用无实际意义)
     */
    String name() default "";

    /**
     * 资源前缀
     */
    String prefix() default "";

    /**
     * 资源的key
     */
    String key() default "";

    /**
     * 时间窗口 单位 s
     */
    int period();

    /**
     * 限制访问次数
     */
    int count();
}



2、自定义统一返回 + 自定义异常 + 全局异常拦截

  • 统一结构返回
@Data
public class BaseResponse<T>  {

    private int code;

    private T data;

    private String message;

    public BaseResponse() {
        this.code = HttpStatus.INTERNAL_SERVER_ERROR.value();
        this.message = HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase();
    }

    public BaseResponse(int code, String message, T data) {
        this.code = code;
        this.message = message;
        this.data = data;
    }

    public static <T> BaseResponse<T> with(int code, String msg, T data) {
        return new BaseResponse<>(code, msg, data);
    }

    public static <T> BaseResponse<T> success(T data) {
        return with(HttpStatus.OK.value(), "success", data);
    }
  
}
  • 自定义异常
public class LimitException extends RuntimeException{

    public LimitException(String msg) {
        super(msg);
    }
}
  • 全局异常拦截器
@Slf4j
@RestControllerAdvice
public class GlobalExceptionHandler {

    @ExceptionHandler(LimitException.class)
    public BaseResponse<?> limitException(LimitException limitException) {
        log.error("【限流发生异常】");
        return BaseResponse.with(HttpStatus.INTERNAL_SERVER_ERROR.value(), limitException.getMessage());
    }

}



3、自定义切面 + lua脚本

  • 自定义切面处理限流
@Aspect
@Component
public class LimitAspect {

    @Resource
    private RedisTemplate<String, Object> redisTemplate;

    private DefaultRedisScript<Long> redisScript;

    @PostConstruct
    public void init() {
	    // 仅加载一次lua脚本资源
        redisScript = new DefaultRedisScript<>();
        redisScript.setLocation(new ClassPathResource("lua/limit.lua"));
        redisScript.setResultType(Long.class);
    }

    @Before("@annotation(com.learn.demo.annotation.Limit)")
    public void handle() {
		// 先去获取资源信息
        Method method = resolveMethod(point);
        Limit limit = method.getAnnotation(Limit.class);
        if (limit != null) {
            // 去执行脚本 然后判断是否超过限流的时间了
            String key = limit.prefix() + "_" + limit.key();
            log.info("{}访问了系统{}", LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")), limit.name());
            Long number = redisTemplate.execute(redisScript, Collections.singletonList(key), limit.period(), System.currentTimeMillis(), limit.count());
            if (number != null && number.intValue() >= limit.count()) {
                throw new LimitException("超过接口访问限制,请稍后重试");
            }
        }
    }

    private Method resolveMethod(JoinPoint point) {
	    // 获取方法签名
        MethodSignature signature = (MethodSignature) point.getSignature();
        Class<?> targetClass = point.getTarget().getClass();
        Method method = getDeclaredMethod(targetClass, signature.getName(), signature.getMethod().getParameterTypes());
        if (method == null) {
            throw new IllegalStateException("无法解析目标方法: " + signature.getMethod().getName());
        }
        return method;
    }

    private Method getDeclaredMethod(Class<?> clazz, String name, Class<?>... parameterTypes) {
        try {
            return clazz.getDeclaredMethod(name, parameterTypes);
        } catch (NoSuchMethodException e) {
            Class<?> superClass = clazz.getSuperclass();
            if (superClass != null) {
                return getDeclaredMethod(superClass, name, parameterTypes);
            }
        }
        return null;
    }
}
  • lua脚本
--获取KEY
local key = KEYS[1];
--获取ARGV内的参数
-- 缓存时间
local expire = tonumber(ARGV[1]);

-- 当前时间
local currentMs = tonumber(ARGV[2]);

-- 最大次数
local count = tonumber(ARGV[3]);

--窗口开始时间
local windowStartMs = currentMs - expire * 1000;

--获取key的次数
local current = redis.call('zcount', key, windowStartMs, currentMs);

--如果key的次数存在且大于预设值直接返回当前key的次数
if current and tonumber(current) >= count then
    return tonumber(current);
end

-- 清除上一个窗口的数据
redis.call("ZREMRANGEBYSCORE", key, 0, windowStartMs);

-- 添加当前成员
redis.call("zadd", key, tostring(currentMs), currentMs);
redis.call("expire", key, expire);

--返回key的次数
return tonumber(current)



4、应用和测试

@RestController
@RequestMapping("/test")
public class TestController {

    @GetMapping("/limit")
    @Limit(name = "测试限流", prefix = "test", key = "limit", period = 10, count = 5)
    public BaseResponse<?> limit() {
        return BaseResponse.success("success");
    }
}

可以看到 在第6次访问的时候 发生了异常

idea控制台打印

并且通过postman测试得到的返回也是限流的返回

postman测试返回结果



5、踩坑记录

在使用过程中,发现执行lua脚本时一直获取不到ARGV的值,报错信息为:@user_script:20: user_script:20: attempt to perform arithmetic on local ‘expire’ (a nil value)。这个是因为lua脚本中的值为空,并且还使用该值。

解决方案:指定RedisTemplate的序列化

@Configuration
public class RedisConfig {

    @Bean(name = "redisTemplate")
    @ConditionalOnClass(RedisOperations.class)
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory factory) {
        RedisTemplate<String, Object> template = new RedisTemplate<>();
        template.setConnectionFactory(factory);

        Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);
        ObjectMapper mapper = new ObjectMapper();
        mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        mapper.activateDefaultTyping(mapper.getPolymorphicTypeValidator(), ObjectMapper.DefaultTyping.NON_FINAL);
        jackson2JsonRedisSerializer.setObjectMapper(mapper);
        StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
        // 指定 key 的序列化格式
        template.setKeySerializer(stringRedisSerializer);
        // 指定 hash 的 key 的序列化格式
        template.setHashKeySerializer(stringRedisSerializer);
        // 指定 value 的序列化格式
        template.setValueSerializer(jackson2JsonRedisSerializer);
        // 指定 hash 的 value 的序列化格式
        template.setHashValueSerializer(jackson2JsonRedisSerializer);
        template.afterPropertiesSet();
        return template;
    }
}

PS:这个只是最基础版的窗口窗口限流,无用户、ip等限制,支持的窗口时间也为秒,看官们可以通过根据修改Limit注解来实现功能的扩展。如果写的有什么问题和错误欢迎大家指出,如果有什么好的建议也欢迎大家提出,一同进步!!!



版权声明:本文为weixin_45472766原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。