Redis实现分布式锁来防止重复提交

  • Post author:
  • Post category:其他


前言:

在系统中,有些接口如果重复提交,可能会造成脏数据或者其他的严重的问题,所以我们一般会对与数据库有交互的接口进行重复处理。我们首先会想到在前端做一层控制。当前端触发操作时,或弹出确认界面,或disable入口并倒计时等等,但是这并不能彻底限制,因此我们这里使用Redis来对某些操作加锁

场景:

场景一:在网络延迟的情况下让用户有时间点击多次submit按钮导致表单重复提交

场景二:表单提交后用户点击【刷新】按钮导致表单重复提交

场景三:用户提交表单后,点击浏览器的【后退】按钮回退到表单页面后进行再次提交

应用:这里我们用到Redis的SETNX key value命令,对于该命令的解释是

将 key 的值设为 value ,当且仅当 key 不存在。

若给定的 key 已经存在,则 SETNX 不做任何动作。

SETNX 是『SET if Not eXists』(如果不存在,则 SET)的简写。

思路(如果不想每一个请求都单独处理,以下行为,可以在自定义拦截器里面统一处理,1,2步在preHandle中处理,第3步再afterCompletion中处理):

把参数组装好,进行MD5加密作为key,这样如果重复提交的话,这个请求生成的key就是一样的

在请求之前,改action先去拿锁,拿到锁再继续进行下去

请求结束之后,必须释放锁,虽然我们已经对锁做了过期处理,防止死锁,但是不建议只靠这样的操作解锁

代码实现:


@Component
public class RedisLock {
 
    public static final int LOCK_EXPIRE = 3000; // ms
 
    @Autowired
    private StringRedisTemplate redisTemplate;
 
 
    /**
     *  分布式锁
     *
     * @param key key值
     * @return 是否获取到
     */
    public boolean lock(String key) {
        String lock = key;
        try {
            return (Boolean) redisTemplate.execute((RedisCallback) connection -> {
                long expireAt = System.currentTimeMillis() + LOCK_EXPIRE;
                Boolean acquire = connection.setNX(lock.getBytes(), String.valueOf(expireAt).getBytes());
                if (acquire) {
                    return true;
                } else {
                    //判断该key上的值是否过期了
                    byte[] value = connection.get(lock.getBytes());
                    if (Objects.nonNull(value) && value.length > 0) {
                        long expireTime = Long.parseLong(new String(value));
                        if (expireTime < System.currentTimeMillis()) {
                            // 如果锁已经过期
                            byte[] oldValue = connection.getSet(lock.getBytes(), String.valueOf(System.currentTimeMillis() + LOCK_EXPIRE).getBytes());
                            // 防止死锁
                            return Long.parseLong(new String(oldValue)) < System.currentTimeMillis();
                        }
                    }
                }
                return false;
            });
        } finally {
            RedisConnectionUtils.unbindConnection(redisTemplate.getConnectionFactory());
        }
    }
 
 
    @Autowired
    private RedisService redisService;
 
    /**
     * 删除锁
     *
     * @param key
     */
    public void delete(String key) {
        try {
            redisTemplate.delete(key);
        } finally {
            RedisConnectionUtils.unbindConnection(redisTemplate.getConnectionFactory());
        }
    }
 
}
测试Controller,如果拿不到锁,则等待0.5秒后继续拿,重复5次

@RestController
public class RedisLockTestController {
 
    @Autowired
    private RedisLock redisLock;
 
    @PostMapping("createOrder")
    public String createOrder(HttpServletRequest request){
        String lockKey = MapUtil.getRedisKeyByParam(request.getParameterMap());
        if (redisLock.lock(lockKey)){
            //处理逻辑
            redisLock.delete(lockKey);
            return "success";
        }else {
            // 设置失败次数计数器, 当到达5次时, 返回失败
            int failCount = 1;
            while(failCount <= 5){
                // 等待100ms重试
                try {
                    Thread.sleep(500);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                if (redisLock.lock(lockKey)){
                    // 执行逻辑操作
                    //处理逻辑
                    redisLock.delete(lockKey);
                    return "success";
                }else{
                    failCount ++;
                }
            }
            return "请勿重复提交请求";
        }
 
    }
 
}
请求参数工具类

public class MapUtil {
 
    public static String getRedisKeyByParam(Map<String, String[]> requestParams) {
        //除去数组中的空值
        Map<String, String> sPara = paraFilter(toVerifyMap(requestParams,false));
        //把数组所有元素,按照“参数=参数值”的模式用“&”字符拼接成字符串
        String prestr = createLinkString(sPara);
        //生成签名结果
        String mysign = DigestUtils.md5Hex(getContentBytes(prestr, "UTF-8"));
        return mysign;
    }
 
    /**
     * 除去数组中的空值
     * @param sArray 参数组
     * @return 去掉空值后新的参数组
     */
    public static Map<String, String> paraFilter(Map<String, String> sArray) {
        Map<String, String> result = new HashMap<>();
        if (sArray == null || sArray.size() <= 0) {
            return result;
        }
        for (String key : sArray.keySet()) {
            String value = sArray.get(key);
            if (value == null || value.equals("")) {
                continue;
            }
            result.put(key, value);
        }
        return result;
    }
 
    /**
     * 把数组所有元素排序,并按照“参数=参数值”的模式用“&”字符拼接成字符串
     * @param params 需要排序并参与字符拼接的参数组
     * @return 拼接后字符串
     */
    public static String createLinkString(Map<String, String> params) {
        List<String> keys = new ArrayList<>(params.keySet());
        Collections.sort(keys);
        String prestr = "";
        for (int i = 0; i < keys.size(); i++) {
            String key = keys.get(i);
            String value = params.get(key);
            if (i == keys.size() - 1) {//拼接时,不包括最后一个&字符
                prestr = prestr + key + "=" + value;
            } else {
                prestr = prestr + key + "=" + value + "&";
            }
        }
        return prestr;
    }
 
    private static byte[] getContentBytes(String content, String charset) {
        if (charset == null || "".equals(charset)) {
            return content.getBytes();
        }
        try {
            return content.getBytes(charset);
        } catch (UnsupportedEncodingException e) {
            throw new RuntimeException("MD5签名过程中出现错误,指定的编码集不对,您目前指定的编码集是:" + charset);
        }
    }
 
 
    /**
     * 请求参数Map转换验证Map
     * @param requestParams 请求参数Map
     * @param charset 是否要转utf8编码
     * @return
     * @throws UnsupportedEncodingException
     */
    public static Map<String,String> toVerifyMap(Map<String, String[]> requestParams, boolean charset) {
        Map<String,String> params = new HashMap<>();
        for (Iterator iter = requestParams.keySet().iterator(); iter.hasNext();) {
            String name = (String) iter.next();
            String[] values = requestParams.get(name);
            String valueStr = "";
            for (int i = 0; i < values.length; i++) {
                valueStr = (i == values.length - 1) ? valueStr + values[i] : valueStr + values[i] + ",";
            }
            //乱码解决,这段代码在出现乱码时使用。如果mysign和sign不相等也可以使用这段代码转化
            if(charset)
                valueStr = getContentString(valueStr, "UTF-8");
            params.put(name, valueStr);
        }
        return params;
    }
 
    /**
     * 编码转换
     * @param content
     * @param charset
     * @return
     */
    private static String getContentString(String content, String charset) {
        if (charset == null || "".equals(charset)) {
            return new String(content.getBytes());
        }
        try {
            return new String(content.getBytes("ISO-8859-1"), charset);
        } catch (UnsupportedEncodingException e) {
            throw new RuntimeException("指定的编码集不对,您目前指定的编码集是:" + charset);
        }
    }
}