|
|
@@ -0,0 +1,360 @@
|
|
|
+package com.storlead.ai.service.impl;
|
|
|
+
|
|
|
+import com.fasterxml.jackson.databind.JsonNode;
|
|
|
+import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
+import com.storlead.ai.config.properties.SparkAiProperties;
|
|
|
+import com.storlead.ai.core.AiProviderType;
|
|
|
+import com.storlead.ai.core.AiService;
|
|
|
+import com.storlead.ai.core.ChatRequest;
|
|
|
+import com.storlead.ai.core.ChatResponse;
|
|
|
+import com.storlead.ai.exception.AiServiceException;
|
|
|
+import com.storlead.ai.util.HttpClientUtil;
|
|
|
+import org.apache.hc.core5.net.URIBuilder;
|
|
|
+import org.slf4j.Logger;
|
|
|
+import org.slf4j.LoggerFactory;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
+import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
|
|
+import org.springframework.core.io.buffer.DataBuffer;
|
|
|
+import org.springframework.core.io.buffer.DataBufferUtils;
|
|
|
+import org.springframework.stereotype.Service;
|
|
|
+import org.springframework.web.reactive.function.client.WebClient;
|
|
|
+import reactor.core.publisher.Flux;
|
|
|
+import reactor.core.publisher.Mono;
|
|
|
+
|
|
|
+import javax.annotation.PostConstruct;
|
|
|
+import javax.crypto.Mac;
|
|
|
+import javax.crypto.spec.SecretKeySpec;
|
|
|
+import java.net.URI;
|
|
|
+import java.net.URL;
|
|
|
+import java.nio.charset.StandardCharsets;
|
|
|
+import java.text.SimpleDateFormat;
|
|
|
+import java.util.*;
|
|
|
+
|
|
|
+/**
|
|
|
+ * @program: storlead-ai-platform
|
|
|
+ * @description:
|
|
|
+ * @author: chenkq
|
|
|
+ * @create: 2025-09-28 16:46
|
|
|
+ */
|
|
|
+@Service
|
|
|
+@ConditionalOnProperty(prefix = "ai.providers.sparkai", name = "enabled", havingValue = "true")
|
|
|
+public class SparkAiService implements AiService {
|
|
|
+
|
|
|
+ private static final Logger logger = LoggerFactory.getLogger(OpenAiService.class);
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private SparkAiProperties properties;
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private HttpClientUtil httpClientUtil;
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private ObjectMapper objectMapper;
|
|
|
+
|
|
|
+ private WebClient webClient;
|
|
|
+
|
|
|
+ @PostConstruct
|
|
|
+ public void init() {
|
|
|
+ this.webClient = httpClientUtil.createWebClient(
|
|
|
+ properties.getBaseUrl(),
|
|
|
+ properties.getTimeoutSeconds()
|
|
|
+ );
|
|
|
+ logger.info("OpenAI服务初始化完成 - baseUrl: {}, model: {}",
|
|
|
+ properties.getBaseUrl(), properties.getDefaultModel());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Mono<ChatResponse> chat(ChatRequest request) {
|
|
|
+ return Mono.fromCallable(() -> buildRequestBody(request))
|
|
|
+ .flatMap(requestBody -> {
|
|
|
+ return httpClientUtil.postForObject(
|
|
|
+ webClient,
|
|
|
+ "/v1/chat/completions",
|
|
|
+ requestBody,
|
|
|
+ JsonNode.class,
|
|
|
+ httpClientUtil.createBearerAuthHeaders(properties.getApiKey()),
|
|
|
+ properties.getMaxRetries()
|
|
|
+ );
|
|
|
+ })
|
|
|
+ .map(this::parseResponse)
|
|
|
+ .onErrorMap(this::handleException);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Flux<String> chatStream(ChatRequest request) {
|
|
|
+ logger.debug("开始构建流式请求");
|
|
|
+
|
|
|
+ Map<String, Object> requestBody = buildStreamRequestBody(request);
|
|
|
+
|
|
|
+ return webClient.post()
|
|
|
+ .uri("/v1/chat/completions")
|
|
|
+ .headers(httpHeaders -> {
|
|
|
+ httpHeaders.setBearerAuth(properties.getApiKey());
|
|
|
+ httpHeaders.set("Accept", "text/event-stream");
|
|
|
+ })
|
|
|
+ .bodyValue(requestBody)
|
|
|
+ .retrieve()
|
|
|
+ .bodyToFlux(DataBuffer.class)
|
|
|
+ .map(dataBuffer -> {
|
|
|
+ byte[] bytes = new byte[dataBuffer.readableByteCount()];
|
|
|
+ dataBuffer.read(bytes);
|
|
|
+ DataBufferUtils.release(dataBuffer);
|
|
|
+ return new String(bytes, StandardCharsets.UTF_8);
|
|
|
+ })
|
|
|
+ .flatMapIterable(data -> Arrays.asList(data.split("\n")))
|
|
|
+ .filter(line -> !line.trim().isEmpty())
|
|
|
+ .filter(line -> line.startsWith("data: ") && !line.equals("data: [DONE]"))
|
|
|
+ .map(line -> line.substring(6).trim())
|
|
|
+ .filter(this::isValidJson) // 添加JSON有效性检查
|
|
|
+ .mapNotNull(this::parseStreamChunk)
|
|
|
+ .onErrorMap(this::handleException);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 检查JSON是否有效且完整
|
|
|
+ */
|
|
|
+ private boolean isValidJson(String json) {
|
|
|
+ if (json.trim().isEmpty()) return false;
|
|
|
+ try {
|
|
|
+ objectMapper.readTree(json);
|
|
|
+ return true;
|
|
|
+ } catch (Exception e) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public AiProviderType getProviderType() {
|
|
|
+ return AiProviderType.OPENAI;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public boolean isAvailable() {
|
|
|
+ try {
|
|
|
+ // 简单的健康检查 - 发送一个极简的请求
|
|
|
+ Map<String, Object> testRequest = Map.of(
|
|
|
+ "model", properties.getDefaultModel(),
|
|
|
+ "messages", List.of(Map.of("role", "user", "content", "test")),
|
|
|
+ "max_tokens", 1
|
|
|
+ );
|
|
|
+
|
|
|
+ return httpClientUtil.postForObject(
|
|
|
+ webClient,
|
|
|
+ "/v1/chat/completions",
|
|
|
+ testRequest,
|
|
|
+ JsonNode.class,
|
|
|
+ httpClientUtil.createBearerAuthHeaders(properties.getApiKey()),
|
|
|
+ 1
|
|
|
+ ).block() != null;
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.warn("OpenAI服务健康检查失败: {}", e.getMessage());
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public List<String> getSupportedModels() {
|
|
|
+ return properties.getSupportedModels();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String getDefaultModel() {
|
|
|
+ return properties.getDefaultModel();
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 构建请求体
|
|
|
+ */
|
|
|
+ private Map<String, Object> buildRequestBody(ChatRequest request) {
|
|
|
+ Map<String, Object> requestBody = new HashMap<>();
|
|
|
+
|
|
|
+ // 基础参数
|
|
|
+ requestBody.put("model", getModelName(request));
|
|
|
+ requestBody.put("messages", buildMessages(request));
|
|
|
+
|
|
|
+ // 可选参数
|
|
|
+ Map<String, Object> params = request.getParameters();
|
|
|
+ if (params != null) {
|
|
|
+ requestBody.put("temperature", params.getOrDefault("temperature", properties.getDefaultTemperature()));
|
|
|
+ requestBody.put("max_tokens", params.getOrDefault("max_tokens", properties.getDefaultMaxTokens()));
|
|
|
+
|
|
|
+ if (params.containsKey("top_p")) {
|
|
|
+ requestBody.put("top_p", params.get("top_p"));
|
|
|
+ }
|
|
|
+ if (params.containsKey("frequency_penalty")) {
|
|
|
+ requestBody.put("frequency_penalty", params.get("frequency_penalty"));
|
|
|
+ }
|
|
|
+ if (params.containsKey("presence_penalty")) {
|
|
|
+ requestBody.put("presence_penalty", params.get("presence_penalty"));
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ requestBody.put("temperature", properties.getDefaultTemperature());
|
|
|
+ requestBody.put("max_tokens", properties.getDefaultMaxTokens());
|
|
|
+ }
|
|
|
+
|
|
|
+ return requestBody;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 构建流式请求体
|
|
|
+ */
|
|
|
+ private Map<String, Object> buildStreamRequestBody(ChatRequest request) {
|
|
|
+ Map<String, Object> requestBody = buildRequestBody(request);
|
|
|
+ requestBody.put("stream", true);
|
|
|
+ return requestBody;
|
|
|
+ }
|
|
|
+
|
|
|
+ public static String getAuthUrl(String hostUrl, String apiKey, String apiSecret) throws Exception {
|
|
|
+ URL url = new URL(hostUrl);
|
|
|
+ // 时间
|
|
|
+ SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
|
|
|
+ format.setTimeZone(TimeZone.getTimeZone("GMT"));
|
|
|
+ String date = format.format(new Date());
|
|
|
+ // 拼接
|
|
|
+ String preStr = "host: " + url.getHost() + "\n" +
|
|
|
+ "date: " + date + "\n" +
|
|
|
+ "GET " + url.getPath() + " HTTP/1.1";
|
|
|
+ // System.err.println(preStr);
|
|
|
+ // SHA256加密
|
|
|
+ Mac mac = Mac.getInstance("hmacsha256");
|
|
|
+ SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256");
|
|
|
+ mac.init(spec);
|
|
|
+
|
|
|
+ byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));
|
|
|
+ // Base64加密
|
|
|
+ String sha = Base64.getEncoder().encodeToString(hexDigits);
|
|
|
+ // System.err.println(sha);
|
|
|
+ // 拼接
|
|
|
+ String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha);
|
|
|
+ // 拼接地址
|
|
|
+ URI uri = new URIBuilder()
|
|
|
+ .setScheme("https")
|
|
|
+ .setHost(url.getHost())
|
|
|
+ .setPath(url.getPath())
|
|
|
+ .addParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8)))
|
|
|
+ .addParameter("date", date)
|
|
|
+ .addParameter("host", url.getHost())
|
|
|
+ .build();
|
|
|
+
|
|
|
+ String finalUrl = uri.toString();
|
|
|
+
|
|
|
+ return finalUrl;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 构建消息列表
|
|
|
+ */
|
|
|
+ private List<Map<String, String>> buildMessages(ChatRequest request) {
|
|
|
+ List<Map<String, String>> messages = new java.util.ArrayList<>();
|
|
|
+
|
|
|
+ // 系统提示词
|
|
|
+ if (request.getSystemPrompt() != null && !request.getSystemPrompt().trim().isEmpty()) {
|
|
|
+ messages.add(Map.of("role", "system", "content", request.getSystemPrompt()));
|
|
|
+ }
|
|
|
+
|
|
|
+ // 用户消息
|
|
|
+ messages.add(Map.of("role", "user", "content", request.getMessage()));
|
|
|
+
|
|
|
+ return messages;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 获取模型名称
|
|
|
+ */
|
|
|
+ private String getModelName(ChatRequest request) {
|
|
|
+ if (request.getModel() != null && !request.getModel().trim().isEmpty()) {
|
|
|
+ return request.getModel();
|
|
|
+ }
|
|
|
+ return properties.getDefaultModel();
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 解析响应
|
|
|
+ */
|
|
|
+ private ChatResponse parseResponse(JsonNode responseJson) {
|
|
|
+ try {
|
|
|
+ if (responseJson.has("error")) {
|
|
|
+ JsonNode error = responseJson.get("error");
|
|
|
+ String errorMessage = error.get("message").asText();
|
|
|
+ String errorCode = error.has("code") ? error.get("code").asText() : "OPENAI_ERROR";
|
|
|
+ throw AiServiceException.apiError(AiProviderType.OPENAI, errorMessage);
|
|
|
+ }
|
|
|
+
|
|
|
+ JsonNode choices = responseJson.get("choices");
|
|
|
+ if (choices == null || choices.isEmpty()) {
|
|
|
+ throw AiServiceException.apiError(AiProviderType.OPENAI, "响应中没有选择项");
|
|
|
+ }
|
|
|
+
|
|
|
+ JsonNode firstChoice = choices.get(0);
|
|
|
+ JsonNode message = firstChoice.get("message");
|
|
|
+ String content = message.get("content").asText();
|
|
|
+
|
|
|
+ String model = responseJson.has("model") ? responseJson.get("model").asText() : properties.getDefaultModel();
|
|
|
+
|
|
|
+ // 构建元数据
|
|
|
+ Map<String, Object> metadata = new HashMap<>();
|
|
|
+ if (responseJson.has("usage")) {
|
|
|
+ metadata.put("usage", responseJson.get("usage"));
|
|
|
+ }
|
|
|
+ if (firstChoice.has("finish_reason")) {
|
|
|
+ metadata.put("finishReason", firstChoice.get("finish_reason").asText());
|
|
|
+ }
|
|
|
+
|
|
|
+ ChatResponse response = ChatResponse.success(content, model);
|
|
|
+ response.setProviderType(AiProviderType.STROLEAAI);
|
|
|
+ response.setMetadata(metadata);
|
|
|
+
|
|
|
+ return response;
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.error("解析OpenAI响应失败", e);
|
|
|
+ throw AiServiceException.apiError(AiProviderType.OPENAI, "响应解析失败: " + e.getMessage());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 解析流式响应块
|
|
|
+ */
|
|
|
+ private String parseStreamChunk(String chunk) {
|
|
|
+ try {
|
|
|
+ JsonNode chunkJson = objectMapper.readTree(chunk);
|
|
|
+
|
|
|
+ if (chunkJson.has("choices")) {
|
|
|
+ JsonNode choices = chunkJson.get("choices");
|
|
|
+ if (!choices.isEmpty()) {
|
|
|
+ JsonNode firstChoice = choices.get(0);
|
|
|
+ if (firstChoice.has("delta")) {
|
|
|
+ JsonNode delta = firstChoice.get("delta");
|
|
|
+ if (delta.has("content")) {
|
|
|
+ String content = delta.get("content").asText();
|
|
|
+ // 只返回纯文本内容,不包含任何JSON结构
|
|
|
+ return content;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return null; // 如果没有content,返回null会被mapNotNull过滤掉
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.warn("解析流式响应块失败: {}", chunk, e);
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 异常处理
|
|
|
+ */
|
|
|
+ private Throwable handleException(Throwable throwable) {
|
|
|
+ String errorMessage = httpClientUtil.parseWebClientError(throwable);
|
|
|
+ logger.error("OpenAI API调用失败: {}", errorMessage, throwable);
|
|
|
+
|
|
|
+ if (errorMessage.contains("401")) {
|
|
|
+ return AiServiceException.authenticationError(AiProviderType.OPENAI);
|
|
|
+ } else if (errorMessage.contains("429")) {
|
|
|
+ return AiServiceException.rateLimitError(AiProviderType.OPENAI);
|
|
|
+ } else if (errorMessage.contains("quota")) {
|
|
|
+ return AiServiceException.quotaExceededError(AiProviderType.OPENAI);
|
|
|
+ } else {
|
|
|
+ return AiServiceException.networkError(AiProviderType.OPENAI, throwable);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|