| 
					
				 | 
			
			
				@@ -0,0 +1,428 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+package com.storlead.ai.websocket; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import cn.hutool.core.util.StrUtil; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import com.storlead.ai.core.AiProviderType; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import com.storlead.ai.core.ChatRequest; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import com.storlead.ai.service.AiChatService; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import com.fasterxml.jackson.databind.ObjectMapper; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.slf4j.Logger; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.slf4j.LoggerFactory; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.springframework.beans.factory.annotation.Autowired; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.springframework.stereotype.Component; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.springframework.web.socket.*; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import reactor.core.Disposable; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import java.io.IOException; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import java.util.Map; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import java.util.concurrent.ConcurrentHashMap; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+/** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ * 修复后的WebSocket处理器 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ * 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ * @author your-name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ * @since 1.0.0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+@Component 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+public class AiChatWebSocketHandler implements WebSocketHandler { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private static final Logger logger = LoggerFactory.getLogger(AiChatWebSocketHandler.class); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @Autowired 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private AiChatService aiChatService; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @Autowired 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private WebSocketConnectionManager connectionManager; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @Autowired 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private WebSocketSessionManager sessionManager; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @Autowired 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private ObjectMapper objectMapper; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private final Map<String, String> sessionProvider = new ConcurrentHashMap<>(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private final Map<String, String> sessionModel = new ConcurrentHashMap<>(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @Override 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    public void afterConnectionEstablished(WebSocketSession session) throws Exception { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String sessionId = session.getId(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String userId = extractUserId(session); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String provider = extractProvider(session); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String model = extractModel(session); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        sessionProvider.put(sessionId,provider); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        sessionModel.put(sessionId,model); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        connectionManager.addConnection(sessionId, session); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        sessionManager.createSession(sessionId, userId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        logger.info("WebSocket连接建立: {} (用户: {})", sessionId, userId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        // 发送连接成功消息 - 修复方法调用 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        sendMessage(session, WebSocketMessage.systemMessage(Map.of( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "event", "connected", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "sessionId", sessionId, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "userId", userId, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "supportedProviders", aiChatService.getSupportedProviders().stream() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        .map(type -> Map.of( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                "code", type.getCode(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                "name", type.getDisplayName(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                "available", aiChatService.isProviderAvailable(type) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        )) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        .toList() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ))); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @Override 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    public void handleMessage(WebSocketSession session, org.springframework.web.socket.WebSocketMessage<?> payload) throws Exception { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String sessionId = session.getId(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        try { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            String provider = sessionProvider.get(sessionId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            String model = sessionModel.get(sessionId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            String message = payload.getPayload().toString(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//            String messageType = messageNode.get("type").asText(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//            logger.debug("收到消息: {} - {}", sessionId, messageType); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            handleChatRequest(session, message,provider,model); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//            switch (messageType) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                case "CHAT_REQUEST" -> handleChatRequest(session, message,provider,message); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                case "CANCEL_CHAT" -> handleCancelChat(session); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                case "PING" -> handlePing(session); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                case "PROVIDER_SWITCH" -> handleProviderSwitch(session, provider); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                case "CONNECTION_TEST" -> handleConnectionTest(session); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                default -> sendError(session, "未知消息类型: " + messageType); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } catch (Exception e) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logger.error("处理消息失败: {}", sessionId, e); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sendError(session, "消息处理失败: " + e.getMessage()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @Override 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String sessionId = session.getId(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        logger.error("WebSocket传输错误: {}", sessionId, exception); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        cleanup(sessionId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @Override 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String sessionId = session.getId(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        logger.info("WebSocket连接关闭: {} - {}", sessionId, closeStatus); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        cleanup(sessionId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @Override 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    public boolean supportsPartialMessages() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return false; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     * 处理聊天请求 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private void handleChatRequest(WebSocketSession session,String message,String provider,String model) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String sessionId = session.getId(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        // 检查是否正在聊天中 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if (sessionManager.isChatting(sessionId)) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sendError(session, "正在聊天中,请等待完成或先取消当前对话"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        try { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 验证必要字段 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if (StrUtil.isBlank(message)) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sendError(session, "消息内容不能为空"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                return; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 构建聊天请求 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ChatRequest.Builder requestBuilder = ChatRequest.builder().message(message); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//            // 可选参数 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//            if (messageNode.has("systemPrompt") && !messageNode.get("systemPrompt").asText().trim().isEmpty()) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                requestBuilder.systemPrompt(messageNode.get("systemPrompt").asText()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if (StrUtil.isNotBlank(model)) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                requestBuilder.model(model); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ChatRequest chatRequest = requestBuilder.build(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 验证AI提供商 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            AiProviderType providerType; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            try { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                providerType = AiProviderType.fromCode(provider); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } catch (IllegalArgumentException e) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sendError(session, "不支持的AI提供商: " + provider); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                return; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 检查提供商是否可用 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if (!aiChatService.isProviderAvailable(providerType)) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sendError(session, providerType.getDisplayName() + " 服务当前不可用"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                return; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 开始流式聊天 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            startStreamChat(session, sessionId, providerType, chatRequest); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } catch (Exception e) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logger.error("处理聊天请求失败: {}", sessionId, e); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sendError(session, "请求处理失败: " + e.getMessage()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     * 开始流式聊天 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private void startStreamChat(WebSocketSession session, String sessionId, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                 AiProviderType providerType, ChatRequest chatRequest) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        logger.debug("开始流式聊天: {} - 提供商: {}", sessionId, providerType.getDisplayName()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        // 订阅流式响应 - 修复方法调用 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Disposable subscription = aiChatService.chatStream(providerType, chatRequest) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                .doOnSubscribe(sub -> { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    logger.debug("流式聊天订阅成功: {}", sessionId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                }) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                .subscribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        // 处理每个数据块 - 修复方法调用 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        chunk -> { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            if (chunk != null && !chunk.trim().isEmpty()) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                sendMessage(session, WebSocketMessage.chatChunk(chunk)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        }, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        // 处理错误 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        error -> { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            logger.error("流式聊天错误: {} - 提供商: {}", sessionId, providerType.getDisplayName(), error); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            sendError(session, "聊天失败: " + error.getMessage()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            sessionManager.cancelStreamSubscription(sessionId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        }, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        // 处理完成 - 修复方法调用 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        () -> { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            logger.debug("流式聊天完成: {} - 提供商: {}", sessionId, providerType.getDisplayName()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            sendMessage(session, WebSocketMessage.chatComplete()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            sessionManager.cancelStreamSubscription(sessionId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        // 保存订阅以便后续取消 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        sessionManager.setStreamSubscription(sessionId, subscription); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     * 处理取消聊天 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private void handleCancelChat(WebSocketSession session) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String sessionId = session.getId(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if (!sessionManager.isChatting(sessionId)) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sendMessage(session, WebSocketMessage.systemMessage("当前没有进行中的聊天")); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        sessionManager.cancelStreamSubscription(sessionId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        sendMessage(session, WebSocketMessage.systemMessage("聊天已取消")); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        logger.debug("用户取消聊天: {}", sessionId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     * 处理提供商切换 - 修复方法调用 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private void handleProviderSwitch(WebSocketSession session,String provider) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        try { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if (StrUtil.isBlank(provider)) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sendError(session, "缺少provider参数"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                return; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            AiProviderType providerType = AiProviderType.fromCode(provider); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            boolean available = aiChatService.isProviderAvailable(providerType); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 修复方法调用 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sendMessage(session, WebSocketMessage.systemMessage(Map.of( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "event", "provider_status", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "provider", provider, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "available", available, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "models", aiChatService.getSupportedModels(providerType) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ))); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } catch (IllegalArgumentException e) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sendError(session, "无效的提供商: " + provider); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } catch (Exception e) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sendError(session, "提供商切换失败: " + e.getMessage()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     * 处理连接测试 - 修复方法调用 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private void handleConnectionTest(WebSocketSession session) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        sendMessage(session, WebSocketMessage.systemMessage(Map.of( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "event", "connection_test", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "status", "ok", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "timestamp", System.currentTimeMillis(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "message", "连接正常" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ))); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     * 处理心跳包 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private void handlePing(WebSocketSession session) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        sendMessage(session, WebSocketMessage.pong()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        logger.debug("心跳响应: {}", session.getId()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     * 发送消息到WebSocket客户端 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private void sendMessage(WebSocketSession session, WebSocketMessage message) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if (session == null || !session.isOpen()) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logger.warn("尝试向已关闭的WebSocket会话发送消息: {}", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    session != null ? session.getId() : "null"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        try { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            String json = objectMapper.writeValueAsString(message); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            session.sendMessage(new TextMessage(json)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logger.trace("发送WebSocket消息成功: {} - {}", session.getId(), message.getType()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } catch (IOException e) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logger.error("发送WebSocket消息失败: {}", session.getId(), e); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 如果发送失败,可能连接已断开,执行清理 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            cleanup(session.getId()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } catch (Exception e) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logger.error("序列化WebSocket消息失败: {}", session.getId(), e); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     * 发送错误消息 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private void sendError(WebSocketSession session, String errorMessage) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        logger.debug("发送错误消息: {} - {}", session.getId(), errorMessage); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        sendMessage(session, WebSocketMessage.chatError(errorMessage)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     * 从WebSocket会话中提取用户ID 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private String extractProvider(WebSocketSession session) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        try { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            String query = session.getUri().getQuery(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if (query != null && query.contains("provider=")) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                String[] parts = query.split("provider="); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if (parts.length > 1) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    String userId = parts[1].split("&")[0]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    return java.net.URLDecoder.decode(userId, "UTF-8"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        }  catch (Exception e) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logger.debug("not found provider : {}", session.getId(), e); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        // 如果无法提取用户ID,使用会话ID生成一个 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return "Storlead-Ai"; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     * 从WebSocket会话中提取用户ID 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private String extractModel(WebSocketSession session) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        try { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            String query = session.getUri().getQuery(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if (query != null && query.contains("provider=")) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                String[] parts = query.split("provider="); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if (parts.length > 1) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    String userId = parts[1].split("&")[0]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    return java.net.URLDecoder.decode(userId, "UTF-8"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } catch (Exception e) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logger.debug("not found provider : {}", session.getId(), e); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        // 如果无法提取用户ID,使用会话ID生成一个 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return "all"; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     * 从WebSocket会话中提取用户ID 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private String extractUserId(WebSocketSession session) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        try { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            String query = session.getUri().getQuery(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if (query != null && query.contains("userId=")) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                String[] parts = query.split("userId="); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if (parts.length > 1) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    String userId = parts[1].split("&")[0]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    return java.net.URLDecoder.decode(userId, "UTF-8"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } catch (Exception e) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logger.debug("提取用户ID失败: {}", session.getId(), e); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        // 如果无法提取用户ID,使用会话ID生成一个 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return "user_" + session.getId().substring(0, 8); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     * 清理会话资源 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private void cleanup(String sessionId) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        try { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            connectionManager.removeConnection(sessionId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sessionManager.removeSession(sessionId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sessionProvider.remove(sessionId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sessionModel.remove(sessionId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logger.debug("清理会话资源完成: {}", sessionId); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } catch (Exception e) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logger.error("清理会话资源失败: {}", sessionId, e); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     * 广播消息给所有连接的客户端 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    public void broadcast(WebSocketMessage message) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Map<String, WebSocketSession> connections = connectionManager.getAllConnections(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        int successCount = 0; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        int failCount = 0; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for (Map.Entry<String, WebSocketSession> entry : connections.entrySet()) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            try { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sendMessage(entry.getValue(), message); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                successCount++; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } catch (Exception e) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                failCount++; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                logger.warn("广播消息失败: {}", entry.getKey(), e); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                // 清理失败的连接 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                cleanup(entry.getKey()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        logger.debug("广播消息完成 - 成功: {}, 失败: {}", successCount, failCount); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     * 获取当前连接数 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    public int getConnectionCount() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return connectionManager.getConnectionCount(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     * 获取会话统计信息 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    public Map<String, Object> getSessionStats() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return Map.of( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "totalConnections", connectionManager.getConnectionCount(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "sessionStats", sessionManager.getSessionStats() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+} 
			 |