前言

  1. 本文实现 上下游ws的代理功能、客户端发布功能
  2. 开发语言:Spring Boot + Kotlin
  3. 实现方式很多种,这里给出接口代码是思路,可以改 @ServerEndpoint 托管实现

代理模式

用户连接为上游,被代理地址为下游。

  1. 劫持控制修改上下游消息内容
  2. 对上游进行鉴权

时序设计

IWebSocketProxier时序图
IWebSocketProxier时序图
时序代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
sequenceDiagram
participant UpClient as 上游客户端
participant Proxier as IWebSocketProxier
participant DownClient as 下游服务
participant Scheduler2 as 清理线程

Note over UpClient,Proxier: 1. 上游连接建立
UpClient->>Proxier: WebSocket 握手
activate Proxier
Proxier-->>Proxier: afterConnectionEstablished(session)
Proxier-->>Proxier: sessions[session.id] = WebSocketProxySession(...)
Proxier-->>Proxier: onUpstreamOpen(proxy)
deactivate Proxier

Note over UpClient,Proxier: 2. 上游首条消息(授权)
UpClient->>Proxier: TextMessage(首次消息)
activate Proxier
Proxier-->>Proxier: handleMessage(session, message)
Proxier-->>Proxier: onUpstreamFirstMessage(proxy, message)
alt 授权失败
Proxier-->>UpClient: sendMessage(授权失败通知)
Proxier-->>Proxier: closeSession(session.id)
else 授权成功
Proxier-->>Proxier: proxy.authorized = true
Proxier-->>Proxier: onAuthSuccess(proxy)
Proxier-->>Proxier: connectDownstream(session.id)
Proxier-->>Proxier: downstreamContexts[session.id].pending.offer(clone(message))
end
deactivate Proxier

Note over UpClient,Proxier: 3. 上游后续消息
UpClient->>Proxier: TextMessage(后续消息)
activate Proxier
Proxier-->>Proxier: handleMessage
alt !downConnected
Proxier-->>Proxier: pending.offer(clone(message))
else 已连接下游
Proxier-->>DownClient: sendToDownstream(transformUpstream(message))
end
deactivate Proxier

Note over DownClient,Proxier: 4. 下游连接建立完成
DownClient->>Proxier: 握手完成
activate Proxier
Proxier-->>Proxier: ctx.downConnected = true
Proxier-->>Proxier: flush pending → sendToDownstream(...)
deactivate Proxier

Note over DownClient,Proxier: 5. 下游消息回传
DownClient->>Proxier: TextMessage(下游响应)
activate Proxier
Proxier-->>UpClient: proxy.session.sendMessage(transformDownstream(msg))
deactivate Proxier

Note over Scheduler2,Proxier: 6. 会话超时自动清理
Scheduler2->>Proxier: cleanupExpired()
activate Proxier
Proxier-->>Proxier: closeSession(超时 session.id)
deactivate Proxier

Note over UpClient,Proxier: 7. 上游主动关闭
UpClient->>Proxier: closeConnection
activate Proxier
Proxier-->>Proxier: afterConnectionClosed(session, status)
Proxier-->>Proxier: closeSession(session.id)
deactivate Proxier

接口代码

IWebSocketProxier
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import com.fasterxml.jackson.databind.ObjectMapper
import org.slf4j.LoggerFactory
import org.springframework.web.socket.*
import org.springframework.web.socket.client.WebSocketClient
import org.springframework.web.socket.handler.AbstractWebSocketHandler
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger

/**
* 包装上游会话及其状态,用于管理授权和心跳
*
* @param session WebSocket 上游会话
* @param authorized 是否已通过授权验证
* @param downConnected 下游连接是否已建立
* @param lastHeartbeat 最近心跳时间戳(毫秒)
*/
data class WebSocketProxySession(
val session: WebSocketSession,
var authorized: Boolean = false,
var downConnected: Boolean = false,
var lastHeartbeat: Long = System.currentTimeMillis()
)

/**
* 通用 WebSocket 代理抽象类
*
* 负责管理上游和下游的连接生命周期、消息转发以及超时清理
*
* 使用方式:
* 1. 实现核心抽象方法:
* - registerPath: 定义代理路由
* - onUpstreamFirstMessage: 处理上游首条消息并进行授权
* - downstreamUri: 获取下游 URI
* - transformUpstream: 上游→下游 转换逻辑
* - transformDownstream: 下游→上游 转换逻辑
* 2. 可选覆盖钩子:
* - onUpstreamOpen: 上游连接初始化
* - onAuthSuccess: 授权成功回调
* - onUpstreamFirstMessageIsNull: 授权失败处理
* - onSessionClosed: 会话关闭后处理
*
* @param objectMapper 用于 JSON 序列化/反序列化
* @param client WebSocket 客户端,用于建立下游连接
* @author ThatCoder
*/
abstract class IWebSocketProxier(
val objectMapper: ObjectMapper,
private val client: WebSocketClient
) : AbstractWebSocketHandler() {
/** 代理接入路径 */
abstract val registerPath: String

/** 会话超时时间,默认 10 分钟 */
open val sessionTimeoutMillis: Long = 10 * 60 * 1000

private val logger = LoggerFactory.getLogger(this::class.java)
private val sessions = ConcurrentHashMap<String, WebSocketProxySession>()
private val downstreamContexts = ConcurrentHashMap<String, DownstreamContext>()
private val scheduler = Executors.newSingleThreadScheduledExecutor(
NamedThreadFactory("proxy-session-timeout-")
)

init {
// 定期清理超时会话
scheduler.scheduleAtFixedRate(
{ cleanupExpired() },
sessionTimeoutMillis,
sessionTimeoutMillis,
TimeUnit.MILLISECONDS
)
}

override fun afterConnectionEstablished(session: WebSocketSession) {
logger.info("Upstream connected: ${session.id}")
sessions[session.id] = WebSocketProxySession(session)
onUpstreamOpen(sessions[session.id]!!)
}

override fun handleMessage(session: WebSocketSession, message: WebSocketMessage<*>) {
val proxy = sessions[session.id] ?: return
if (!proxy.authorized) {
val ok = onUpstreamFirstMessage(proxy, message)
if (!ok) {
onUpstreamFirstMessageIsNull(proxy)
closeSession(session.id)
return
}
proxy.authorized = true
onAuthSuccess(proxy)
connectDownstream(session.id)
downstreamContexts[session.id]?.pending?.offer(clone(message))
return
}
val ctx = downstreamContexts[session.id] ?: return
if (!ctx.downConnected.get()) {
ctx.pending.offer(clone(message))
} else {
ctx.sendToDownstream(transformUpstream(message))
}
}

override fun afterConnectionClosed(session: WebSocketSession, status: CloseStatus) {
logger.info("Upstream closed: ${session.id}")
closeSession(session.id)
}

/**
* 向所有上游会话发送心跳,维持长连接
*/
fun sendHeartbeat() {
val ping = PingMessage()
sessions.values.forEach {
try {
it.session.sendMessage(ping)
} catch (_: Exception) {
// 忽略发送失败
}
}
}

// ---------- 可覆盖钩子 ----------

/** 上游连接建立后回调 */
protected open fun onUpstreamOpen(proxy: WebSocketProxySession) = Unit

/**
* 上游首条消息处理并授权
* @return true 表示通过,false 则触发授权失败
*/
protected abstract fun onUpstreamFirstMessage(
proxy: WebSocketProxySession,
message: WebSocketMessage<*>
): Boolean

/** 授权失败发送给上游的消息 */
protected open fun onUpstreamFirstMessageIsNull(proxy: WebSocketProxySession) {
val err = mapOf("finish" to true, "error" to "身份认证失败")
proxy.session.sendMessage(TextMessage(objectMapper.writeValueAsString(err)))
}

/** 授权成功后回调 */
protected open fun onAuthSuccess(proxy: WebSocketProxySession) = Unit

/** 根据上游会话获取下游 URI */
protected abstract fun downstreamUri(proxy: WebSocketProxySession): String

/** 上游→下游 消息转换 */
protected abstract fun transformUpstream(message: WebSocketMessage<*>): WebSocketMessage<*>

/** 下游→上游 消息转换 */
protected abstract fun transformDownstream(message: WebSocketMessage<*>): WebSocketMessage<*>

/** 会话关闭后回调 */
protected open fun onSessionClosed(proxy: WebSocketProxySession) = Unit

// ---------- 内部逻辑 ----------

/**
* 建立下游连接,并将后续消息路由到 DownstreamContext
*/
private fun connectDownstream(sessionId: String) {
val proxy = sessions[sessionId]!!
val ctx = DownstreamContext(proxy)
downstreamContexts[sessionId] = ctx
client.execute(object : AbstractWebSocketHandler() {
override fun afterConnectionEstablished(down: WebSocketSession) {
logger.info("Downstream connected for: $sessionId")
ctx.downConnected.set(true)
ctx.downstream = down
while (true) {
val msg = ctx.pending.poll() ?: break
ctx.sendToDownstream(transformUpstream(msg))
}
}

override fun handleMessage(down: WebSocketSession, msg: WebSocketMessage<*>) {
proxy.session.sendMessage(transformDownstream(msg))
}

override fun afterConnectionClosed(down: WebSocketSession, status: CloseStatus) {
logger.warn("Downstream closed early: ${status.code}")
closeSession(sessionId)
}
}, downstreamUri(proxy))
}

/** 关闭并清理指定会话 */
private fun closeSession(sessionId: String) {
sessions.remove(sessionId)?.also { onSessionClosed(it) }
downstreamContexts.remove(sessionId)?.closeAll()
}

/** 清理超时会话 */
private fun cleanupExpired() {
val now = System.currentTimeMillis()
sessions.entries
.filter { now - it.value.lastHeartbeat > sessionTimeoutMillis }
.forEach { closeSession(it.key) }
}

/** 克隆消息以避免并发问题 */
private fun clone(msg: WebSocketMessage<*>): WebSocketMessage<*> = when (msg) {
is TextMessage -> TextMessage(msg.payload)
is BinaryMessage -> BinaryMessage(msg.payload.asReadOnlyBuffer())
else -> msg
}

/**
* 管理下游消息发送及队列
*/
private class DownstreamContext(proxy: WebSocketProxySession) {
@Volatile var downstream: WebSocketSession? = null
val downConnected = AtomicBoolean(false)
val pending = ConcurrentLinkedQueue<WebSocketMessage<*>>()
private val executor: ExecutorService = ThreadPoolExecutor(
4, 16, 60, TimeUnit.SECONDS,
LinkedBlockingQueue(1000),
NamedThreadFactory("proxy-send-${proxy.session.id}")
)

/** 将消息异步发送到下游 */
fun sendToDownstream(msg: WebSocketMessage<*>) {
executor.execute {
try {
downstream?.sendMessage(msg)
} catch (e: Exception) {
LoggerFactory.getLogger("DownstreamLogger").error("Send downstream failed", e)
}
}
}

/** 关闭下游并清理资源 */
fun closeAll() {
try {
downstream?.close()
} catch (_: Exception) {
}
executor.shutdownNow()
pending.clear()
}
}

/** 为线程池生成可读性线程名 */
private class NamedThreadFactory(prefix: String) : ThreadFactory {
private val cnt = AtomicInteger(1)
private val name = "${prefix}-${cnt.getAndIncrement()}"
override fun newThread(r: Runnable) = Thread(r, name)
}
}

实现示例

以代理 FunAsr 为例,统一上下游的消息类型,对上游进行身份权限认证

AsrProxier
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import com.bidr.waterx.transpond.config.extension.fieldJust
import com.bidr.waterx.transpond.config.extension.fieldRemove
import com.bidr.waterx.transpond.config.extension.fieldRename
import com.bidr.waterx.transpond.config.extension.putMap
import com.bidr.waterx.transpond.config.extension.toObjectNode
import com.bidr.waterx.transpond.config.extension.toTextMessage
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.node.ArrayNode
import org.slf4j.LoggerFactory
import org.springframework.stereotype.Component
import org.springframework.web.socket.CloseStatus
import org.springframework.web.socket.TextMessage
import org.springframework.web.socket.WebSocketMessage
import org.springframework.web.socket.client.WebSocketClient

/** ASR 代理实现 */
@Component
class AsrProxier(
objectMapper: ObjectMapper,
webSocketClient: WebSocketClient,
private val akService: IAKService
) : IWebSocketProxier(objectMapper, webSocketClient) {
private val paramProxy = mapOf(
"id" to "wav_name",
"finish" to "is_speaking",
"answer" to "text"
)

private val logger = LoggerFactory.getLogger(this::class.java)

override val registerPath = "/ws/asr"

// 鉴权服务
override fun onUpstreamFirstMessage(proxy: WebSocketProxySession, message: WebSocketMessage<*>): Boolean {
val node = message.toObjectNode(objectMapper) ?: return false
val ak = node.get("ak")?.asText() ?: return false
return akService.check(ak)
}

override fun downstreamUri(proxy: WebSocketProxySession) = "ws://localhost:10095"

// 处理上游消息适配成FUNASR接收类型
override fun transformUpstream(message: WebSocketMessage<*>) = when (message) {
is TextMessage -> runCatching {
val forward = message.toObjectNode(objectMapper)?.fieldRename(paramProxy) ?: return message
forward.get("is_speaking")?.let {
val finished = it.asBoolean(false)
if (!finished) forward.putMap(mapOf(
"language" to "zn",
"itn" to false,
"hotwords" to "{\"阿里巴巴\":20,\"hello world\":40}"
))
forward.put("is_speaking", !finished)
}
forward.get("mode")?.let {
if (listOf("mixed","online").contains(it.asText())) {
val arr = objectMapper.createArrayNode().add(5).add(10).add(5)
forward.set<ArrayNode>("chunk_size", arr)
forward.put("chunk_interval", 10)
if (it.asText() == "mixed") forward.put("mode", "2pass")
}
}
forward.fieldRemove(listOf("ak"))
logger.info("transformUpstream: $forward")
forward.toTextMessage(objectMapper)
}.getOrDefault(message)
else -> message
}

// 处理下游消息适配成客户接收类型
override fun transformDownstream(message: WebSocketMessage<*>) = when (message) {
is TextMessage -> {
val forward = message.toObjectNode(objectMapper)
?.fieldRename(paramProxy.toMutableMap().plus("finish" to "is_final"), true)
?: return message
forward.putMap(mapOf(
"mode" to when (forward.get("mode")?.asText() ?: "2pass-offline") {
"2pass-online" -> "online"
"2pass-offline" -> "offline"
else -> forward.get("mode").asText()
},
"timestamp" to System.currentTimeMillis()
))
forward.fieldJust(paramProxy.keys.plus("mode").toList())
logger.info("transformDownstream: $forward")
forward.toTextMessage(objectMapper)
}
else -> message
}

override fun onUpstreamFirstMessageIsNull(proxy: WebSocketProxySession) {
super.onUpstreamFirstMessageIsNull(proxy)
proxy.session.close(CloseStatus.POLICY_VIOLATION)
}
}

客户端模式

客户端模式是自己为发布器,用户为上游,自己作为下游。

  1. 用户认证
  2. 会话对象维护
  3. 心跳维护
  4. 消息广播
  5. 消息过滤广播
  6. 单例模式

时序设计

IWebSocketPublisher时序图
IWebSocketPublisher时序图
时序代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
sequenceDiagram
participant Client as 客户端
participant Publisher as IWebSocketPublisher
participant Scheduler as 定时清理线程

Note over Client,Publisher: 1. 连接建立与初始化
Client->>Publisher: WebSocket 握手并建立连接
activate Publisher
Publisher-->>Publisher: afterConnectionEstablished(session)
Publisher-->>Publisher: onUpstreamOpen(session)
deactivate Publisher

Note over Client,Publisher: 2. 首次消息(身份验证)
Client->>Publisher: TextMessage(首次消息)
activate Publisher
Publisher-->>Publisher: handleMessage(session, message)
Publisher-->>Publisher: onUpstreamFirstMessage(session, message)
alt 验证失败
Publisher-->>Client: onUpstreamFirstMessageIsNull → 发送错误提示
Publisher-->>Client: session.close(POLICY_VIOLATION)
else 验证成功
Publisher-->>Publisher: sessions[session.id] = WebSocketSenderSession(...)
Publisher-->>Publisher: onAuthSuccess(...)
end
deactivate Publisher

Note over Client,Publisher: 3. 后续业务消息处理
Client->>Publisher: TextMessage(业务消息) 或 PingMessage(心跳)
activate Publisher
Publisher-->>Publisher: handleMessage
alt 心跳
Publisher-->>Publisher: 更新 lastHeartbeat
else 业务消息
Publisher-->>Publisher: onUpstreamMessage(...)
end
deactivate Publisher

Note over Scheduler,Publisher: 4. 会话超时清理
Scheduler->>Publisher: cleanupExpired()
activate Publisher
Publisher-->>Publisher: 关闭过期 session → onSessionClosed
deactivate Publisher

Note over Publisher,Client: 5. 发布/广播/心跳
Publisher->>Client: publishAll/publishByFilter/publishSender
Publisher-->>Publisher: transformPublish(...)
Publisher-->>Client: sendMessage(转换后消息)
Publisher->>Client: sendHeartbeat() → PingMessage()

接口代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import com.fasterxml.jackson.databind.ObjectMapper
import org.slf4j.LoggerFactory
import org.springframework.web.socket.*
import org.springframework.web.socket.handler.AbstractWebSocketHandler
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicInteger

/**
* 会话、用户信息的包装类
*/
data class WebSocketSenderSession<T>(
val session: WebSocketSession,
val user: T,
/** 心跳超时标志 */
val lastHeartbeat: Long = System.currentTimeMillis()
)

/**
* WebSocket 发布者抽象接口,用于构建支持用户认证、心跳维护、消息广播的通用 WebSocket 服务。
*
* ### 使用方式
*
* #### 必须实现
* > 继承本类并实现以下核心抽象方法
* - [registerPath]:注册的路径,WebSocket 接入入口
* - [onUpstreamFirstMessage]:处理上游客户端首次连接时的消息,一般用于身份验证,返回的用户信息将用于标识会话;若返回 null,连接将被关闭
* - [onUpstreamMessage]:处理客户端后续发送的消息
*
* #### 可选重写
* - [onUpstreamOpen]:连接建立但未发送任何消息时的初始化回调
* - [onSessionClosed]:连接关闭后的回调处理
* - [onUpstreamFirstMessageIsNull]:首次消息认证失败时的回调,默认发送错误信息
* - [onAuthSuccess]:首次消息认证通过后的回调
* - [transformPublish]:发送消息前进行的消息变换
*
* ### 会话管理
* - 会话信息以 [WebSocketSenderSession] 包装,包含 `session`、用户信息及心跳时间
* - 默认 10 分钟未活跃会话将被关闭,可通过 [sessionTimeoutMillis] 调整
*
* ### 发布功能
* - [publishAll]:向所有连接发布消息
* - [publishByFilter]:根据过滤条件发布消息
* - [publishSender]:向单个连接发送消息
* - [sendHeartbeat]:向所有连接发送 Ping 消息,维持长连接
*
* @param objectMapper Jackson 用于 JSON 序列化/反序列化
* @param T 用户类型,需由 [onUpstreamFirstMessage] 提供
* @author ThatCoder
*/
abstract class IWebSocketPublisher<T>(
private val objectMapper: ObjectMapper
) : AbstractWebSocketHandler() {

abstract val registerPath: String

/** 会话超时毫秒数 默认十分钟 */
val sessionTimeoutMillis: Long = 10*60*1000

private val logger = LoggerFactory.getLogger(this::class.java)

/** 所有会话管理器 */
val sessions = ConcurrentHashMap<String, WebSocketSenderSession<T>>()
private val scheduler = Executors.newSingleThreadScheduledExecutor(NamedThreadFactory("session-timeout-"))

init {
// 定期清理超时会话
scheduler.scheduleAtFixedRate({ cleanupExpired() }, sessionTimeoutMillis, sessionTimeoutMillis, TimeUnit.MILLISECONDS)
}

override fun afterConnectionEstablished(session: WebSocketSession) {
logger.info("Client connected: ${session.id}")
onUpstreamOpen(session)
}

override fun handleMessage(session: WebSocketSession, message: WebSocketMessage<*>) {
// 首次消息处理授权与注册
if (!sessions.containsKey(session.id)) {
val user = onUpstreamFirstMessage(session, message)
if (user == null) {
onUpstreamFirstMessageIsNull(session)
session.close(CloseStatus.POLICY_VIOLATION)
return
}
sessions[session.id] = WebSocketSenderSession(session, user)
onAuthSuccess(sessions[session.id]!!)
logger.info("Session registered: ${session.id} -> $user")
return
}
// 心跳更新或具体消息处理
onUpstreamMessage(sessions[session.id]!!, message)
}

override fun afterConnectionClosed(session: WebSocketSession, status: CloseStatus) {
logger.info("Client closed: ${session.id} (${status.reason})")
sessions.remove(session.id)?.let { onSessionClosed(it) }
}

/**
* 全局发布消息
* @param message 消息
*/
fun publishAll(message: WebSocketMessage<*>) {
sessions.values.forEach { sender -> send(sender, message) }
}

/**
* 按过滤器发布
* @param filter 过滤器
* @param message 消息
*/
fun publishByFilter(filter: (WebSocketSenderSession<T>) -> Boolean, message: WebSocketMessage<*>) {
sessions.values.filter(filter).forEach { send(it, message) }
}

/**
* 发送消息给指定会话
* @param sender 会话
* @param message 消息
*/
fun publishSender(sender: WebSocketSenderSession<T>, message: WebSocketMessage<*>) {
send(sender, message)
}

/** 发送心跳 */
fun sendHeartbeat() {
val ping = PingMessage()
sessions.values.forEach {
try { it.session.sendMessage(ping) } catch (_: Exception) {}
}
}

// ========== 子类扩展点 ===========

/**
* 会话首次创建时调用
* @param session 会话
*/
protected open fun onUpstreamOpen(session: WebSocketSession) = Unit

/**
* 会话消息
* @param sender 会话对象
* @param message 消息
*/
protected abstract fun onUpstreamMessage(sender: WebSocketSenderSession<T>, message: WebSocketMessage<*>)

/**
* 会话关闭时调用
* @param sender 会话对象
*/
protected open fun onSessionClosed(sender: WebSocketSenderSession<T>) = Unit

/**
* 会话首条消息
*
* 通常在验证用户权限时调用
* @param session 会话
* @param message 消息
* @return 用户信息 如果返回null则触发 onUpstreamFirstMessageIsNull
* @see onUpstreamFirstMessageIsNull
*/
protected abstract fun onUpstreamFirstMessage(
session: WebSocketSession,
message: WebSocketMessage<*>
): T?

/**
* 会话首条消息处理为空时调用
* @param session 会话
*/
protected open fun onUpstreamFirstMessageIsNull(session: WebSocketSession) {
val err = mapOf("error" to "身份认证失败")
session.sendMessage(TextMessage(objectMapper.writeValueAsString(err)))
}

/**
* 认证成功后执行
* @param sender 会话
*/
protected open fun onAuthSuccess(sender: WebSocketSenderSession<T>) = Unit

/** 清理超时会话 */
private fun cleanupExpired() {
val now = System.currentTimeMillis()
sessions.values.filter { now - it.lastHeartbeat > sessionTimeoutMillis }
.forEach {
try { it.session.close(CloseStatus.SESSION_NOT_RELIABLE) } catch (_: Exception) {}
sessions.remove(it.session.id)
logger.info("Session timeout removed: ${it.session.id}")
}
}

private fun send(sender: WebSocketSenderSession<T>, message: WebSocketMessage<*>) {
if (!sender.session.isOpen) return
try {
sender.session.sendMessage(transformPublish(message, sender))
} catch (e: Exception) {
logger.error("Publish to ${sender.session.id} failed", e)
}
}

protected open fun transformPublish(
message: WebSocketMessage<*>,
sender: WebSocketSenderSession<T>
): WebSocketMessage<*> = message

private class NamedThreadFactory(prefix: String) : ThreadFactory {
private val cnt = AtomicInteger(1)
private val name = prefix + cnt.getAndIncrement()
override fun newThread(r: Runnable): Thread {
return Thread(r, name)
}
}
}

实现示例

以实现聊天室为例,这个例子有对单对群发送演示

兼容单例模式,只使用 publishSender 方法即可, 相当于一对一服务

  • 实现后可以多开几个网页测试 websocket测试网页
  • 链接本地 ws://localhost:8080/ws/chat后可以发送一个body鉴权进群 {"ak": "123456", "message": "我是卢本伟", "name": "卢本伟"}
  • 进群后续可以不发送 ak,已经有了sessionId对应的用户, 后面发送 {"message": "欢迎来到卢本伟广场"} 即可
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import com.bidr.waterx.transpond.config.extension.toObjectNode
import com.bidr.waterx.transpond.config.extension.toTextMessage
import com.fasterxml.jackson.databind.ObjectMapper
import org.springframework.stereotype.Component
import org.springframework.web.socket.TextMessage
import org.springframework.web.socket.WebSocketMessage
import org.springframework.web.socket.WebSocketSession

data class ChatUser(val userid: String, val name: String)

/**
* 聊天室发布者
*/
@Component
class ChatPublisher(private val objectMapper: ObjectMapper) : IWebSocketPublisher<ChatUser>(objectMapper) {

override val registerPath = "/ws/chat"

override fun onUpstreamFirstMessage(session: WebSocketSession, message: WebSocketMessage<*>): ChatUser? {
val message = message.toObjectNode() ?: return null
val ak = message.get("ak")?.asText() ?: return null
val name = message.get("name")?.asText() ?: return null
if (ak != "123456") return null
// 创建用户
val user = ChatUser( session.id, name)
// 给该用户发送欢迎信息
session.sendMessage(TextMessage("Hi, $name. Please chat friendly!"))
// 群发用户入群提示
publishAll(TextMessage("$name've joined the chat room."))
return user
}

override fun onSessionClosed(sender: WebSocketSenderSession<ChatUser>) {
// 群发用户离开提示
publishAll(TextMessage("${sender.user.name} has left the chat room."))
}

override fun onUpstreamMessage(sender: WebSocketSenderSession<ChatUser>, message: WebSocketMessage<*>) {
// 转发用户消息至群聊
publishAll(objectMapper.createObjectNode().apply {
put("type", "chat")
putPOJO("user", sender.user)
putPOJO("message", message.toObjectNode())
}.toTextMessage())
}
}

路由注册

两个接口都有 registerPath 所以我们可以让 Spring 收集 IWebSocketPublisher、IWebSocketProxier 的实现类,自动注册里面的路由实现

WebSocketConfig
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
package cn.uwant.auto.config

import IWebSocketProxier
import IWebSocketPublisher
import jakarta.websocket.ContainerProvider
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.web.socket.client.WebSocketClient
import org.springframework.web.socket.client.standard.StandardWebSocketClient
import org.springframework.web.socket.config.annotation.EnableWebSocket
import org.springframework.web.socket.config.annotation.WebSocketConfigurer
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry
import org.springframework.context.annotation.Lazy
import kotlin.collections.map

/**
* WebSocket配置
* @author ThatCoder
*/
@Configuration
@EnableWebSocket
class WebSocketConfig(
@Lazy private val proxies: List<IWebSocketProxier>,
@Lazy private val publishers: List<IWebSocketPublisher<*>>
) : WebSocketConfigurer {
override fun registerWebSocketHandlers(registry: WebSocketHandlerRegistry) {
proxies.map {
registry.addHandler(it, it.registerPath).setAllowedOrigins("*")
}
publishers.map {
registry.addHandler(it, it.registerPath).setAllowedOrigins("*")
}
}
@Bean
fun webSocketClient(): WebSocketClient {
val container = ContainerProvider.getWebSocketContainer()
return StandardWebSocketClient(container)
}
}

相关错误

见 BUG 专栏


本站由 钟意 使用 Stellar 1.28.1 主题创建。
又拍云 提供CDN加速/云存储服务
vercel 提供托管服务
湘ICP备2023019799号-1
总访问 次 | 本页访问