diff --git a/yudao-module-llm/yudao-module-llm-biz/pom.xml b/yudao-module-llm/yudao-module-llm-biz/pom.xml
index f8a8a4c5d..e01a56f9b 100644
--- a/yudao-module-llm/yudao-module-llm-biz/pom.xml
+++ b/yudao-module-llm/yudao-module-llm-biz/pom.xml
@@ -125,6 +125,10 @@
poi-scratchpad
5.2.3
+
+ org.springframework
+ spring-webflux
+
diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java
index 6096a5850..f019cead0 100644
--- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java
+++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java
@@ -2,8 +2,12 @@ package cn.iocoder.yudao.module.llm.controller.admin.conversation;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
+import cn.iocoder.yudao.framework.common.exception.ServiceException;
+import cn.iocoder.yudao.module.llm.service.http.vo.ModelCompletionsReqVO;
import cn.iocoder.yudao.module.llm.service.http.vo.TextToImageReqVo;
import com.alibaba.fastjson.JSON;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.*;
import javax.annotation.Resource;
import org.springframework.validation.annotation.Validated;
@@ -15,8 +19,10 @@ import io.swagger.v3.oas.annotations.Operation;
import javax.validation.constraints.*;
import javax.validation.*;
import javax.servlet.http.*;
+import java.time.LocalDateTime;
import java.util.*;
import java.io.IOException;
+import java.util.concurrent.CompletableFuture;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
@@ -32,11 +38,13 @@ import static cn.iocoder.yudao.framework.apilog.core.enums.OperateTypeEnum.*;
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*;
import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO;
import cn.iocoder.yudao.module.llm.service.conversation.ConversationService;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@Tag(name = "管理后台 - 大模型对话记录")
@RestController
@RequestMapping("/llm/conversation")
@Validated
+@Slf4j
public class ConversationController {
@Resource
@@ -101,6 +109,31 @@ public class ConversationController {
public CommonResult chat(@Valid @RequestBody ChatReqVO chatReqVO) {
return success(conversationService.chat(chatReqVO));
}
+
+ /**
+ * 对话推理接口,使用 SSE 进行流式响应
+ * @param chatReqVO 对话请求对象
+ * @return SseEmitter 对象,用于流式发送响应
+ */
+ @PostMapping("/stream-chat")
+ public SseEmitter streamChat(@Valid @RequestBody ChatReqVO chatReqVO) {
+ log.info("收到对话推理请求,请求参数: {}", chatReqVO);
+ SseEmitter emitter = new SseEmitter();
+ try {
+ conversationService.chatStream(chatReqVO, emitter);
+ } catch (Exception e) {
+ log.error("处理对话推理请求时发生异常", e);
+ try {
+ emitter.completeWithError(e);
+ } catch (Exception ex) {
+ log.error("无法完成 SseEmitter 错误处理", ex);
+ }
+ }
+ log.info("返回 SseEmitter 对象,准备进行流式响应");
+ return emitter;
+ }
+
+
@PostMapping("/text-to-image")
@Operation(summary = "文字转图片接口")
public CommonResult textToImage(@Valid @RequestBody TextToImageReqVo req) {
diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java
index 7c582bb70..9e1388ee0 100644
--- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java
+++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java
@@ -9,7 +9,9 @@ import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*;
import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
+import cn.iocoder.yudao.module.llm.service.http.vo.ModelCompletionsReqVO;
import cn.iocoder.yudao.module.llm.service.http.vo.TextToImageReqVo;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* 大模型对话记录 Service 接口
@@ -64,4 +66,11 @@ public interface ConversationService {
* @return
*/
JSONArray textToImage(TextToImageReqVo req);
-}
\ No newline at end of file
+
+ /**
+ * 聊天流
+ * @param chatReqVO chatReqVO
+ * @param emitter emitter
+ */
+ void chatStream (@Valid ChatReqVO chatReqVO, SseEmitter emitter);
+}
diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java
index 4bbffedce..9ad223380 100644
--- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java
+++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java
@@ -1,5 +1,7 @@
package cn.iocoder.yudao.module.llm.service.conversation;
+import cn.hutool.core.bean.BeanUtil;
+import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONArray;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.http.HttpUtils;
@@ -32,11 +34,16 @@ import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
+import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.StringRedisTemplate;
+import org.springframework.http.MediaType;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
+import org.springframework.web.reactive.function.client.WebClient;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
+import java.io.IOException;
import java.util.*;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@@ -49,6 +56,7 @@ import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*;
*/
@Service
@Validated
+@Slf4j
public class ConversationServiceImpl implements ConversationService {
@Resource
@@ -130,19 +138,11 @@ public class ConversationServiceImpl implements ConversationService {
ApplicationSaveReqVO applicationSaveReqVO = BeanUtils.toBean(application, ApplicationSaveReqVO.class);
applicationService.updateApplication(applicationSaveReqVO);
}
- /*Long promptId = application.getPromptId();
- PromptTemplatesRespVO promptTemplates = promptTemplatesService.getPromptTemplates(promptId);
- if(promptTemplates != null){
- chatReqVO.setSystemPrompt(promptTemplates.getTemplateText());
- }*/
+
chatReqVO.setSystemPrompt(application.getPrompt());
}
}
- /* if (Objects.equals(1, chatReqVO.getModelType())){
- return publicModelChat(chatReqVO);
- }else {
- return privateModelChat(chatReqVO);
- }*/
+
return publicModelChat(chatReqVO);
}
@@ -273,9 +273,171 @@ public class ConversationServiceImpl implements ConversationService {
dataRefluxDataSaveReqVO.setPrompt(chatReqVO.getPrompt());
dataRefluxDataSaveReqVO.setResponse(modelCompletionsRespVO.getAnswer());
dataRefluxDataSaveReqVO.setSystem(modelCompletionsRespVO.getSystem());
+ dataRefluxDataSaveReqVO.setMaxTokens(chatReqVO.getMaxTokens());
+ dataRefluxDataSaveReqVO.setTemperature(chatReqVO.getTemperature());
dataRefluxDataService.saveDataRefluxData(dataRefluxDataSaveReqVO);
return chatRespVO;
}
+
+ /**
+ * 处理对话请求,以流式方式返回结果
+ * @param chatReqVO 对话请求对象
+ * @param emitter SseEmitter 对象,用于流式发送响应
+ */
+ public void chatStream(ChatReqVO chatReqVO, SseEmitter emitter) {
+ log.info("开始处理对话请求,请求参数: {}", chatReqVO);
+ // 检查系统提示信息,如果为空则尝试从应用中获取
+ if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().equals("")) {
+ if (chatReqVO.getApplicationId() != null) {
+ log.info("系统提示信息为空,尝试从应用中获取,应用 ID: {}", chatReqVO.getApplicationId());
+ ApplicationRespVO application = applicationService.getApplication(chatReqVO.getApplicationId());
+ List messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
+ if (CollectionUtils.isEmpty(messageHistoryList)) {
+ log.info("聊天历史记录为空,更新应用聊天计数");
+ application.setChatCount(application.getChatCount() + 1);
+ ApplicationSaveReqVO applicationSaveReqVO = BeanUtil.toBean(application, ApplicationSaveReqVO.class);
+ applicationService.updateApplication(applicationSaveReqVO);
+ }
+ chatReqVO.setSystemPrompt(application.getPrompt());
+ log.info("已更新系统提示信息为: {}", chatReqVO.getSystemPrompt());
+ }
+ }
+ // 调用公共模型聊天流式处理方法
+ publicModelChatStream(chatReqVO, emitter);
+ }
+
+ /**
+ * 公共模型聊天流式处理方法
+ * @param chatReqVO 对话请求对象
+ * @param emitter SseEmitter 对象,用于流式发送响应
+ */
+ public void publicModelChatStream(ChatReqVO chatReqVO, SseEmitter emitter) {
+ log.info("开始公共模型聊天流式处理,请求参数: {}", chatReqVO);
+ // 检查 UUID 是否为空,若为空则生成一个
+ if (StrUtil.isBlank(chatReqVO.getUuid())) {
+ log.info("UUID 为空,生成新的 UUID");
+ chatReqVO.setUuid(UUID.randomUUID().toString());
+ }
+ String model = null;
+ String selfModelUrl = "";
+ // 根据模型类型获取模型信息
+ if (Objects.equals(1, chatReqVO.getModelType())) {
+ log.info("使用预制模型,模型 ID: {}", chatReqVO.getModelId());
+ BaseModelDO baseModelDO = baseModelService.getBaseModel(chatReqVO.getModelId());
+ if (baseModelDO == null) {
+ log.error("预制模型不存在,模型 ID: {}", chatReqVO.getModelId());
+ try {
+ emitter.completeWithError(new RuntimeException("BASE_MODEL_NOT_EXISTS"));
+ } catch (Exception e) {
+ log.error("无法完成 SseEmitter 错误处理", e);
+ }
+ return;
+ }
+ selfModelUrl = baseModelDO.getChatUrl();
+ model = baseModelDO.getAigcModelName();
+ log.info("获取到预制模型信息,模型名称: {}, 聊天 URL: {}", model, selfModelUrl);
+ } else if (Objects.equals(0, chatReqVO.getModelType())) {
+ log.info("使用自定义模型,模型 ID: {}", chatReqVO.getModelId());
+ ModelServiceDO modelServiceDO = modelServiceMapper.selectById(chatReqVO.getModelId());
+ if (modelServiceDO == null) {
+ log.error("自定义模型服务不存在,模型 ID: {}", chatReqVO.getModelId());
+ try {
+ emitter.completeWithError(new RuntimeException("MODEL_SERVICE_NOT_EXISTS"));
+ } catch (Exception e) {
+ log.error("无法完成 SseEmitter 错误处理", e);
+ }
+ return;
+ }
+ model = modelServiceDO.getBaseModelName();
+ selfModelUrl = modelServiceDO.getModelUrl();
+ log.info("获取到自定义模型信息,模型名称: {}, 模型 URL: {}", model, selfModelUrl);
+ } else {
+ log.error("无效的模型类型,模型类型: {}", chatReqVO.getModelType());
+ try {
+ emitter.completeWithError(new RuntimeException("BASE_MODEL_NOT_EXISTS"));
+ } catch (Exception e) {
+ log.error("无法完成 SseEmitter 错误处理", e);
+ }
+ return;
+ }
+
+ List messages = new ArrayList<>();
+
+ // 如果知识库 ID 不为空,先调用知识库获取相关信息
+ StringBuilder knowledgeBase = new StringBuilder();
+ if (chatReqVO.getKnowledge() != null && chatReqVO.getKnowledge() != 0) {
+ log.info("知识库 ID 不为空,开始查询知识库,知识库 ID: {}", chatReqVO.getKnowledge());
+ LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>();
+ queryWrapper.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, chatReqVO.getKnowledge());
+ List fileList = knowledgeDocumentsMapper.selectList(queryWrapper);
+ for (KnowledgeDocumentsDO knowledgeDocumentsDO : fileList) {
+ Long id = knowledgeDocumentsDO.getId();
+ KnowledgeRagEmbedQueryVO knowledgeRagEmbedQueryVO = new KnowledgeRagEmbedQueryVO();
+ knowledgeRagEmbedQueryVO.setFile_id(id.toString());
+ knowledgeRagEmbedQueryVO.setQuery(chatReqVO.getPrompt());
+ String result = HttpUtils.post(llmBackendProperties.getEmbedQuery(), null, JSON.toJSONString(knowledgeRagEmbedQueryVO));
+ com.alibaba.fastjson.JSONArray jsonArray = JSON.parseArray(result);
+ if (jsonArray != null && !jsonArray.isEmpty()) {
+ JSONArray jsonArray1 = (JSONArray) jsonArray.get(0);
+ JSONObject jsonObject = (JSONObject) jsonArray1.get(0);
+ knowledgeBase.append(jsonObject.get("page_content"));
+ }
+ }
+ log.info("知识库查询完成,获取到的信息: {}", knowledgeBase.toString());
+ }
+ String mess = chatReqVO.getSystemPrompt() + "" + knowledgeBase.toString() + "";
+ // 查询历史记录消息,并将查询出来的知识信息放入到 role = system 的消息中
+ List messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
+ if (messageHistoryList != null && !messageHistoryList.isEmpty()) {
+ log.info("存在聊天历史记录,处理历史记录消息");
+ for (String messageHistory : messageHistoryList) {
+ ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class);
+ if (modelCompletionsMessage.getRole().equals("system")) {
+ modelCompletionsMessage.setContent(mess);
+ stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
+ }
+ messages.add(modelCompletionsMessage);
+ }
+ } else {
+ log.info("不存在聊天历史记录,创建新的系统消息");
+ ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
+ systemMessage.setRole("system");
+ systemMessage.setContent(mess);
+ stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
+ messages.add(systemMessage);
+ }
+
+ // 创建用户消息
+ ModelCompletionsReqVO.ModelCompletionsMessage message = new ModelCompletionsReqVO.ModelCompletionsMessage();
+ message.setRole("user");
+ message.setContent(chatReqVO.getPrompt());
+ messages.add(message);
+
+ // 构建模型补全请求对象
+ ModelCompletionsReqVO modelCompletionsReqVO = new ModelCompletionsReqVO();
+ modelCompletionsReqVO.setMessages(messages);
+ modelCompletionsReqVO.setModel(model);
+ log.info("构建模型补全请求对象,请求参数: {}", modelCompletionsReqVO);
+
+ // 调用模型服务进行流式处理
+ modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter);
+
+ // 将用户消息存入缓存
+ stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(message));
+
+ // 保存数据回流信息
+ DataRefluxDataSaveReqVO dataRefluxDataSaveReqVO = new DataRefluxDataSaveReqVO();
+ dataRefluxDataSaveReqVO.setModelServiceId(chatReqVO.getModelId());
+ dataRefluxDataSaveReqVO.setModelType(chatReqVO.getModelType());
+ dataRefluxDataSaveReqVO.setPrompt(chatReqVO.getPrompt());
+ dataRefluxDataSaveReqVO.setSystem("助手");
+ dataRefluxDataSaveReqVO.setMaxTokens(chatReqVO.getMaxTokens());
+ dataRefluxDataSaveReqVO.setTemperature(chatReqVO.getTemperature());
+ dataRefluxDataService.saveDataRefluxData(dataRefluxDataSaveReqVO);
+ log.info("数据回流信息保存完成");
+ }
+
+
/**
* 私有模型聊天
* @param chatReqVO
diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java
index 017daec5a..1eff7fe69 100644
--- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java
+++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java
@@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.llm.service.http;
+import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.util.http.HttpUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.llm.dal.dataobject.servername.ServerNameDO;
@@ -12,11 +13,14 @@ import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
+import java.net.HttpURLConnection;
+import java.net.URL;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -129,6 +133,83 @@ public class ModelService {
throw new RuntimeException("模型补全请求处理失败", e);
}
}
+ /**
+ * 模型补全流式处理方法
+ * @param url 模型服务的 URL
+ * @param req 模型补全请求对象
+ * @param emitter SseEmitter 对象,用于流式发送响应
+ */
+ public void modelCompletionsStream(String url, ModelCompletionsReqVO req, SseEmitter emitter) {
+ log.info("开始处理模型补全请求,请求参数: {}", req);
+ // 检查模型是否为空,若为空则设置默认模型
+ if (StrUtil.isBlank(req.getModel())) {
+ log.info("模型 ID 为空,设置为默认模型: {}", DEFAULT_MODEL_ID);
+ req.setModel(DEFAULT_MODEL_ID);
+ }
+ // 记录请求信息
+ log.info("请求参数: {}", JSON.toJSONString(req));
+ try {
+ // 发起 HTTP POST 请求
+ URL targetUrl;
+ if (StrUtil.isBlank(url)) {
+ log.info("URL 为空,使用默认 URL: {}", llmBackendProperties.getModelCompletions());
+ targetUrl = new URL(llmBackendProperties.getModelCompletions());
+ } else {
+ log.info("使用指定 URL: {}", url);
+ targetUrl = new URL(url);
+ }
+ HttpURLConnection connection = (HttpURLConnection) targetUrl.openConnection();
+ connection.setRequestMethod("POST");
+ connection.setRequestProperty("Content-Type", "application/json");
+ connection.setDoOutput(true);
+ connection.getOutputStream().write(JSON.toJSONString(req).getBytes());
+
+ BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream()));
+ String line;
+ StringBuilder responseBuilder = new StringBuilder();
+ while ((line = reader.readLine()) != null) {
+ // 解析流式响应内容
+ if (StrUtil.isNotBlank(line)) {
+ log.info("收到流式响应内容: {}", line);
+ ChatCompletion chatCompletion = JSON.parseObject(line, ChatCompletion.class);
+ if (StrUtil.isBlank(chatCompletion.getDetail())) {
+ String respContent = chatCompletion.getChoices().get(0).getMessage().getContent();
+ String patternString = "(.*?)";
+ Pattern pattern = Pattern.compile(patternString, Pattern.DOTALL);
+ Matcher matcher = pattern.matcher(respContent);
+ String answerContent = matcher.replaceAll("");
+ // 流式发送数据
+ try {
+ emitter.send(SseEmitter.event().data(answerContent));
+ log.info("已发送流式响应数据: {}", answerContent);
+ } catch (Exception e) {
+ log.error("发送流式响应数据时发生异常", e);
+ try {
+ emitter.completeWithError(e);
+ } catch (Exception ex) {
+ log.error("无法完成 SseEmitter 错误处理", ex);
+ }
+ return;
+ }
+ }
+ }
+ }
+ // 完成流式传输
+ try {
+ emitter.complete();
+ log.info("流式传输完成");
+ } catch (Exception e) {
+ log.error("完成流式传输时发生异常", e);
+ }
+ } catch (Exception e) {
+ log.error("处理模型补全请求时发生异常", e);
+ try {
+ emitter.completeWithError(e);
+ } catch (Exception ex) {
+ log.error("无法完成 SseEmitter 错误处理", ex);
+ }
+ }
+ }
/**
* 通过 GPU 类型获取对应的主机地址