Yang Zeng's Notes

Yang Zeng

SpringBoot中集成WebSocket「支持多终端、共享session」

知识点

WebSocket 和 Http:
WebSocket 是 HTML5 开始提供的一种在单个 TCP 连接上进行全双工通讯的协议,使得客户端和服务器之间的数据交换变得更加简单,允许服务端主动向客户端推送数据。在 WebSocket API 中,浏览器和服务器只需要完成一次握手,两者之间就直接可以创建持久性的连接,并进行双向数据传输。
它是为了解决客户端发起多个 http 请求到服务器资源浏览器必须要经过长时间的轮训问题而生的,他实现了多路复用,他是全双工通信。在 webSocket 协议下客服端和浏览器可以同时发送信息。
HTTP 协议是用在应用层的协议,他是基于 tcp 协议的,http 协议建立链接也必须要有三次握手才能发送信息。http 链接分为短链接,长链接,短链接是每次请求都要三次握手才能发送自己的信息。即每一个 request 对应一个 response。长链接是在一定的期限内保持链接。保持 TCP 连接不断开。客户端与服务器通信,必须要有客户端发起然后服务器返回结果。客户端是主动的,服务器是被动的。
HTTP 的长链接和 WebSocket 持久连接的区别:

  1. HTTP1.1 的连接默认使用长连接(persistent connection),即在一定的期限内保持链接,客户端会需要在短时间内向服务端请求大量的资源,保持 TCP 连接不断开。客户端与服务器通信,必须要有客户端发起然后服务器返回结果。客户端是主动的,服务器是被动的。在一个 TCP 连接上可以传输多个 Request/Response 消息对,所以本质上还是 Request/Response 消息对,仍然会造成资源的浪费、实时性不强等问题。如果不是持续连接,即短连接,那么每个资源都要建立一个新的连接,HTTP 底层使用的是 TCP,那么每次都要使用三次握手建立 TCP 连接,即每一个 request 对应一个 response,将造成极大的资源浪费。长轮询,即客户端发送一个超时时间很长的 Request,服务器 hold 住这个连接,在有新数据到达时返回 Response
  2. Websocket 的持久连接 只需建立一次 Request/Response 消息对,之后都是 TCP 连接,避免了需要多次建立 Request/Response 消息对而产生的冗余头部信息。Websocket 只需要一次 HTTP 握手,所以说整个通讯过程是建立在一次连接/状态中,而且 websocket 可以实现服务端主动联系客户端,这是 http 做不到的

1
添加 pom 依赖

1
2
3
4
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>

2
自定义 Service
自定义 WebSocketServer,使用底层的 websocket 方法,提供对应的 onOpen、onClose、onMessage、onError 方法

1、添加配置类
1
2
3
4
5
6
7
8
9
10
11
12
13
14
@Configuration
public class WebSocketConfig {
/**
* ServerEndpointExporter 作用
*
* 这个Bean会自动注册使用@ServerEndpoint注解声明的websocket endpoint
*
* @return
*/
@Bean
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}
}

注意:ServerEndpointExporter 一定要注入,这个 bean 会自动注册使用了@ServerEndpoint 注解声明的 Websocket endpoint;如果使用独立的 servlet 容器,而不是直接使用 springboot 的内置容器,就不要注入 ServerEndpointExporter,因为它将由容器自己提供和管理。

2、自定义 WebSocketServer(核心 Server)

这里有两个问题:

  • 如果是单台实例,其实我们大可不必使用 redis 进行共享 session;但如果是分布式,客户端创建的连接是随机的,可能与服务器 A 创建了连接,也可能是服务器 B,如果仅仅将连接信息存到内存,那就有问题了
  • WebsocketSession 不支持序列化,所以不能直接将 Session 对象存储到 redis 中

这里只是其中的一种解决方法,A、B 服务器的 session 依然保存在各自的服务器中,然后将 userid、sessionId、服务端服务器的 IP 的关系保存在 redis 中;当然这个 sessionId 你也可以不用,你可以在拿到 websocket 的 session 的时候,给他赋予一个唯一 ID,并把这个 ID 和 websocket session 存入内存,同时将该关系以及当前创建的服务器 IP 保存到 redis 中;
如何使用?需要发送消息的时候,根据 userId 从 redis 中获取对应的关系,再根据对应的 IP 转发到对用的 websocket 服务器上即可

sessionId 是由 org.apache.tomcat.websocket.WsSession 生成的一个递增的 16 进制并转为字符串,每次重启服务,这个 id 的计数又会重新从 0 开始。如果建立了多个通道,那他们的 id 可能为(0,1818,70cc).因为通道断开,对应的 webSocketSession 对象被释放,所以不同通道直接的 id 可能是不连续的.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import com.alibaba.fastjson.JSONObject;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

@Component
@ServerEndpoint("/websocket/{userId}")
@Slf4j
public class WebSocketServer {
/**
* 静态变量,用来记录当前在线连接数。应该把它设计成线程安全的。
*/
private static int onlineCount = 0;
/**
* concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象。
* <p>
* Map的key为userId,List用于存放一个userId的多个终端,比如同一个userId,同时在手机端和PC登陆
*/
private static ConcurrentHashMap<String, List<WebSocketServer>> webSocketMap = new ConcurrentHashMap<>();
/**
* 与某个客户端的连接会话,需要通过它来给客户端发送数据
*/
private Session session;
/**
* 接收userId
*/
private String userId = "";

/**
* 自己封装的redis的工具类,用于将数据存放于redis,解决分布式下多服务器共享session
*/
private static RedisUtils redisUtils;

@Autowired
public void setChatService(RedisUtils redisUtils) {
WebSocketServer.redisUtils = redisUtils;
}

/**
* 读取配置文件中的配置数据
*/
private static ValueConfig valueConfig;

@Autowired
public void setValueConfig(ValueConfig valueConfig) {
WebSocketServer.valueConfig = valueConfig;
}

/**
* 存储websocket客户端接入的信息key前缀
*/
public static final String websocketRedisKeyPrefix = "WebSocket_";

/**
* 连接建立成功调用的方法
*/
@OnOpen
public void onOpen(Session session, @PathParam("userId") String userId) {
this.session = session;
this.userId = userId;
List<WebSocketServer> servers;
List<WebSocketServer> webSocketServers = new ArrayList<>();
//将接入的客户端信息添加到内存
if (webSocketMap.containsKey(userId)) {
//查询当前userId以及当前的session是否已经存在,如果存在,先移除再新增,如果不存在,直接新增
webSocketServers = webSocketMap.get(userId).stream().filter(o -> o.session.getId().equals(session.getId())).collect(Collectors.toList());
}
if (webSocketMap.containsKey(userId) && webSocketServers.size() > 0) {
webSocketServers = webSocketMap.get(userId);
webSocketServers.removeIf(webSocketServer -> webSocketServer.session.getId().equals(session.getId()));
servers = webSocketServers;
servers.add(this);
webSocketMap.put(userId, servers);
} else {
servers = null == webSocketMap.get(userId) ? new ArrayList<>() : webSocketMap.get(userId);
servers.add(this);
webSocketMap.put(userId, servers);
addOnlineCount();//在线数加1
}
log.info("用户【" + userId + "】sessionId:[" + session.getId() + "]连接成功" + ",当前在线人数为:" + getOnlineCount());
//region 将客户端连接信息存入redis
try {
/**
* SocketUserInfoDTO
*
* 存储在redis中按照如下结构
* |-- Ip
* -- sessionId1 |-- 其他信息
* --
* userId -- -- sessionId2 |- ...
* --
* -- sessionId3 |-- Ip
* -- |-- 其他信息
*/
SocketUserInfoDTO suid = null;
SocketUserInfoSessionDTO socketUserInfoSessionDTO = new SocketUserInfoSessionDTO();
socketUserInfoSessionDTO.setSessionId(session.getId());
socketUserInfoSessionDTO.setUserId(userId);
socketUserInfoSessionDTO.setIp(IpUtils.getServerIpAddress() + ":" + valueConfig.getServicePort());
//需要从redis拉最新的客户端连接信息
Object object = redisUtils.get(getSocketRedisKey(userId));
if (null != object) {
suid = JSONObject.parseObject(object.toString(), SocketUserInfoDTO.class);
}
SocketUserInfoDTO socketUserInfoDTO = new SocketUserInfoDTO();
if (null == suid) { //当前user没有保存的socket信息
Map<String, Map<String, SocketUserInfoSessionDTO>> listMap = new HashMap<>();
Map<String, SocketUserInfoSessionDTO> map = new HashMap<>();
map.put(session.getId(), socketUserInfoSessionDTO);
listMap.put(userId, map);
socketUserInfoDTO.setListMap(listMap);
//保存到redis
redisUtils.set(getSocketRedisKey(userId), JSONObject.toJSONString(socketUserInfoDTO));
} else { //当前user有保存的socket信息
Map<String, Map<String, SocketUserInfoSessionDTO>> map = suid.getListMap();
Map<String, SocketUserInfoSessionDTO> sessionDTOMap = map.get(userId);
sessionDTOMap.put(session.getId(), socketUserInfoSessionDTO);
map.put(userId, sessionDTOMap);
socketUserInfoDTO.setListMap(map);
redisUtils.set(getSocketRedisKey(userId), JSONObject.toJSONString(socketUserInfoDTO));
}
sendMessage("连接成功");
} catch (IOException e) {
log.error("用户:" + userId + ",网络异常!!!!!!");
}
//endregion
}


/**
* 连接关闭调用的方法
*/
@OnClose
public void onClose() {
List<WebSocketServer> webSocketServers = new ArrayList<>();
if (webSocketMap.containsKey(userId)) {
webSocketServers = webSocketMap.get(userId).stream().filter(o -> o.session.getId().equals(session.getId())).collect(Collectors.toList());
}
if (webSocketMap.containsKey(userId) && webSocketServers.size() > 0) {
webSocketServers = webSocketMap.get(userId);
Iterator<WebSocketServer> iterator = webSocketServers.iterator();
while (iterator.hasNext()) {
if (iterator.next().session.getId().equals(session.getId())) {
iterator.remove();
}
}
webSocketMap.put(userId, webSocketServers);
subOnlineCount();
log.info("用户【" + userId + "】sessionId:[" + session.getId() + "]断开连接,当前在线人数为:" + getOnlineCount());
}
//从redis中移除对应的客户端
Object redisSocketObj = redisUtils.get(getSocketRedisKey(userId));
SocketUserInfoDTO suid = null;
if (null != redisSocketObj) {
String resultStr = redisSocketObj.toString();
suid = JSONObject.parseObject(resultStr, SocketUserInfoDTO.class);
}
if (null != suid) {
Map<String, Map<String, SocketUserInfoSessionDTO>> map = suid.getListMap();
Map<String, SocketUserInfoSessionDTO> sessionDTOMap = map.get(userId);
Iterator<Map.Entry<String, SocketUserInfoSessionDTO>> entryIterator = sessionDTOMap.entrySet().iterator();
while (entryIterator.hasNext()) {
Map.Entry<String, SocketUserInfoSessionDTO> entry = entryIterator.next();
if (session.getId().equals(entry.getValue().getSessionId())) {
entryIterator.remove();
}
}
if (sessionDTOMap.size() <= 0) {
map.remove(userId);
}
SocketUserInfoDTO socketUserInfoDTO = new SocketUserInfoDTO();
if (map.size() <= 0) {
redisUtils.del(getSocketRedisKey(userId));
} else {
socketUserInfoDTO.setListMap(map);
redisUtils.set(getSocketRedisKey(userId), JSONObject.toJSONString(socketUserInfoDTO));
}
}
}

/**
* 收到客户端消息后调用的方法
*
* @param message 客户端发送过来的消息
*/
@OnMessage
public void onMessage(String message, Session session) {
log.info("用户【" + userId + "】sessionId:[" + session.getId() + "]发送消息给服务端报文:" + message);
// log.info("session" + session.getBasicRemote() + "|" + session.getId());
try {
sendMessage("服务端消息:" + "用户" + userId + "收到客户端的消息");
} catch (IOException e) {
e.printStackTrace();
}
//可以群发消息
//消息保存到数据库、redis
}

/**
* @param session
* @param error
*/
@OnError
public void onError(Session session, Throwable error) {
log.error("用户错误:" + this.userId + ",原因:" + error.getMessage());
error.printStackTrace();
}

/**
* 实现服务器主动推送
*/
public void sendMessage(String message) throws IOException {
this.session.getBasicRemote().sendText(message);
}


/**
* 发送自定义消息
*/
public static void sendInfo(String message, @PathParam("userId") String userId, String sessionId) throws IOException {
log.info("发送消息到用户【" + userId + "】sessionId:[" + sessionId + "]发送消息给客户端报文:" + message);
log.info(JSON.toJSONString("当前客户" + userId + "的所有客户端:" + webSocketMap.get(userId)));
if (StringUtils.isNotBlank(userId) && webSocketMap.containsKey(userId)) {
for (WebSocketServer webSocketServer : webSocketMap.get(userId)) {
//1、如果不考虑websocket有多台服务器的情况下,可以不用判断,推送消息的时候对该用户的所有终端都推送
//2、当然如果业务需求不需要多终端推送,哪个终端有消息,就推送哪个,这里就不需要修改
if (sessionId.equals(webSocketServer.session.getId())) {
webSocketServer.sendMessage(message);
}
}
} else {
log.error("用户" + userId + ",不在线!");
}
}

/**
* 构造redis的key值
*
* @param userId
* @return
*/
private String getSocketRedisKey(String userId) {
return websocketRedisKeyPrefix + userId;
}

public static synchronized int getOnlineCount() {
return onlineCount;
}

public static synchronized void addOnlineCount() {
WebSocketServer.onlineCount++;
}

public static synchronized void subOnlineCount() {
WebSocketServer.onlineCount--;
}
}

相关的 POJO:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
/**
* 为了使用方便,将保存的客户端信息保存为map
*/
@Data
public class SocketUserInfoDTO implements Serializable {
private static final long serialVersionUID = 1L;
private Map<String, Map<String, SocketUserInfoSessionDTO>> listMap;
}


/**
* 存储客户端连接的信息,根据自己的情况,可以存一些业务参数,方便后面使用
*/
@Data
public class SocketUserInfoSessionDTO implements Serializable {
private static final long serialVersionUID = 1L;
//WebSocket Session 的sessionId
private String sessionId;
//userId
private String userId;
//客户端创建连接时候的服务器IP+端口号
private String ip;
}


@Data
public class ForwardSocketUserInfoRequest {
private String sessionId;
private String userId;
//需要发送消息时携带的信息,可以自定义
private Object object;
}

3
实现 HTTP 消息转发接口

  • **socketSend 接口:**消息服务统一将消息通过该接口发送处理
  • **forwardSend 接口:**由 websocket 服务内部根据 IP 做消息转发

Controller:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
@RestController
@RequestMapping("/websocket")
public class WebSocketController {

@Resource
private WebSocketService webSocketService;

/**
* @param request
* @return
*/
@RequestMapping(value = "/socketSend", method = RequestMethod.POST)
@ResponseBody
Result socketSend(@RequestBody ForwardSocketUserInfoRequest request) {
return webSocketService.socketSend(request);
}

/**
* 该接口供本服务内部通过转发到对用的服务器发送http调用
* @param request
* @return
*/
@RequestMapping(value = "/forwardSend", method = RequestMethod.POST)
@ResponseBody
Result forwardSend(@RequestBody ForwardSocketUserInfoRequest request) {
return webSocketService.forwardSend(request);
}
}

Service Interface:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
public interface WebSocketService {
/**
* 处理需要发送socket消息的任务
*
* @param request
*/
Result socketSend(ForwardSocketUserInfoRequest request);

/**
* 该接口供本服务内部通过转发到对用的服务器发送http调用
*
* @param request
* @return
*/
Result forwardSend(ForwardSocketUserInfoRequest request);
}

实现 Impl:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
@Slf4j
@Service
public class WebSocketServiceImpl implements WebSocketService {

@Resource
private RedisUtils redisUtils;

@Value("${server.port}")
private String servicePort;

/**
* 根据redis中存入的websocket客户端接入信息,转发到对应的服务器做消息推送
*
* @param request
*/
@Override
public Result socketSend(ForwardSocketUserInfoRequest request) {
String redisKey = WebSocketServer.websocketRedisKeyPrefix + request.getUserId();
Object redisSocketObj = redisUtils.get(redisKey);
SocketUserInfoDTO socketUserInfoDTO = null;
if (null != redisSocketObj) {
String resultStr = redisSocketObj.toString();
socketUserInfoDTO = JSONObject.parseObject(resultStr, SocketUserInfoDTO.class);
}
if (null == socketUserInfoDTO) {
log.error("给用户推送websocket消息失败");
return Result.buildSuccess();
}
Map<String, Map<String, SocketUserInfoSessionDTO>> mapMap = socketUserInfoDTO.getListMap();
Map<String, SocketUserInfoSessionDTO> sessionDTOMap = mapMap.get(request.getUserId());
//遍历map,根据Ip通过HTTP方式,传入sessionId参数,调用推送信息接口(这里的接口需要根据sessionId从内存中的map拿到对应的websocket,然后发送消息),推送信息到客户端
Iterator<Map.Entry<String, SocketUserInfoSessionDTO>> entryIterator = sessionDTOMap.entrySet().iterator();
while (entryIterator.hasNext()) {
Map.Entry<String, SocketUserInfoSessionDTO> entry = entryIterator.next();
String ipAndPort = IpUtils.getServerIpAddress() + ":" + servicePort;
if (ipAndPort.equals(entry.getValue().getIp())) { //如果是本台服务器,直接发送
try {
WebSocketServer.sendInfo(JSON.toJSONString(request.getObject()), request.getUserId(), entry.getValue().getSessionId());
} catch (IOException e) {
e.printStackTrace();
}
} else {
request.setSessionId(entry.getValue().getSessionId());
//通过HTTP转发到对应的服务器做处理
JnHttpUtils.post(JSON.toJSONString(request),
"http://" + entry.getValue().getIp() + "/websocket/forwardSend");
}
}
return Result.buildSuccess();
}

/**
* 该接口供本服务内部通过转发到对用的服务器发送http调用
*
* @param request
* @return
*/
@Override
public Result forwardSend(ForwardSocketUserInfoRequest request) {
try {
WebSocketServer.sendInfo(JSON.toJSONString(request.getObject()), request.getUserId(), request.getSessionId());
} catch (IOException e) {
e.printStackTrace();
}
return Result.buildSuccess();
}
}

4
使用

1、模拟客户端创建连接

启动服务,使用在线 websocket 请求工具:
image.png
image.png
可以看到服务器的控制台,打印的日志:

image.png

具体的内存和 redis 中数据的变化可以 debug 观察一下