|
|
@@ -0,0 +1,428 @@
|
|
|
+package com.storlead.ai.websocket;
|
|
|
+import cn.hutool.core.util.StrUtil;
|
|
|
+import com.storlead.ai.core.AiProviderType;
|
|
|
+import com.storlead.ai.core.ChatRequest;
|
|
|
+import com.storlead.ai.service.AiChatService;
|
|
|
+import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
+import org.slf4j.Logger;
|
|
|
+import org.slf4j.LoggerFactory;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
+import org.springframework.stereotype.Component;
|
|
|
+import org.springframework.web.socket.*;
|
|
|
+import reactor.core.Disposable;
|
|
|
+
|
|
|
+import java.io.IOException;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.concurrent.ConcurrentHashMap;
|
|
|
+
|
|
|
+
|
|
|
+/**
|
|
|
+ * 修复后的WebSocket处理器
|
|
|
+ *
|
|
|
+ * @author your-name
|
|
|
+ * @since 1.0.0
|
|
|
+ */
|
|
|
+@Component
|
|
|
+public class AiChatWebSocketHandler implements WebSocketHandler {
|
|
|
+
|
|
|
+ private static final Logger logger = LoggerFactory.getLogger(AiChatWebSocketHandler.class);
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private AiChatService aiChatService;
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private WebSocketConnectionManager connectionManager;
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private WebSocketSessionManager sessionManager;
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private ObjectMapper objectMapper;
|
|
|
+
|
|
|
+ private final Map<String, String> sessionProvider = new ConcurrentHashMap<>();
|
|
|
+ private final Map<String, String> sessionModel = new ConcurrentHashMap<>();
|
|
|
+ @Override
|
|
|
+ public void afterConnectionEstablished(WebSocketSession session) throws Exception {
|
|
|
+ String sessionId = session.getId();
|
|
|
+ String userId = extractUserId(session);
|
|
|
+ String provider = extractProvider(session);
|
|
|
+ String model = extractModel(session);
|
|
|
+ sessionProvider.put(sessionId,provider);
|
|
|
+ sessionModel.put(sessionId,model);
|
|
|
+ connectionManager.addConnection(sessionId, session);
|
|
|
+ sessionManager.createSession(sessionId, userId);
|
|
|
+
|
|
|
+ logger.info("WebSocket连接建立: {} (用户: {})", sessionId, userId);
|
|
|
+
|
|
|
+ // 发送连接成功消息 - 修复方法调用
|
|
|
+ sendMessage(session, WebSocketMessage.systemMessage(Map.of(
|
|
|
+ "event", "connected",
|
|
|
+ "sessionId", sessionId,
|
|
|
+ "userId", userId,
|
|
|
+ "supportedProviders", aiChatService.getSupportedProviders().stream()
|
|
|
+ .map(type -> Map.of(
|
|
|
+ "code", type.getCode(),
|
|
|
+ "name", type.getDisplayName(),
|
|
|
+ "available", aiChatService.isProviderAvailable(type)
|
|
|
+ ))
|
|
|
+ .toList()
|
|
|
+ )));
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void handleMessage(WebSocketSession session, org.springframework.web.socket.WebSocketMessage<?> payload) throws Exception {
|
|
|
+ String sessionId = session.getId();
|
|
|
+
|
|
|
+ try {
|
|
|
+
|
|
|
+ String provider = sessionProvider.get(sessionId);
|
|
|
+ String model = sessionModel.get(sessionId);
|
|
|
+ String message = payload.getPayload().toString();
|
|
|
+
|
|
|
+// String messageType = messageNode.get("type").asText();
|
|
|
+// logger.debug("收到消息: {} - {}", sessionId, messageType);
|
|
|
+ handleChatRequest(session, message,provider,model);
|
|
|
+// switch (messageType) {
|
|
|
+// case "CHAT_REQUEST" -> handleChatRequest(session, message,provider,message);
|
|
|
+// case "CANCEL_CHAT" -> handleCancelChat(session);
|
|
|
+// case "PING" -> handlePing(session);
|
|
|
+// case "PROVIDER_SWITCH" -> handleProviderSwitch(session, provider);
|
|
|
+// case "CONNECTION_TEST" -> handleConnectionTest(session);
|
|
|
+// default -> sendError(session, "未知消息类型: " + messageType);
|
|
|
+// }
|
|
|
+
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.error("处理消息失败: {}", sessionId, e);
|
|
|
+ sendError(session, "消息处理失败: " + e.getMessage());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
|
|
|
+ String sessionId = session.getId();
|
|
|
+ logger.error("WebSocket传输错误: {}", sessionId, exception);
|
|
|
+ cleanup(sessionId);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
|
|
|
+ String sessionId = session.getId();
|
|
|
+ logger.info("WebSocket连接关闭: {} - {}", sessionId, closeStatus);
|
|
|
+ cleanup(sessionId);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public boolean supportsPartialMessages() {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 处理聊天请求
|
|
|
+ */
|
|
|
+ private void handleChatRequest(WebSocketSession session,String message,String provider,String model) {
|
|
|
+ String sessionId = session.getId();
|
|
|
+
|
|
|
+ // 检查是否正在聊天中
|
|
|
+ if (sessionManager.isChatting(sessionId)) {
|
|
|
+ sendError(session, "正在聊天中,请等待完成或先取消当前对话");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ try {
|
|
|
+ // 验证必要字段
|
|
|
+ if (StrUtil.isBlank(message)) {
|
|
|
+ sendError(session, "消息内容不能为空");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ // 构建聊天请求
|
|
|
+ ChatRequest.Builder requestBuilder = ChatRequest.builder().message(message);
|
|
|
+
|
|
|
+// // 可选参数
|
|
|
+// if (messageNode.has("systemPrompt") && !messageNode.get("systemPrompt").asText().trim().isEmpty()) {
|
|
|
+// requestBuilder.systemPrompt(messageNode.get("systemPrompt").asText());
|
|
|
+// }
|
|
|
+
|
|
|
+ if (StrUtil.isNotBlank(model)) {
|
|
|
+ requestBuilder.model(model);
|
|
|
+ }
|
|
|
+
|
|
|
+ ChatRequest chatRequest = requestBuilder.build();
|
|
|
+
|
|
|
+ // 验证AI提供商
|
|
|
+ AiProviderType providerType;
|
|
|
+ try {
|
|
|
+ providerType = AiProviderType.fromCode(provider);
|
|
|
+ } catch (IllegalArgumentException e) {
|
|
|
+ sendError(session, "不支持的AI提供商: " + provider);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ // 检查提供商是否可用
|
|
|
+ if (!aiChatService.isProviderAvailable(providerType)) {
|
|
|
+ sendError(session, providerType.getDisplayName() + " 服务当前不可用");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ // 开始流式聊天
|
|
|
+ startStreamChat(session, sessionId, providerType, chatRequest);
|
|
|
+
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.error("处理聊天请求失败: {}", sessionId, e);
|
|
|
+ sendError(session, "请求处理失败: " + e.getMessage());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 开始流式聊天
|
|
|
+ */
|
|
|
+ private void startStreamChat(WebSocketSession session, String sessionId,
|
|
|
+ AiProviderType providerType, ChatRequest chatRequest) {
|
|
|
+
|
|
|
+ logger.debug("开始流式聊天: {} - 提供商: {}", sessionId, providerType.getDisplayName());
|
|
|
+
|
|
|
+ // 订阅流式响应 - 修复方法调用
|
|
|
+ Disposable subscription = aiChatService.chatStream(providerType, chatRequest)
|
|
|
+ .doOnSubscribe(sub -> {
|
|
|
+ logger.debug("流式聊天订阅成功: {}", sessionId);
|
|
|
+ })
|
|
|
+ .subscribe(
|
|
|
+ // 处理每个数据块 - 修复方法调用
|
|
|
+ chunk -> {
|
|
|
+ if (chunk != null && !chunk.trim().isEmpty()) {
|
|
|
+ sendMessage(session, WebSocketMessage.chatChunk(chunk));
|
|
|
+ }
|
|
|
+ },
|
|
|
+ // 处理错误
|
|
|
+ error -> {
|
|
|
+ logger.error("流式聊天错误: {} - 提供商: {}", sessionId, providerType.getDisplayName(), error);
|
|
|
+ sendError(session, "聊天失败: " + error.getMessage());
|
|
|
+ sessionManager.cancelStreamSubscription(sessionId);
|
|
|
+ },
|
|
|
+ // 处理完成 - 修复方法调用
|
|
|
+ () -> {
|
|
|
+ logger.debug("流式聊天完成: {} - 提供商: {}", sessionId, providerType.getDisplayName());
|
|
|
+ sendMessage(session, WebSocketMessage.chatComplete());
|
|
|
+ sessionManager.cancelStreamSubscription(sessionId);
|
|
|
+ }
|
|
|
+ );
|
|
|
+
|
|
|
+ // 保存订阅以便后续取消
|
|
|
+ sessionManager.setStreamSubscription(sessionId, subscription);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 处理取消聊天
|
|
|
+ */
|
|
|
+ private void handleCancelChat(WebSocketSession session) {
|
|
|
+ String sessionId = session.getId();
|
|
|
+
|
|
|
+ if (!sessionManager.isChatting(sessionId)) {
|
|
|
+ sendMessage(session, WebSocketMessage.systemMessage("当前没有进行中的聊天"));
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ sessionManager.cancelStreamSubscription(sessionId);
|
|
|
+ sendMessage(session, WebSocketMessage.systemMessage("聊天已取消"));
|
|
|
+ logger.debug("用户取消聊天: {}", sessionId);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 处理提供商切换 - 修复方法调用
|
|
|
+ */
|
|
|
+ private void handleProviderSwitch(WebSocketSession session,String provider) {
|
|
|
+ try {
|
|
|
+ if (StrUtil.isBlank(provider)) {
|
|
|
+ sendError(session, "缺少provider参数");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ AiProviderType providerType = AiProviderType.fromCode(provider);
|
|
|
+
|
|
|
+ boolean available = aiChatService.isProviderAvailable(providerType);
|
|
|
+
|
|
|
+ // 修复方法调用
|
|
|
+ sendMessage(session, WebSocketMessage.systemMessage(Map.of(
|
|
|
+ "event", "provider_status",
|
|
|
+ "provider", provider,
|
|
|
+ "available", available,
|
|
|
+ "models", aiChatService.getSupportedModels(providerType)
|
|
|
+ )));
|
|
|
+
|
|
|
+ } catch (IllegalArgumentException e) {
|
|
|
+ sendError(session, "无效的提供商: " + provider);
|
|
|
+ } catch (Exception e) {
|
|
|
+ sendError(session, "提供商切换失败: " + e.getMessage());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 处理连接测试 - 修复方法调用
|
|
|
+ */
|
|
|
+ private void handleConnectionTest(WebSocketSession session) {
|
|
|
+ sendMessage(session, WebSocketMessage.systemMessage(Map.of(
|
|
|
+ "event", "connection_test",
|
|
|
+ "status", "ok",
|
|
|
+ "timestamp", System.currentTimeMillis(),
|
|
|
+ "message", "连接正常"
|
|
|
+ )));
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 处理心跳包
|
|
|
+ */
|
|
|
+ private void handlePing(WebSocketSession session) {
|
|
|
+ sendMessage(session, WebSocketMessage.pong());
|
|
|
+ logger.debug("心跳响应: {}", session.getId());
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 发送消息到WebSocket客户端
|
|
|
+ */
|
|
|
+ private void sendMessage(WebSocketSession session, WebSocketMessage message) {
|
|
|
+ if (session == null || !session.isOpen()) {
|
|
|
+ logger.warn("尝试向已关闭的WebSocket会话发送消息: {}",
|
|
|
+ session != null ? session.getId() : "null");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ try {
|
|
|
+ String json = objectMapper.writeValueAsString(message);
|
|
|
+ session.sendMessage(new TextMessage(json));
|
|
|
+ logger.trace("发送WebSocket消息成功: {} - {}", session.getId(), message.getType());
|
|
|
+ } catch (IOException e) {
|
|
|
+ logger.error("发送WebSocket消息失败: {}", session.getId(), e);
|
|
|
+ // 如果发送失败,可能连接已断开,执行清理
|
|
|
+ cleanup(session.getId());
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.error("序列化WebSocket消息失败: {}", session.getId(), e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 发送错误消息
|
|
|
+ */
|
|
|
+ private void sendError(WebSocketSession session, String errorMessage) {
|
|
|
+ logger.debug("发送错误消息: {} - {}", session.getId(), errorMessage);
|
|
|
+ sendMessage(session, WebSocketMessage.chatError(errorMessage));
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 从WebSocket会话中提取用户ID
|
|
|
+ */
|
|
|
+ private String extractProvider(WebSocketSession session) {
|
|
|
+ try {
|
|
|
+ String query = session.getUri().getQuery();
|
|
|
+ if (query != null && query.contains("provider=")) {
|
|
|
+ String[] parts = query.split("provider=");
|
|
|
+ if (parts.length > 1) {
|
|
|
+ String userId = parts[1].split("&")[0];
|
|
|
+ return java.net.URLDecoder.decode(userId, "UTF-8");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.debug("not found provider : {}", session.getId(), e);
|
|
|
+ }
|
|
|
+ // 如果无法提取用户ID,使用会话ID生成一个
|
|
|
+ return "Storlead-Ai";
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 从WebSocket会话中提取用户ID
|
|
|
+ */
|
|
|
+ private String extractModel(WebSocketSession session) {
|
|
|
+ try {
|
|
|
+ String query = session.getUri().getQuery();
|
|
|
+ if (query != null && query.contains("provider=")) {
|
|
|
+ String[] parts = query.split("provider=");
|
|
|
+ if (parts.length > 1) {
|
|
|
+ String userId = parts[1].split("&")[0];
|
|
|
+ return java.net.URLDecoder.decode(userId, "UTF-8");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.debug("not found provider : {}", session.getId(), e);
|
|
|
+ }
|
|
|
+ // 如果无法提取用户ID,使用会话ID生成一个
|
|
|
+ return "all";
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 从WebSocket会话中提取用户ID
|
|
|
+ */
|
|
|
+ private String extractUserId(WebSocketSession session) {
|
|
|
+ try {
|
|
|
+ String query = session.getUri().getQuery();
|
|
|
+ if (query != null && query.contains("userId=")) {
|
|
|
+ String[] parts = query.split("userId=");
|
|
|
+ if (parts.length > 1) {
|
|
|
+ String userId = parts[1].split("&")[0];
|
|
|
+ return java.net.URLDecoder.decode(userId, "UTF-8");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.debug("提取用户ID失败: {}", session.getId(), e);
|
|
|
+ }
|
|
|
+
|
|
|
+ // 如果无法提取用户ID,使用会话ID生成一个
|
|
|
+ return "user_" + session.getId().substring(0, 8);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 清理会话资源
|
|
|
+ */
|
|
|
+ private void cleanup(String sessionId) {
|
|
|
+ try {
|
|
|
+ connectionManager.removeConnection(sessionId);
|
|
|
+ sessionManager.removeSession(sessionId);
|
|
|
+ sessionProvider.remove(sessionId);
|
|
|
+ sessionModel.remove(sessionId);
|
|
|
+ logger.debug("清理会话资源完成: {}", sessionId);
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.error("清理会话资源失败: {}", sessionId, e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 广播消息给所有连接的客户端
|
|
|
+ */
|
|
|
+ public void broadcast(WebSocketMessage message) {
|
|
|
+ Map<String, WebSocketSession> connections = connectionManager.getAllConnections();
|
|
|
+ int successCount = 0;
|
|
|
+ int failCount = 0;
|
|
|
+
|
|
|
+ for (Map.Entry<String, WebSocketSession> entry : connections.entrySet()) {
|
|
|
+ try {
|
|
|
+ sendMessage(entry.getValue(), message);
|
|
|
+ successCount++;
|
|
|
+ } catch (Exception e) {
|
|
|
+ failCount++;
|
|
|
+ logger.warn("广播消息失败: {}", entry.getKey(), e);
|
|
|
+ // 清理失败的连接
|
|
|
+ cleanup(entry.getKey());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ logger.debug("广播消息完成 - 成功: {}, 失败: {}", successCount, failCount);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 获取当前连接数
|
|
|
+ */
|
|
|
+ public int getConnectionCount() {
|
|
|
+ return connectionManager.getConnectionCount();
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 获取会话统计信息
|
|
|
+ */
|
|
|
+ public Map<String, Object> getSessionStats() {
|
|
|
+ return Map.of(
|
|
|
+ "totalConnections", connectionManager.getConnectionCount(),
|
|
|
+ "sessionStats", sessionManager.getSessionStats()
|
|
|
+ );
|
|
|
+ }
|
|
|
+}
|