写在前面

最近发现线上发送验证码的接口被莫名刷了,1天就把上周刚充值的1万条用完了,我内心其实是崩溃的,于是决定对该接口实现限流,这里选择使用拦截器配合Redis来实现。

实战

项目初始化

第一步,新建一个名为limit-redis的SpringBoot项目,然后在POM文件中添加如下依赖:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>

第二步,修改application.yml配置文件信息:

1
2
3
4
5
6
spring:
redis:
host: 127.0.0.1
port: 6379
password: 1234
database: 1

重写RedisTempplate的序列化

第三步,重写RedisTempplate的序列化逻辑。一般来说我们更倾向于在SpringBoot中使用 Spring Data Redis来操作Redis,但是随着而来的则是它的序列化问题,默认使用的是JdkSerializationRedisSerializer,采用的是二进制方式,且会自动的给存入的key和value添加一些前缀,导致实际情况与开发者预想的不一致。针对这种情况我们可以使用Jackson2JsonRedisSerializer这一序列化方式,不建议使用StringRedisTemplate来替代RedisTemplate,因为它提供的数据类型和操作都有限,无法满足日常需要。

定义一个名为RedisConfig的类,该类用于重写RedisTempplate的序列化逻辑,使用Jackson2JsonRedisSerializer取代默认的JdkSerializationRedisSerializer,这样利于后续开发和使用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
@Configuration
public class RedisConfig {
@Bean
public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory connectionFactory) {
RedisTemplate<Object, Object> redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(connectionFactory);
// 使用Jackson2JsonRedisSerialize 替换默认序列化(默认采用的是JDK序列化)
Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);
ObjectMapper mapper = new ObjectMapper();
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
mapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
jackson2JsonRedisSerializer.setObjectMapper(mapper);
redisTemplate.setKeySerializer(jackson2JsonRedisSerializer);
redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
redisTemplate.setHashKeySerializer(jackson2JsonRedisSerializer);
redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
return redisTemplate;
}
}

自定义限流注解

第四步,自定义限流注解。我们要求这个限流分为两种,一种是针对某一个接口的全局限流,另一种是针对IP地址的限流:
(1)针对当前某一接口的全局限流。举个例子,/test接口可以在1分钟内访问60次;
(2)针对IP地址的限流。举个例子,192.168.56.1这一IP地址可以在1分钟内访问60次;
针对上述情况,可以创建一个枚举类LimitType,用于记录限流的类型:

1
2
3
4
5
6
7
8
9
10
11
12
13
/**
* 限流类型
*/
public enum LimitType {
/**
* 默认类型:全局限流
*/
DEFAULT,
/**
* 根据请求IP地址限流
*/
IP;
}

接着自定义一个限流注解RateLimiter,里面设置限流key,注意这个key仅仅是一个前缀,后续我们会拼接其他的变量组成完整的key,进而存入Redis中。完整key的格式如下:

1
rate_limit:IP地址-注解所添加方法所在的类的名称-注解所添加方法的名称

举个例子,如下所示:

1
rate_limit:192.168.30.10-com.melody.limitredis.controller.RateLimiterController-test

time则是限流的时间,单位为秒;count则是限流的次数,默认100次;limitType则是限流的类型,默认为全局限流:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
/**
* 限流Key
*/
String key()default "rate_limit:";

/**
* 限流时间,单位秒
*/
int time()default 60;

/**
* 限流次数,默认100次
*/
int count()default 100;

/**
* 限流类型,默认全局限流
*/
LimitType limitType()default LimitType.DEFAULT;
}

这样后续我们需要对某个接口进行限流,只需在该接口上添加@RateLimiter注解,并设置上述对应的参数即可,总的来说还是比较简单的。

第五步,编写Lua脚本。我们知道Redis的单个操作是具备原子性的,而多个操作就无法保证,但是我们可以借助于Lua脚本来实现。一般来说调用Lua脚本有两种方式:
(1)在Redis服务端定义Lua脚本,然后计算出一个hash值,接着在Java代码中通过这个hash值来确定需要执行的Lua脚本;
(2)在Java代码中将Lua脚本定义好,然后将其发送到Redis服务端,进而去执行。
笔者比较倾向于第二种方式,因此可以先在客户端定义好Lua脚本,然后通过Spring Data Redis提供的redisTemplate.execute()方法,传入脚本实例和对应的参数就可以执行对应的Lua脚本了。

在项目的resources目录下新建一个名为lua的目录,并在该lua目录下新建一个名为limit.lua的脚本:

1
2
3
4
5
6
7
8
9
10
11
12
local key = KEYS[1]
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])
local current = redis.call('get', key)
if current and tonumber(current) > count then
return tonumber(current)
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
redis.call('expire', key, time)
end
return tonumber(current)

这个KEYS和ARGV就是用户后续在执行该脚本时传入来的参数,tonumber方法用于将字符串转成数字,redis.call()方法通过传入方法名称和参数,进而实现调用不同方法的逻辑。上述脚本所要表达的含义如下:
(a)获取用户传递进来的key,限流次数count和限流时间time;
(b)调用get(key)方法来获取当前key的值,即当前接口在当前时间内已经访问的次数;
(c)如果该接口是第一次访问,那么(b)得到的结果将是nil,否则得到的是一个数字。如果是数字那么我们将判断它和限流次数count的大小,如果该数字大于count,则说明已经超过最大访问次数,需要限流了,可以直接返回当前请求的次数;
(d)如果得到的结果是nil,则说明是第一次访问该接口,那么给当前的key进行自增加1,并设置一个过期时间;
(e)最后只需将自增后的值返回即可。

第六步,在RedisConfig类中定义一个Bean来加载这个Lua脚本:

1
2
3
4
5
6
7
@Bean
public DefaultRedisScript<Long> limitScript() {
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
redisScript.setResultType(Long.class);
return redisScript;
}

注意Lua脚本的存放位置和名称需要与开发者实际存放的位置相匹配,否则后续无法成功调用该脚本。

自定义切面类

其实拦截请求可以有不同的实现思路,可以用拦截器或者AOP,考虑到拦截器也是AOP思想的体现,因此这里就直接使用AOP来实现。

第七步,自定义切面类RateLimiterAspect,该类用于拦截所有添加了@RateLimiter注解的方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
@Component
@Aspect
public class RateLimiterAspect {
private static final Logger logger = LoggerFactory.getLogger(RateLimiterAspect.class);

@Autowired
private RedisTemplate<Object,Object> redisTemplate;

@Autowired
private RedisScript<Long> limitScript;

@Before("@annotation(rateLimiter)")
public void before(JoinPoint point, RateLimiter rateLimiter){
int time = rateLimiter.time();
int count = rateLimiter.count();

String combineKey = getCombineKey(rateLimiter,point);
List<Object> keys = Collections.singletonList(combineKey);
try{
Long number = redisTemplate.execute(limitScript, keys, count, time);
if(number == null || number.intValue() > count){
throw new BizException("请求过于频繁,请稍后重试",500);
}
logger.info("当前请求次数'{}',限定次数'{}'", number.intValue(), count);
}catch (BizException e){
throw e;
}catch (Exception e){
throw new RuntimeException("服务器限流异常,请稍候再试");
}
}

private String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
//IP限制
if(rateLimiter.limitType() == LimitType.IP){
stringBuffer.append(IpUtils.getRequestIp(((ServletRequestAttributes)RequestContextHolder.getRequestAttributes()).getRequest())).append("-");
}
MethodSignature signature = (MethodSignature)point.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
logger.info("{}",stringBuffer.toString());
return stringBuffer.toString();
}
}

此处我们使用了前置通知,并在前置通知中对请求进行了处理,逻辑如下:
(1)获取注解中的time、count 、key和limitType这四个属性;
(2)调用getCombineKey()方法来获取一个完整的限流Key,先判断是否包含IP,如果有IP就将IP添加到里面,否则就不添加,最终完整的限流Key格式如下:

1
rate_limit:IP地址-注解所添加方法所在的类的名称-注解所添加方法的名称

(3)将限流Key放入一个集合中,因为此处我们调用的redisTemplate.execute()的完整方法如下:

1
<T> T execute(RedisScript<T> script, List<K> keys, Object... args);

里面的keys就是一个集合,所以此处就生成了一个单例模式的集合。其实这个也对应于之前我们在Lua脚本中定义的参数,即上述方法中的第二个参数keys就是脚本中的KEYS,可变长度就是脚本中的ARGV,注意值下标从1开始计数:

1
2
3
local key = KEYS[1]
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])

(4)之后就是将执行Lua脚本之后的返回值与限流值count进行比较,如果值大于count则说明超出最大访问次数,应当限流,此时抛出一个异常,这里我们自定义了一个异常BizException。

第八步,定义一个用于获取用户IP地址信息的工具类IpUtils

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
public class IpUtils {
/**
* 获取请求真实IP地址
*/
public static String getRequestIp(HttpServletRequest request) {
//通过HTTP代理服务器转发时添加
String ipAddress = request.getHeader("x-forwarded-for");
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("WL-Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getRemoteAddr();
// 从本地访问时根据网卡取本机配置的IP
if (ipAddress.equals("127.0.0.1") || ipAddress.equals("0:0:0:0:0:0:0:1")) {
InetAddress inetAddress = null;
try {
inetAddress = InetAddress.getLocalHost();
} catch (UnknownHostException e) {
e.printStackTrace();
}
ipAddress = inetAddress.getHostAddress();
}
}
// 通过多个代理转发的情况,第一个IP为客户端真实IP,多个IP会按照','分割
if (ipAddress != null && ipAddress.length() > 15) {
if (ipAddress.indexOf(",") > 0) {
ipAddress = ipAddress.substring(0, ipAddress.indexOf(","));
}
}
return ipAddress;
}
}

当然了,这里只是获取了用户的IP地址,之后可以向一些网站发起请求,得到这些IP的归属地,这里笔者就不展示了。

全局异常捕获

第九步,定义一个异常BizException:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public class BizException extends RuntimeException{
private int code;

public int getCode() {
return code;
}

public void setCode(int status) {
this.code = code;
}

public BizException(String message){
super(message);
}

public BizException(String message, int code){
super(message);
this.code = code;
}
}

第十步,针对该自定义异常,定义一个全局异常处理类GlobalExceptionHandler,用于处理这个异常:

1
2
3
4
5
6
7
8
9
10
11
@RestControllerAdvice
public class GlobalExceptionHandler {
@ResponseBody
@ExceptionHandler(BizException.class)
public Map<String,Object> handler(BizException e){
Map<String,Object> map = new HashMap<>();
map.put("message",e.getMessage());
map.put("code",e.getCode());
return map;
}
}

定义测试接口

第十一步,新建接口测试类RateLimiterController,然后在里面定义一个接口,并在该接口上添加自定义的@RateLimiter注解,用于测试我们的限流效果:

1
2
3
4
5
6
7
8
@RestController
public class RateLimiterController {
@GetMapping("/test")
@RateLimiter(time = 2,count = 5,limitType = LimitType.IP)
public String test(){
return "test >>>" +new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date());
}
}

项目测试

第十二步,启动项目,访问http://localhost:8080/test,可以看到页面显示信息如下:

1
test >>>2021-05-17 13:57:57

但是当用户访问较为频繁的时候,页面会给出相应的提示:

1
2
3
4
{
code: 500,
message: "请求过于频繁,请稍后重试"
}

同时查看IDEA的控制台,可以看到也输出了相应的日志信息:

小结

本篇基于AOP思想并配合Redis实现了接口的限流功能,在项目中实际上也有提供基于拦截器的实现,这一点需要的小伙伴可以自取。本篇实现的接口限流其实非常简陋,而我们经常使用的微服务组件Sentinel则是比较完备的限流工具,因此后续有更细粒度的限流需求可以参考Sentinel的限流逻辑进行实现。