1811872455@163.com 1 月之前
父节点
当前提交
07a8f5c711
共有 21 个文件被更改,包括 1082 次插入49 次删除
  1. 104 49
      .idea/workspace.xml
  2. 6 0
      storlead-ai-api/pom.xml
  3. 71 0
      storlead-ai-api/src/main/java/com/storlead/ai/controller/WebSocketEndpointController.java
  4. 1 0
      storlead-ai-api/src/main/java/com/storlead/ai/core/AiProviderType.java
  5. 56 0
      storlead-ai-api/src/main/java/com/storlead/ai/service/WebSocketService.java
  6. 428 0
      storlead-ai-api/src/main/java/com/storlead/ai/websocket/AiChatWebSocketHandler.java
  7. 19 0
      storlead-ai-api/src/main/java/com/storlead/ai/websocket/MessageType.java
  8. 41 0
      storlead-ai-api/src/main/java/com/storlead/ai/websocket/WebSocketConfig.java
  9. 82 0
      storlead-ai-api/src/main/java/com/storlead/ai/websocket/WebSocketConnectionManager.java
  10. 134 0
      storlead-ai-api/src/main/java/com/storlead/ai/websocket/WebSocketMessage.java
  11. 140 0
      storlead-ai-api/src/main/java/com/storlead/ai/websocket/WebSocketSessionManager.java
  12. 二进制
      storlead-ai-api/target/classes/com/storlead/ai/controller/WebSocketEndpointController.class
  13. 二进制
      storlead-ai-api/target/classes/com/storlead/ai/service/WebSocketService.class
  14. 二进制
      storlead-ai-api/target/classes/com/storlead/ai/service/impl/OpenAiService.class
  15. 二进制
      storlead-ai-api/target/classes/com/storlead/ai/websocket/AiChatWebSocketHandler.class
  16. 二进制
      storlead-ai-api/target/classes/com/storlead/ai/websocket/MessageType.class
  17. 二进制
      storlead-ai-api/target/classes/com/storlead/ai/websocket/WebSocketConfig.class
  18. 二进制
      storlead-ai-api/target/classes/com/storlead/ai/websocket/WebSocketConnectionManager.class
  19. 二进制
      storlead-ai-api/target/classes/com/storlead/ai/websocket/WebSocketMessage.class
  20. 二进制
      storlead-ai-api/target/classes/com/storlead/ai/websocket/WebSocketSessionManager$SessionInfo.class
  21. 二进制
      storlead-ai-api/target/classes/com/storlead/ai/websocket/WebSocketSessionManager.class

+ 104 - 49
.idea/workspace.xml

@@ -4,28 +4,19 @@
     <option name="autoReloadType" value="SELECTIVE" />
   </component>
   <component name="ChangeListManager">
-    <list default="true" id="2dc641fc-1465-479e-874d-97069c194ded" name="Changes" comment="">
-      <change afterPath="$PROJECT_DIR$/pom.xml" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/pom.xml" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/AiPlatformApplication.java" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/config/properties/OpenAiProperties.java" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/controller/AiChatController.java" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/core/AiProviderType.java" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/core/AiService.java" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/core/ChatRequest.java" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/core/ChatResponse.java" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/exception/AiServiceException.java" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/factory/AiServiceFactory.java" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/factory/AiServiceFactoryManager.java" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/factory/impl/OpenAiServiceFactory.java" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/service/AiChatService.java" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/service/impl/OpenAiService.java" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/util/HttpClientUtil.java" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/resources/application-dev.yml" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/resources/application-prod.yml" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/resources/application-test.yml" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/resources/application-uat.yml" afterDir="false" />
-      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/resources/application.yml" afterDir="false" />
+    <list default="true" id="2dc641fc-1465-479e-874d-97069c194ded" name="Changes" comment="ai项目">
+      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/controller/WebSocketEndpointController.java" afterDir="false" />
+      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/service/WebSocketService.java" afterDir="false" />
+      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/websocket/AiChatWebSocketHandler.java" afterDir="false" />
+      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/websocket/MessageType.java" afterDir="false" />
+      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/websocket/WebSocketConfig.java" afterDir="false" />
+      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/websocket/WebSocketConnectionManager.java" afterDir="false" />
+      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/websocket/WebSocketMessage.java" afterDir="false" />
+      <change afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/websocket/WebSocketSessionManager.java" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/storlead-ai-api/pom.xml" beforeDir="false" afterPath="$PROJECT_DIR$/storlead-ai-api/pom.xml" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/core/AiProviderType.java" beforeDir="false" afterPath="$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/core/AiProviderType.java" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/storlead-ai-api/target/classes/com/storlead/ai/service/impl/OpenAiService.class" beforeDir="false" afterPath="$PROJECT_DIR$/storlead-ai-api/target/classes/com/storlead/ai/service/impl/OpenAiService.class" afterDir="false" />
     </list>
     <option name="SHOW_DIALOG" value="false" />
     <option name="HIGHLIGHT_CONFLICTS" value="true" />
@@ -35,6 +26,7 @@
   <component name="FileTemplateManagerImpl">
     <option name="RECENT_TEMPLATES">
       <list>
+        <option value="Enum" />
         <option value="Class" />
       </list>
     </option>
@@ -57,33 +49,33 @@
   <component name="ProjectViewState">
     <option name="showLibraryContents" value="true" />
   </component>
-  <component name="PropertiesComponent"><![CDATA[{
-  "keyToString": {
-    "RequestMappingsPanelOrder0": "0",
-    "RequestMappingsPanelOrder1": "1",
-    "RequestMappingsPanelWidth0": "75",
-    "RequestMappingsPanelWidth1": "75",
-    "RunOnceActivity.CodyAccountHistoryMigration": "true",
-    "RunOnceActivity.CodyConvertUrlToCodebaseName": "true",
-    "RunOnceActivity.CodyHistoryLlmMigration": "true",
-    "RunOnceActivity.CodyProjectSettingsMigration": "true",
-    "RunOnceActivity.OpenProjectViewOnStart": "true",
-    "RunOnceActivity.ShowReadmeOnStart": "true",
-    "WebServerToolWindowFactoryState": "false",
-    "jdk.selected.JAVA_MODULE": "ms-17",
-    "last_opened_file_path": "D:/chenkq-work/git/storlead-ai-platform/storlead-ai-api/src/main/resources",
-    "node.js.detected.package.eslint": "true",
-    "node.js.detected.package.tslint": "true",
-    "node.js.selected.package.eslint": "(autodetect)",
-    "node.js.selected.package.tslint": "(autodetect)",
-    "project.structure.last.edited": "Modules",
-    "project.structure.proportion": "0.0",
-    "project.structure.side.proportion": "0.0",
-    "settings.editor.selected.configurable": "MavenSettings",
-    "spring.configuration.checksum": "1259999022d84b755f5d094d936c15a0",
-    "vue.rearranger.settings.migration": "true"
+  <component name="PropertiesComponent">{
+  &quot;keyToString&quot;: {
+    &quot;RequestMappingsPanelOrder0&quot;: &quot;0&quot;,
+    &quot;RequestMappingsPanelOrder1&quot;: &quot;1&quot;,
+    &quot;RequestMappingsPanelWidth0&quot;: &quot;75&quot;,
+    &quot;RequestMappingsPanelWidth1&quot;: &quot;75&quot;,
+    &quot;RunOnceActivity.CodyAccountHistoryMigration&quot;: &quot;true&quot;,
+    &quot;RunOnceActivity.CodyConvertUrlToCodebaseName&quot;: &quot;true&quot;,
+    &quot;RunOnceActivity.CodyHistoryLlmMigration&quot;: &quot;true&quot;,
+    &quot;RunOnceActivity.CodyProjectSettingsMigration&quot;: &quot;true&quot;,
+    &quot;RunOnceActivity.OpenProjectViewOnStart&quot;: &quot;true&quot;,
+    &quot;RunOnceActivity.ShowReadmeOnStart&quot;: &quot;true&quot;,
+    &quot;WebServerToolWindowFactoryState&quot;: &quot;false&quot;,
+    &quot;jdk.selected.JAVA_MODULE&quot;: &quot;ms-17&quot;,
+    &quot;last_opened_file_path&quot;: &quot;D:/chenkq-work/git/storlead-ai-platform/storlead-ai-api/src/main/resources&quot;,
+    &quot;node.js.detected.package.eslint&quot;: &quot;true&quot;,
+    &quot;node.js.detected.package.tslint&quot;: &quot;true&quot;,
+    &quot;node.js.selected.package.eslint&quot;: &quot;(autodetect)&quot;,
+    &quot;node.js.selected.package.tslint&quot;: &quot;(autodetect)&quot;,
+    &quot;project.structure.last.edited&quot;: &quot;Modules&quot;,
+    &quot;project.structure.proportion&quot;: &quot;0.0&quot;,
+    &quot;project.structure.side.proportion&quot;: &quot;0.0&quot;,
+    &quot;settings.editor.selected.configurable&quot;: &quot;MavenSettings&quot;,
+    &quot;spring.configuration.checksum&quot;: &quot;1259999022d84b755f5d094d936c15a0&quot;,
+    &quot;vue.rearranger.settings.migration&quot;: &quot;true&quot;
   }
-}]]></component>
+}</component>
   <component name="ReactorSettings">
     <option name="notificationShown" value="true" />
   </component>
@@ -128,13 +120,76 @@
       <workItem from="1756176360247" duration="6014000" />
       <workItem from="1756347439468" duration="3161000" />
       <workItem from="1757642133544" duration="17020000" />
-      <workItem from="1758250810392" duration="10129000" />
+      <workItem from="1758250810392" duration="10282000" />
+      <workItem from="1758505253719" duration="15768000" />
+      <workItem from="1758591153350" duration="8846000" />
+      <workItem from="1758780969521" duration="11096000" />
     </task>
+    <task id="LOCAL-00001" summary="ai项目">
+      <option name="closed" value="true" />
+      <created>1758276829909</created>
+      <option name="number" value="00001" />
+      <option name="presentableId" value="LOCAL-00001" />
+      <option name="project" value="LOCAL" />
+      <updated>1758276829909</updated>
+    </task>
+    <option name="localTasksCounter" value="2" />
     <servers />
   </component>
   <component name="TypeScriptGeneratedFilesManager">
     <option name="version" value="3" />
   </component>
+  <component name="Vcs.Log.Tabs.Properties">
+    <option name="OPEN_GENERIC_TABS">
+      <map>
+        <entry key="acbd7e43-8d6f-44e7-8543-30ca2008e0a7" value="TOOL_WINDOW" />
+      </map>
+    </option>
+    <option name="TAB_STATES">
+      <map>
+        <entry key="acbd7e43-8d6f-44e7-8543-30ca2008e0a7">
+          <value>
+            <State>
+              <option name="FILTERS">
+                <map>
+                  <entry key="branch">
+                    <value>
+                      <list>
+                        <option value="HEAD" />
+                      </list>
+                    </value>
+                  </entry>
+                  <entry key="structure">
+                    <value>
+                      <list>
+                        <option value="dir:D:/chenkq-work/git/storlead-ai-platform/storlead-ai-api/src/main/java/com/storlead/ai/controller" />
+                      </list>
+                    </value>
+                  </entry>
+                </map>
+              </option>
+              <option name="SHOW_ONLY_AFFECTED_CHANGES" value="true" />
+            </State>
+          </value>
+        </entry>
+      </map>
+    </option>
+  </component>
+  <component name="VcsManagerConfiguration">
+    <MESSAGE value="ai项目" />
+    <option name="LAST_COMMIT_MESSAGE" value="ai项目" />
+  </component>
+  <component name="XDebuggerManager">
+    <breakpoint-manager>
+      <breakpoints>
+        <line-breakpoint enabled="true" type="java-line">
+          <url>file://$PROJECT_DIR$/storlead-ai-api/src/main/java/com/storlead/ai/websocket/AiChatWebSocketHandler.java</url>
+          <line>1</line>
+          <option name="timeStamp" value="13" />
+        </line-breakpoint>
+      </breakpoints>
+    </breakpoint-manager>
+  </component>
   <component name="XSLT-Support.FileAssociations.UIState">
     <expand />
     <select />

+ 6 - 0
storlead-ai-api/pom.xml

@@ -62,6 +62,12 @@
             <scope>test</scope>
         </dependency>
 
+        <!-- Spring Boot WebSocket -->
+        <dependency>
+            <groupId>org.springframework.boot</groupId>
+            <artifactId>spring-boot-starter-websocket</artifactId>
+        </dependency>
+
         <!-- Spring Boot WebFlux (响应式编程) -->
         <dependency>
             <groupId>org.springframework.boot</groupId>

+ 71 - 0
storlead-ai-api/src/main/java/com/storlead/ai/controller/WebSocketEndpointController.java

@@ -0,0 +1,71 @@
+package com.storlead.ai.controller;
+
+
+import com.storlead.ai.service.WebSocketService;
+import io.swagger.v3.oas.annotations.Operation;
+import io.swagger.v3.oas.annotations.tags.Tag;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Controller;
+import org.springframework.web.bind.annotation.*;
+import reactor.core.publisher.Mono;
+
+import java.util.Map;
+
+/**
+ * @program: WebSocket端点控制器
+ * @description: 提供WebSocket连接信息和测试页面
+ * @author: chenkq
+ * @create: 2025-09-25 14:42
+ */
+@Controller
+@Tag(name = "WebSocket端点", description = "WebSocket连接和测试相关")
+public class WebSocketEndpointController {
+
+    @Autowired
+    private WebSocketService webSocketService;
+
+    /**
+     * WebSocket测试页面
+     */
+    @GetMapping("/websocket")
+    @Operation(summary = "WebSocket测试页面", description = "提供WebSocket连接测试的HTML页面")
+    public String websocketTestPage() {
+        return "websocket-test"; // 返回静态HTML页面
+    }
+
+    /**
+     * 获取WebSocket端点信息
+     */
+    @GetMapping("/api/v1/websocket/endpoints")
+    @ResponseBody
+    @Operation(summary = "获取WebSocket端点", description = "返回可用的WebSocket连接端点")
+    public Mono<Map<String, Object>> getWebSocketEndpoints() {
+        return Mono.fromCallable(() -> {
+            String baseUrl = "ws://localhost:8080"; // 在生产环境中应该动态获取
+
+            return Map.of(
+                    "endpoints", Map.of(
+                            "websocket", baseUrl + "/ws/ai/chat",
+                            "websocket_native", baseUrl + "/ws/ai/chat/native",
+                            "sockjs", baseUrl + "/ws/ai/chat"
+                    ),
+                    "usage", Map.of(
+                            "connect", "ws://localhost:8080/ws/ai/chat?userId=your_user_id",
+                            "protocols", java.util.List.of("websocket", "sockjs"),
+                            "example", Map.of(
+                                    "url", "ws://localhost:8080/ws/ai/chat?userId=test123",
+                                    "message", Map.of(
+                                            "type", "CHAT_REQUEST",
+                                            "message", "你好",
+                                            "provider", "openai"
+                                    )
+                            )
+                    ),
+                    "status", Map.of(
+                            "active_connections", webSocketService.getConnectionCount(),
+                            "available", true
+                    )
+            );
+        });
+    }
+}

+ 1 - 0
storlead-ai-api/src/main/java/com/storlead/ai/core/AiProviderType.java

@@ -8,6 +8,7 @@ package com.storlead.ai.core;
  */
 public enum AiProviderType {
 
+    STROLEAAI("storleadai", "Ollama"),
     /**
      * OpenAI GPT系列
      */

+ 56 - 0
storlead-ai-api/src/main/java/com/storlead/ai/service/WebSocketService.java

@@ -0,0 +1,56 @@
+package com.storlead.ai.service;
+
+import com.storlead.ai.websocket.AiChatWebSocketHandler;
+import com.storlead.ai.websocket.WebSocketMessage;
+import com.storlead.ai.websocket.WebSocketSessionManager;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Service;
+import java.util.Map;
+
+/**
+ * @program: storlead-ai-platform
+ * @description:
+ * @author: chenkq
+ * @create: 2025-09-22 15:07
+ */
+@Service
+public class WebSocketService {
+
+    @Autowired
+    private AiChatWebSocketHandler webSocketHandler;
+
+    @Autowired
+    private WebSocketSessionManager sessionManager;
+
+    /**
+     * 获取连接数
+     */
+    public int getConnectionCount() {
+        return webSocketHandler.getConnectionCount();
+    }
+
+    /**
+     * 获取会话统计
+     */
+    public Map<String, Object> getSessionStats() {
+        return sessionManager.getSessionStats();
+    }
+
+    /**
+     * 广播消息
+     */
+    public void broadcast(String message) {
+        webSocketHandler.broadcast(WebSocketMessage.systemMessage(message));
+    }
+
+    /**
+     * 广播系统通知
+     */
+    public void broadcastNotification(String type, Object data) {
+        webSocketHandler.broadcast(WebSocketMessage.systemMessage(Map.of(
+                "event", type,
+                "data", data,
+                "timestamp", System.currentTimeMillis()
+        )));
+    }
+}

+ 428 - 0
storlead-ai-api/src/main/java/com/storlead/ai/websocket/AiChatWebSocketHandler.java

@@ -0,0 +1,428 @@
+package com.storlead.ai.websocket;
+import cn.hutool.core.util.StrUtil;
+import com.storlead.ai.core.AiProviderType;
+import com.storlead.ai.core.ChatRequest;
+import com.storlead.ai.service.AiChatService;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Component;
+import org.springframework.web.socket.*;
+import reactor.core.Disposable;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+
+/**
+ * 修复后的WebSocket处理器
+ *
+ * @author your-name
+ * @since 1.0.0
+ */
+@Component
+public class AiChatWebSocketHandler implements WebSocketHandler {
+
+    private static final Logger logger = LoggerFactory.getLogger(AiChatWebSocketHandler.class);
+
+    @Autowired
+    private AiChatService aiChatService;
+
+    @Autowired
+    private WebSocketConnectionManager connectionManager;
+
+    @Autowired
+    private WebSocketSessionManager sessionManager;
+
+    @Autowired
+    private ObjectMapper objectMapper;
+
+    private final Map<String, String> sessionProvider = new ConcurrentHashMap<>();
+    private final Map<String, String> sessionModel = new ConcurrentHashMap<>();
+    @Override
+    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
+        String sessionId = session.getId();
+        String userId = extractUserId(session);
+        String provider = extractProvider(session);
+        String model = extractModel(session);
+        sessionProvider.put(sessionId,provider);
+        sessionModel.put(sessionId,model);
+        connectionManager.addConnection(sessionId, session);
+        sessionManager.createSession(sessionId, userId);
+
+        logger.info("WebSocket连接建立: {} (用户: {})", sessionId, userId);
+
+        // 发送连接成功消息 - 修复方法调用
+        sendMessage(session, WebSocketMessage.systemMessage(Map.of(
+                "event", "connected",
+                "sessionId", sessionId,
+                "userId", userId,
+                "supportedProviders", aiChatService.getSupportedProviders().stream()
+                        .map(type -> Map.of(
+                                "code", type.getCode(),
+                                "name", type.getDisplayName(),
+                                "available", aiChatService.isProviderAvailable(type)
+                        ))
+                        .toList()
+        )));
+    }
+
+    @Override
+    public void handleMessage(WebSocketSession session, org.springframework.web.socket.WebSocketMessage<?> payload) throws Exception {
+        String sessionId = session.getId();
+
+        try {
+
+            String provider = sessionProvider.get(sessionId);
+            String model = sessionModel.get(sessionId);
+            String message = payload.getPayload().toString();
+
+//            String messageType = messageNode.get("type").asText();
+//            logger.debug("收到消息: {} - {}", sessionId, messageType);
+            handleChatRequest(session, message,provider,model);
+//            switch (messageType) {
+//                case "CHAT_REQUEST" -> handleChatRequest(session, message,provider,message);
+//                case "CANCEL_CHAT" -> handleCancelChat(session);
+//                case "PING" -> handlePing(session);
+//                case "PROVIDER_SWITCH" -> handleProviderSwitch(session, provider);
+//                case "CONNECTION_TEST" -> handleConnectionTest(session);
+//                default -> sendError(session, "未知消息类型: " + messageType);
+//            }
+
+        } catch (Exception e) {
+            logger.error("处理消息失败: {}", sessionId, e);
+            sendError(session, "消息处理失败: " + e.getMessage());
+        }
+    }
+
+
+    @Override
+    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
+        String sessionId = session.getId();
+        logger.error("WebSocket传输错误: {}", sessionId, exception);
+        cleanup(sessionId);
+    }
+
+    @Override
+    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
+        String sessionId = session.getId();
+        logger.info("WebSocket连接关闭: {} - {}", sessionId, closeStatus);
+        cleanup(sessionId);
+    }
+
+    @Override
+    public boolean supportsPartialMessages() {
+        return false;
+    }
+
+    /**
+     * 处理聊天请求
+     */
+    private void handleChatRequest(WebSocketSession session,String message,String provider,String model) {
+        String sessionId = session.getId();
+
+        // 检查是否正在聊天中
+        if (sessionManager.isChatting(sessionId)) {
+            sendError(session, "正在聊天中,请等待完成或先取消当前对话");
+            return;
+        }
+
+        try {
+            // 验证必要字段
+            if (StrUtil.isBlank(message)) {
+                sendError(session, "消息内容不能为空");
+                return;
+            }
+
+            // 构建聊天请求
+            ChatRequest.Builder requestBuilder = ChatRequest.builder().message(message);
+
+//            // 可选参数
+//            if (messageNode.has("systemPrompt") && !messageNode.get("systemPrompt").asText().trim().isEmpty()) {
+//                requestBuilder.systemPrompt(messageNode.get("systemPrompt").asText());
+//            }
+
+            if (StrUtil.isNotBlank(model)) {
+                requestBuilder.model(model);
+            }
+
+            ChatRequest chatRequest = requestBuilder.build();
+
+            // 验证AI提供商
+            AiProviderType providerType;
+            try {
+                providerType = AiProviderType.fromCode(provider);
+            } catch (IllegalArgumentException e) {
+                sendError(session, "不支持的AI提供商: " + provider);
+                return;
+            }
+
+            // 检查提供商是否可用
+            if (!aiChatService.isProviderAvailable(providerType)) {
+                sendError(session, providerType.getDisplayName() + " 服务当前不可用");
+                return;
+            }
+
+            // 开始流式聊天
+            startStreamChat(session, sessionId, providerType, chatRequest);
+
+        } catch (Exception e) {
+            logger.error("处理聊天请求失败: {}", sessionId, e);
+            sendError(session, "请求处理失败: " + e.getMessage());
+        }
+    }
+
+    /**
+     * 开始流式聊天
+     */
+    private void startStreamChat(WebSocketSession session, String sessionId,
+                                 AiProviderType providerType, ChatRequest chatRequest) {
+
+        logger.debug("开始流式聊天: {} - 提供商: {}", sessionId, providerType.getDisplayName());
+
+        // 订阅流式响应 - 修复方法调用
+        Disposable subscription = aiChatService.chatStream(providerType, chatRequest)
+                .doOnSubscribe(sub -> {
+                    logger.debug("流式聊天订阅成功: {}", sessionId);
+                })
+                .subscribe(
+                        // 处理每个数据块 - 修复方法调用
+                        chunk -> {
+                            if (chunk != null && !chunk.trim().isEmpty()) {
+                                sendMessage(session, WebSocketMessage.chatChunk(chunk));
+                            }
+                        },
+                        // 处理错误
+                        error -> {
+                            logger.error("流式聊天错误: {} - 提供商: {}", sessionId, providerType.getDisplayName(), error);
+                            sendError(session, "聊天失败: " + error.getMessage());
+                            sessionManager.cancelStreamSubscription(sessionId);
+                        },
+                        // 处理完成 - 修复方法调用
+                        () -> {
+                            logger.debug("流式聊天完成: {} - 提供商: {}", sessionId, providerType.getDisplayName());
+                            sendMessage(session, WebSocketMessage.chatComplete());
+                            sessionManager.cancelStreamSubscription(sessionId);
+                        }
+                );
+
+        // 保存订阅以便后续取消
+        sessionManager.setStreamSubscription(sessionId, subscription);
+    }
+
+    /**
+     * 处理取消聊天
+     */
+    private void handleCancelChat(WebSocketSession session) {
+        String sessionId = session.getId();
+
+        if (!sessionManager.isChatting(sessionId)) {
+            sendMessage(session, WebSocketMessage.systemMessage("当前没有进行中的聊天"));
+            return;
+        }
+
+        sessionManager.cancelStreamSubscription(sessionId);
+        sendMessage(session, WebSocketMessage.systemMessage("聊天已取消"));
+        logger.debug("用户取消聊天: {}", sessionId);
+    }
+
+    /**
+     * 处理提供商切换 - 修复方法调用
+     */
+    private void handleProviderSwitch(WebSocketSession session,String provider) {
+        try {
+            if (StrUtil.isBlank(provider)) {
+                sendError(session, "缺少provider参数");
+                return;
+            }
+
+            AiProviderType providerType = AiProviderType.fromCode(provider);
+
+            boolean available = aiChatService.isProviderAvailable(providerType);
+
+            // 修复方法调用
+            sendMessage(session, WebSocketMessage.systemMessage(Map.of(
+                    "event", "provider_status",
+                    "provider", provider,
+                    "available", available,
+                    "models", aiChatService.getSupportedModels(providerType)
+            )));
+
+        } catch (IllegalArgumentException e) {
+            sendError(session, "无效的提供商: " + provider);
+        } catch (Exception e) {
+            sendError(session, "提供商切换失败: " + e.getMessage());
+        }
+    }
+
+    /**
+     * 处理连接测试 - 修复方法调用
+     */
+    private void handleConnectionTest(WebSocketSession session) {
+        sendMessage(session, WebSocketMessage.systemMessage(Map.of(
+                "event", "connection_test",
+                "status", "ok",
+                "timestamp", System.currentTimeMillis(),
+                "message", "连接正常"
+        )));
+    }
+
+    /**
+     * 处理心跳包
+     */
+    private void handlePing(WebSocketSession session) {
+        sendMessage(session, WebSocketMessage.pong());
+        logger.debug("心跳响应: {}", session.getId());
+    }
+
+    /**
+     * 发送消息到WebSocket客户端
+     */
+    private void sendMessage(WebSocketSession session, WebSocketMessage message) {
+        if (session == null || !session.isOpen()) {
+            logger.warn("尝试向已关闭的WebSocket会话发送消息: {}",
+                    session != null ? session.getId() : "null");
+            return;
+        }
+
+        try {
+            String json = objectMapper.writeValueAsString(message);
+            session.sendMessage(new TextMessage(json));
+            logger.trace("发送WebSocket消息成功: {} - {}", session.getId(), message.getType());
+        } catch (IOException e) {
+            logger.error("发送WebSocket消息失败: {}", session.getId(), e);
+            // 如果发送失败,可能连接已断开,执行清理
+            cleanup(session.getId());
+        } catch (Exception e) {
+            logger.error("序列化WebSocket消息失败: {}", session.getId(), e);
+        }
+    }
+
+    /**
+     * 发送错误消息
+     */
+    private void sendError(WebSocketSession session, String errorMessage) {
+        logger.debug("发送错误消息: {} - {}", session.getId(), errorMessage);
+        sendMessage(session, WebSocketMessage.chatError(errorMessage));
+    }
+
+    /**
+     * 从WebSocket会话中提取用户ID
+     */
+    private String extractProvider(WebSocketSession session) {
+        try {
+            String query = session.getUri().getQuery();
+            if (query != null && query.contains("provider=")) {
+                String[] parts = query.split("provider=");
+                if (parts.length > 1) {
+                    String userId = parts[1].split("&")[0];
+                    return java.net.URLDecoder.decode(userId, "UTF-8");
+                }
+            }
+        }  catch (Exception e) {
+            logger.debug("not found provider : {}", session.getId(), e);
+        }
+        // 如果无法提取用户ID,使用会话ID生成一个
+        return "Storlead-Ai";
+    }
+
+    /**
+     * 从WebSocket会话中提取用户ID
+     */
+    private String extractModel(WebSocketSession session) {
+        try {
+            String query = session.getUri().getQuery();
+            if (query != null && query.contains("provider=")) {
+                String[] parts = query.split("provider=");
+                if (parts.length > 1) {
+                    String userId = parts[1].split("&")[0];
+                    return java.net.URLDecoder.decode(userId, "UTF-8");
+                }
+            }
+        } catch (Exception e) {
+            logger.debug("not found provider : {}", session.getId(), e);
+        }
+        // 如果无法提取用户ID,使用会话ID生成一个
+        return "all";
+    }
+
+
+
+    /**
+     * 从WebSocket会话中提取用户ID
+     */
+    private String extractUserId(WebSocketSession session) {
+        try {
+            String query = session.getUri().getQuery();
+            if (query != null && query.contains("userId=")) {
+                String[] parts = query.split("userId=");
+                if (parts.length > 1) {
+                    String userId = parts[1].split("&")[0];
+                    return java.net.URLDecoder.decode(userId, "UTF-8");
+                }
+            }
+        } catch (Exception e) {
+            logger.debug("提取用户ID失败: {}", session.getId(), e);
+        }
+
+        // 如果无法提取用户ID,使用会话ID生成一个
+        return "user_" + session.getId().substring(0, 8);
+    }
+
+    /**
+     * 清理会话资源
+     */
+    private void cleanup(String sessionId) {
+        try {
+            connectionManager.removeConnection(sessionId);
+            sessionManager.removeSession(sessionId);
+            sessionProvider.remove(sessionId);
+            sessionModel.remove(sessionId);
+            logger.debug("清理会话资源完成: {}", sessionId);
+        } catch (Exception e) {
+            logger.error("清理会话资源失败: {}", sessionId, e);
+        }
+    }
+
+    /**
+     * 广播消息给所有连接的客户端
+     */
+    public void broadcast(WebSocketMessage message) {
+        Map<String, WebSocketSession> connections = connectionManager.getAllConnections();
+        int successCount = 0;
+        int failCount = 0;
+
+        for (Map.Entry<String, WebSocketSession> entry : connections.entrySet()) {
+            try {
+                sendMessage(entry.getValue(), message);
+                successCount++;
+            } catch (Exception e) {
+                failCount++;
+                logger.warn("广播消息失败: {}", entry.getKey(), e);
+                // 清理失败的连接
+                cleanup(entry.getKey());
+            }
+        }
+
+        logger.debug("广播消息完成 - 成功: {}, 失败: {}", successCount, failCount);
+    }
+
+    /**
+     * 获取当前连接数
+     */
+    public int getConnectionCount() {
+        return connectionManager.getConnectionCount();
+    }
+
+    /**
+     * 获取会话统计信息
+     */
+    public Map<String, Object> getSessionStats() {
+        return Map.of(
+                "totalConnections", connectionManager.getConnectionCount(),
+                "sessionStats", sessionManager.getSessionStats()
+        );
+    }
+}

+ 19 - 0
storlead-ai-api/src/main/java/com/storlead/ai/websocket/MessageType.java

@@ -0,0 +1,19 @@
+package com.storlead.ai.websocket;
+
+
+/**
+ * WebSocket消息类型
+ */
+public enum MessageType {
+    // 客户端 -> 服务端
+    CHAT_REQUEST,       // 聊天请求
+    CANCEL_CHAT,        // 取消聊天
+    PING,              // 心跳包
+
+    // 服务端 -> 客户端
+    CHAT_CHUNK,        // 流式聊天数据块
+    CHAT_COMPLETE,     // 聊天完成
+    CHAT_ERROR,        // 聊天错误
+    SYSTEM_MESSAGE,    // 系统消息
+    PONG              // 心跳响应
+}

+ 41 - 0
storlead-ai-api/src/main/java/com/storlead/ai/websocket/WebSocketConfig.java

@@ -0,0 +1,41 @@
+package com.storlead.ai.websocket;
+
+import com.storlead.ai.websocket.AiChatWebSocketHandler;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.socket.config.annotation.*;
+
+/**
+ * WebSocket配置类
+ * 配置WebSocket端点和处理器
+ *
+ * @author your-name
+ * @since 1.0.0
+ */
+@Configuration
+@EnableWebSocket
+public class WebSocketConfig implements WebSocketConfigurer {
+
+    @Autowired
+    private AiChatWebSocketHandler aiChatWebSocketHandler;
+
+    String[] allowedOrigins = {
+            "http://localhost:18090",
+            "https://localhost:18090",
+            "http://127.0.0.1:18090",
+            "https://127.0.0.1:18090"
+    };
+
+    @Override
+    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
+        // 注册原生WebSocket处理器
+        registry.addHandler(aiChatWebSocketHandler, "/ws/ai/chat")
+                .setAllowedOrigins(allowedOrigins)  // 使用具体的域名列表
+                .withSockJS();           // 启用SockJS支持,提供备用传输
+
+        // 注册纯WebSocket处理器(不使用SockJS)
+        registry.addHandler(aiChatWebSocketHandler, "/ws/ai/chat/native")
+                .setAllowedOrigins(allowedOrigins)  // 使用具体的域名列表
+                .setAllowedOrigins("*");
+    }
+}

+ 82 - 0
storlead-ai-api/src/main/java/com/storlead/ai/websocket/WebSocketConnectionManager.java

@@ -0,0 +1,82 @@
+package com.storlead.ai.websocket;
+
+/**
+ * @program: storlead-ai-platform
+ * @description:
+ * @author: chenkq
+ * @create: 2025-09-22 15:01
+ */
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.stereotype.Component;
+import org.springframework.web.socket.WebSocketSession;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * WebSocket连接管理器
+ * 简化版本,只管理连接本身
+ */
+@Component
+public class WebSocketConnectionManager {
+
+    private static final Logger logger = LoggerFactory.getLogger(WebSocketConnectionManager.class);
+
+    // 存储所有WebSocket连接
+    private final Map<String, WebSocketSession> connections = new ConcurrentHashMap<>();
+
+    /**
+     * 添加连接
+     */
+    public void addConnection(String sessionId, WebSocketSession session) {
+        connections.put(sessionId, session);
+        logger.info("添加WebSocket连接: {} (总连接数: {})", sessionId, connections.size());
+    }
+
+    /**
+     * 移除连接
+     */
+    public void removeConnection(String sessionId) {
+        WebSocketSession removed = connections.remove(sessionId);
+        if (removed != null) {
+            logger.info("移除WebSocket连接: {} (剩余连接数: {})", sessionId, connections.size());
+        }
+    }
+
+    /**
+     * 获取连接
+     */
+    public WebSocketSession getConnection(String sessionId) {
+        return connections.get(sessionId);
+    }
+
+    /**
+     * 获取所有连接
+     */
+    public Map<String, WebSocketSession> getAllConnections() {
+        return new ConcurrentHashMap<>(connections);
+    }
+
+    /**
+     * 获取连接数量
+     */
+    public int getConnectionCount() {
+        return connections.size();
+    }
+
+    /**
+     * 检查连接是否存在且有效
+     */
+    public boolean isConnectionValid(String sessionId) {
+        WebSocketSession session = connections.get(sessionId);
+        return session != null && session.isOpen();
+    }
+
+    /**
+     * 清理所有连接
+     */
+    public void clearAll() {
+        connections.clear();
+        logger.info("清理所有WebSocket连接");
+    }
+}

+ 134 - 0
storlead-ai-api/src/main/java/com/storlead/ai/websocket/WebSocketMessage.java

@@ -0,0 +1,134 @@
+package com.storlead.ai.websocket;
+
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.storlead.ai.websocket.MessageType;
+
+import java.time.LocalDateTime;
+
+/**
+ * WebSocket消息类型
+ */
+
+/**
+ * WebSocket消息
+ */
+@JsonInclude(JsonInclude.Include.NON_NULL)
+public class WebSocketMessage {
+
+    private MessageType type;
+    private String content;
+    private Object data;
+    private String error;
+    private LocalDateTime timestamp;
+
+    // 私有构造函数,强制使用静态工厂方法
+    private WebSocketMessage() {
+        this.timestamp = LocalDateTime.now();
+    }
+
+    private WebSocketMessage(MessageType type) {
+        this();
+        this.type = type;
+    }
+
+    private WebSocketMessage(MessageType type, String content) {
+        this();
+        this.type = type;
+        this.content = content;
+    }
+
+    private WebSocketMessage(MessageType type, Object data) {
+        this();
+        this.type = type;
+        this.data = data;
+    }
+
+    // 静态工厂方法 - 聊天相关
+    public static WebSocketMessage chatChunk(String content) {
+        return new WebSocketMessage(MessageType.CHAT_CHUNK, content);
+    }
+
+    public static WebSocketMessage chatComplete() {
+        return new WebSocketMessage(MessageType.CHAT_COMPLETE);
+    }
+
+    public static WebSocketMessage chatError(String errorMessage) {
+        WebSocketMessage msg = new WebSocketMessage(MessageType.CHAT_ERROR);
+        msg.setError(errorMessage);
+        return msg;
+    }
+
+    // 静态工厂方法 - 系统相关
+    public static WebSocketMessage systemMessage(String message) {
+        return new WebSocketMessage(MessageType.SYSTEM_MESSAGE, message);
+    }
+
+    public static WebSocketMessage systemMessage(Object data) {
+        return new WebSocketMessage(MessageType.SYSTEM_MESSAGE, data);
+    }
+
+    // 兼容旧代码的别名方法
+    public static WebSocketMessage systemInfo(Object data) {
+        return systemMessage(data);
+    }
+
+    public static WebSocketMessage providerStatus(Object data) {
+        return systemMessage(data);
+    }
+
+    public static WebSocketMessage pong() {
+        return new WebSocketMessage(MessageType.PONG, "pong");
+    }
+
+    // Getters and Setters
+    public MessageType getType() {
+        return type;
+    }
+
+    public void setType(MessageType type) {
+        this.type = type;
+    }
+
+    public String getContent() {
+        return content;
+    }
+
+    public void setContent(String content) {
+        this.content = content;
+    }
+
+    public Object getData() {
+        return data;
+    }
+
+    public void setData(Object data) {
+        this.data = data;
+    }
+
+    public String getError() {
+        return error;
+    }
+
+    public void setError(String error) {
+        this.error = error;
+    }
+
+    public LocalDateTime getTimestamp() {
+        return timestamp;
+    }
+
+    public void setTimestamp(LocalDateTime timestamp) {
+        this.timestamp = timestamp;
+    }
+
+    @Override
+    public String toString() {
+        return "WebSocketMessage{" +
+                "type=" + type +
+                ", content='" + content + '\'' +
+                ", data=" + data +
+                ", error='" + error + '\'' +
+                ", timestamp=" + timestamp +
+                '}';
+    }
+}

+ 140 - 0
storlead-ai-api/src/main/java/com/storlead/ai/websocket/WebSocketSessionManager.java

@@ -0,0 +1,140 @@
+package com.storlead.ai.websocket;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.stereotype.Component;
+import reactor.core.Disposable;
+
+import java.time.LocalDateTime;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+
+/**
+ * @program: 管理聊天会话状态和流式响应订阅
+ * @description: WebSocket会话管理器
+ * @author: chenkq
+ * @create: 2025-09-22 15:01
+ */
+
+
+@Component
+public class WebSocketSessionManager {
+
+    private static final Logger logger = LoggerFactory.getLogger(WebSocketSessionManager.class);
+
+    // 存储会话的流式响应订阅
+    private final Map<String, Disposable> streamSubscriptions = new ConcurrentHashMap<>();
+
+    // 存储会话的基本信息
+    private final Map<String, SessionInfo> sessionInfos = new ConcurrentHashMap<>();
+
+    /**
+     * 创建会话信息
+     */
+    public void createSession(String sessionId, String userId) {
+        SessionInfo info = new SessionInfo(sessionId, userId, LocalDateTime.now());
+        sessionInfos.put(sessionId, info);
+        logger.debug("创建会话信息: {} (用户: {})", sessionId, userId);
+    }
+
+    /**
+     * 设置流式订阅
+     */
+    public void setStreamSubscription(String sessionId, Disposable subscription) {
+        // 取消之前的订阅
+        cancelStreamSubscription(sessionId);
+
+        streamSubscriptions.put(sessionId, subscription);
+        updateSessionStatus(sessionId, true);
+        logger.debug("设置流式订阅: {}", sessionId);
+    }
+
+    /**
+     * 取消流式订阅
+     */
+    public void cancelStreamSubscription(String sessionId) {
+        Disposable subscription = streamSubscriptions.remove(sessionId);
+        if (subscription != null && !subscription.isDisposed()) {
+            subscription.dispose();
+            logger.debug("取消流式订阅: {}", sessionId);
+        }
+        updateSessionStatus(sessionId, false);
+    }
+
+    /**
+     * 更新会话状态
+     */
+    private void updateSessionStatus(String sessionId, boolean isChatting) {
+        SessionInfo info = sessionInfos.get(sessionId);
+        if (info != null) {
+            info.isChatting = isChatting;
+            info.lastActivity = LocalDateTime.now();
+        }
+    }
+
+    /**
+     * 移除会话
+     */
+    public void removeSession(String sessionId) {
+        cancelStreamSubscription(sessionId);
+        sessionInfos.remove(sessionId);
+        logger.debug("移除会话: {}", sessionId);
+    }
+
+    /**
+     * 获取会话信息
+     */
+    public SessionInfo getSessionInfo(String sessionId) {
+        return sessionInfos.get(sessionId);
+    }
+
+    /**
+     * 检查会话是否正在聊天
+     */
+    public boolean isChatting(String sessionId) {
+        SessionInfo info = sessionInfos.get(sessionId);
+        return info != null && info.isChatting;
+    }
+
+    /**
+     * 获取所有会话信息
+     */
+    public Map<String, SessionInfo> getAllSessionInfos() {
+        return new ConcurrentHashMap<>(sessionInfos);
+    }
+
+    /**
+     * 获取会话统计
+     */
+    public Map<String, Object> getSessionStats() {
+        int totalSessions = sessionInfos.size();
+        long chattingSessions = sessionInfos.values().stream()
+                .mapToLong(info -> info.isChatting ? 1 : 0)
+                .sum();
+
+        return Map.of(
+                "totalSessions", totalSessions,
+                "chattingSessions", chattingSessions,
+                "idleSessions", totalSessions - chattingSessions
+        );
+    }
+
+    /**
+     * 会话信息类
+     */
+    public static class SessionInfo {
+        public final String sessionId;
+        public final String userId;
+        public final LocalDateTime createdAt;
+        public LocalDateTime lastActivity;
+        public boolean isChatting = false;
+
+        public SessionInfo(String sessionId, String userId, LocalDateTime createdAt) {
+            this.sessionId = sessionId;
+            this.userId = userId;
+            this.createdAt = createdAt;
+            this.lastActivity = createdAt;
+        }
+    }
+}

二进制
storlead-ai-api/target/classes/com/storlead/ai/controller/WebSocketEndpointController.class


二进制
storlead-ai-api/target/classes/com/storlead/ai/service/WebSocketService.class


二进制
storlead-ai-api/target/classes/com/storlead/ai/service/impl/OpenAiService.class


二进制
storlead-ai-api/target/classes/com/storlead/ai/websocket/AiChatWebSocketHandler.class


二进制
storlead-ai-api/target/classes/com/storlead/ai/websocket/MessageType.class


二进制
storlead-ai-api/target/classes/com/storlead/ai/websocket/WebSocketConfig.class


二进制
storlead-ai-api/target/classes/com/storlead/ai/websocket/WebSocketConnectionManager.class


二进制
storlead-ai-api/target/classes/com/storlead/ai/websocket/WebSocketMessage.class


二进制
storlead-ai-api/target/classes/com/storlead/ai/websocket/WebSocketSessionManager$SessionInfo.class


二进制
storlead-ai-api/target/classes/com/storlead/ai/websocket/WebSocketSessionManager.class