写在前面 最近发现线上发送验证码的接口被莫名刷了,1天就把上周刚充值的1万条用完了,我内心其实是崩溃的,于是决定对该接口实现限流,这里选择使用拦截器配合Redis来实现。
实战 项目初始化 第一步,新建一个名为limit-redis
的SpringBoot项目,然后在POM文件中添加如下依赖:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-data-redis</artifactId> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-aop</artifactId> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-test</artifactId> <scope>test</scope> </dependency>
第二步,修改application.yml
配置文件信息:
1 2 3 4 5 6 spring: redis: host: 127.0.0.1 port: 6379 password: 1234 database: 1
重写RedisTempplate的序列化 第三步,重写RedisTempplate的序列化逻辑。一般来说我们更倾向于在SpringBoot中使用 Spring Data Redis来操作Redis,但是随着而来的则是它的序列化问题,默认使用的是JdkSerializationRedisSerializer
,采用的是二进制方式,且会自动的给存入的key和value添加一些前缀,导致实际情况与开发者预想的不一致。针对这种情况我们可以使用Jackson2JsonRedisSerializer
这一序列化方式,不建议使用StringRedisTemplate
来替代RedisTemplate
,因为它提供的数据类型和操作都有限,无法满足日常需要。
定义一个名为RedisConfig的类,该类用于重写RedisTempplate的序列化逻辑,使用Jackson2JsonRedisSerializer
取代默认的JdkSerializationRedisSerializer
,这样利于后续开发和使用:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 @Configuration public class RedisConfig { @Bean public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory connectionFactory) { RedisTemplate<Object, Object> redisTemplate = new RedisTemplate<>(); redisTemplate.setConnectionFactory(connectionFactory); // 使用Jackson2JsonRedisSerialize 替换默认序列化(默认采用的是JDK序列化) Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class); ObjectMapper mapper = new ObjectMapper(); mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY); mapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL); jackson2JsonRedisSerializer.setObjectMapper(mapper); redisTemplate.setKeySerializer(jackson2JsonRedisSerializer); redisTemplate.setValueSerializer(jackson2JsonRedisSerializer); redisTemplate.setHashKeySerializer(jackson2JsonRedisSerializer); redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer); return redisTemplate; } }
自定义限流注解 第四步,自定义限流注解。我们要求这个限流分为两种,一种是针对某一个接口的全局限流,另一种是针对IP地址的限流: (1)针对当前某一接口的全局限流。举个例子,/test
接口可以在1分钟内访问60次; (2)针对IP地址的限流。举个例子,192.168.56.1
这一IP地址可以在1分钟内访问60次; 针对上述情况,可以创建一个枚举类LimitType
,用于记录限流的类型:
1 2 3 4 5 6 7 8 9 10 11 12 13 /** * 限流类型 */ public enum LimitType { /** * 默认类型:全局限流 */ DEFAULT, /** * 根据请求IP地址限流 */ IP; }
接着自定义一个限流注解RateLimiter
,里面设置限流key,注意这个key仅仅是一个前缀,后续我们会拼接其他的变量组成完整的key,进而存入Redis中。完整key的格式如下:
1 rate_limit:IP地址-注解所添加方法所在的类的名称-注解所添加方法的名称
举个例子,如下所示:
1 rate_limit:192.168.30.10-com.melody.limitredis.controller.RateLimiterController-test
time则是限流的时间,单位为秒;count则是限流的次数,默认100次;limitType则是限流的类型,默认为全局限流:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 @Target({ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface RateLimiter { /** * 限流Key */ String key()default "rate_limit:"; /** * 限流时间,单位秒 */ int time()default 60; /** * 限流次数,默认100次 */ int count()default 100; /** * 限流类型,默认全局限流 */ LimitType limitType()default LimitType.DEFAULT; }
这样后续我们需要对某个接口进行限流,只需在该接口上添加@RateLimiter
注解,并设置上述对应的参数即可,总的来说还是比较简单的。
第五步,编写Lua脚本。我们知道Redis的单个操作是具备原子性的,而多个操作就无法保证,但是我们可以借助于Lua脚本来实现。一般来说调用Lua脚本有两种方式: (1)在Redis服务端定义Lua脚本,然后计算出一个hash值,接着在Java代码中通过这个hash值来确定需要执行的Lua脚本; (2)在Java代码中将Lua脚本定义好,然后将其发送到Redis服务端,进而去执行。 笔者比较倾向于第二种方式,因此可以先在客户端定义好Lua脚本,然后通过Spring Data Redis提供的redisTemplate.execute()
方法,传入脚本实例和对应的参数就可以执行对应的Lua脚本了。
在项目的resources目录下新建一个名为lua的目录,并在该lua目录下新建一个名为limit.lua
的脚本:
1 2 3 4 5 6 7 8 9 10 11 12 local key = KEYS[1] local count = tonumber(ARGV[1]) local time = tonumber(ARGV[2]) local current = redis.call('get', key) if current and tonumber(current) > count then return tonumber(current) end current = redis.call('incr', key) if tonumber(current) == 1 then redis.call('expire', key, time) end return tonumber(current)
这个KEYS和ARGV就是用户后续在执行该脚本时传入来的参数,tonumber方法用于将字符串转成数字,redis.call()
方法通过传入方法名称和参数,进而实现调用不同方法的逻辑。上述脚本所要表达的含义如下: (a)获取用户传递进来的key,限流次数count和限流时间time; (b)调用get(key)
方法来获取当前key的值,即当前接口在当前时间内已经访问的次数; (c)如果该接口是第一次访问,那么(b)得到的结果将是nil,否则得到的是一个数字。如果是数字那么我们将判断它和限流次数count的大小,如果该数字大于count,则说明已经超过最大访问次数,需要限流了,可以直接返回当前请求的次数; (d)如果得到的结果是nil,则说明是第一次访问该接口,那么给当前的key进行自增加1,并设置一个过期时间; (e)最后只需将自增后的值返回即可。
第六步,在RedisConfig类中定义一个Bean来加载这个Lua脚本:
1 2 3 4 5 6 7 @Bean public DefaultRedisScript<Long> limitScript() { DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>(); redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua"))); redisScript.setResultType(Long.class); return redisScript; }
注意Lua脚本的存放位置和名称需要与开发者实际存放的位置相匹配,否则后续无法成功调用该脚本。
自定义切面类 其实拦截请求可以有不同的实现思路,可以用拦截器或者AOP,考虑到拦截器也是AOP思想的体现,因此这里就直接使用AOP来实现。
第七步,自定义切面类RateLimiterAspect
,该类用于拦截所有添加了@RateLimiter
注解的方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 @Component @Aspect public class RateLimiterAspect { private static final Logger logger = LoggerFactory.getLogger(RateLimiterAspect.class); @Autowired private RedisTemplate<Object,Object> redisTemplate; @Autowired private RedisScript<Long> limitScript; @Before("@annotation(rateLimiter)") public void before(JoinPoint point, RateLimiter rateLimiter){ int time = rateLimiter.time(); int count = rateLimiter.count(); String combineKey = getCombineKey(rateLimiter,point); List<Object> keys = Collections.singletonList(combineKey); try{ Long number = redisTemplate.execute(limitScript, keys, count, time); if(number == null || number.intValue() > count){ throw new BizException("请求过于频繁,请稍后重试",500); } logger.info("当前请求次数'{}',限定次数'{}'", number.intValue(), count); }catch (BizException e){ throw e; }catch (Exception e){ throw new RuntimeException("服务器限流异常,请稍候再试"); } } private String getCombineKey(RateLimiter rateLimiter, JoinPoint point) { StringBuffer stringBuffer = new StringBuffer(rateLimiter.key()); //IP限制 if(rateLimiter.limitType() == LimitType.IP){ stringBuffer.append(IpUtils.getRequestIp(((ServletRequestAttributes)RequestContextHolder.getRequestAttributes()).getRequest())).append("-"); } MethodSignature signature = (MethodSignature)point.getSignature(); Method method = signature.getMethod(); Class<?> targetClass = method.getDeclaringClass(); stringBuffer.append(targetClass.getName()).append("-").append(method.getName()); logger.info("{}",stringBuffer.toString()); return stringBuffer.toString(); } }
此处我们使用了前置通知,并在前置通知中对请求进行了处理,逻辑如下: (1)获取注解中的time、count 、key和limitType这四个属性; (2)调用getCombineKey()
方法来获取一个完整的限流Key,先判断是否包含IP,如果有IP就将IP添加到里面,否则就不添加,最终完整的限流Key格式如下:
1 rate_limit:IP地址-注解所添加方法所在的类的名称-注解所添加方法的名称
(3)将限流Key放入一个集合中,因为此处我们调用的redisTemplate.execute()
的完整方法如下:
1 <T> T execute(RedisScript<T> script, List<K> keys, Object... args);
里面的keys就是一个集合,所以此处就生成了一个单例模式的集合。其实这个也对应于之前我们在Lua脚本中定义的参数,即上述方法中的第二个参数keys就是脚本中的KEYS,可变长度就是脚本中的ARGV,注意值下标从1开始计数:
1 2 3 local key = KEYS[1] local count = tonumber(ARGV[1]) local time = tonumber(ARGV[2])
(4)之后就是将执行Lua脚本之后的返回值与限流值count进行比较,如果值大于count则说明超出最大访问次数,应当限流,此时抛出一个异常,这里我们自定义了一个异常BizException。
第八步,定义一个用于获取用户IP地址信息的工具类IpUtils
:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 public class IpUtils { /** * 获取请求真实IP地址 */ public static String getRequestIp(HttpServletRequest request) { //通过HTTP代理服务器转发时添加 String ipAddress = request.getHeader("x-forwarded-for"); if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) { ipAddress = request.getHeader("Proxy-Client-IP"); } if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) { ipAddress = request.getHeader("WL-Proxy-Client-IP"); } if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) { ipAddress = request.getRemoteAddr(); // 从本地访问时根据网卡取本机配置的IP if (ipAddress.equals("127.0.0.1") || ipAddress.equals("0:0:0:0:0:0:0:1")) { InetAddress inetAddress = null; try { inetAddress = InetAddress.getLocalHost(); } catch (UnknownHostException e) { e.printStackTrace(); } ipAddress = inetAddress.getHostAddress(); } } // 通过多个代理转发的情况,第一个IP为客户端真实IP,多个IP会按照','分割 if (ipAddress != null && ipAddress.length() > 15) { if (ipAddress.indexOf(",") > 0) { ipAddress = ipAddress.substring(0, ipAddress.indexOf(",")); } } return ipAddress; } }
当然了,这里只是获取了用户的IP地址,之后可以向一些网站发起请求,得到这些IP的归属地,这里笔者就不展示了。
全局异常捕获 第九步,定义一个异常BizException:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 public class BizException extends RuntimeException{ private int code; public int getCode() { return code; } public void setCode(int status) { this.code = code; } public BizException(String message){ super(message); } public BizException(String message, int code){ super(message); this.code = code; } }
第十步,针对该自定义异常,定义一个全局异常处理类GlobalExceptionHandler
,用于处理这个异常:
1 2 3 4 5 6 7 8 9 10 11 @RestControllerAdvice public class GlobalExceptionHandler { @ResponseBody @ExceptionHandler(BizException.class) public Map<String,Object> handler(BizException e){ Map<String,Object> map = new HashMap<>(); map.put("message",e.getMessage()); map.put("code",e.getCode()); return map; } }
定义测试接口 第十一步,新建接口测试类RateLimiterController
,然后在里面定义一个接口,并在该接口上添加自定义的@RateLimiter
注解,用于测试我们的限流效果:
1 2 3 4 5 6 7 8 @RestController public class RateLimiterController { @GetMapping("/test") @RateLimiter(time = 2,count = 5,limitType = LimitType.IP) public String test(){ return "test >>>" +new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()); } }
项目测试 第十二步,启动项目,访问http://localhost:8080/test
,可以看到页面显示信息如下:
1 test >>>2021-05-17 13:57:57
但是当用户访问较为频繁的时候,页面会给出相应的提示:
1 2 3 4 { code: 500, message: "请求过于频繁,请稍后重试" }
同时查看IDEA的控制台,可以看到也输出了相应的日志信息:
小结 本篇基于AOP思想并配合Redis实现了接口的限流功能,在项目中实际上也有提供基于拦截器的实现,这一点需要的小伙伴可以自取。本篇实现的接口限流其实非常简陋,而我们经常使用的微服务组件Sentinel则是比较完备的限流工具,因此后续有更细粒度的限流需求可以参考Sentinel的限流逻辑进行实现。