diff --git a/yudao-module-llm/yudao-module-llm-biz/pom.xml b/yudao-module-llm/yudao-module-llm-biz/pom.xml index f8a8a4c5d..e01a56f9b 100644 --- a/yudao-module-llm/yudao-module-llm-biz/pom.xml +++ b/yudao-module-llm/yudao-module-llm-biz/pom.xml @@ -125,6 +125,10 @@ poi-scratchpad 5.2.3 + + org.springframework + spring-webflux + diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java index 6096a5850..f019cead0 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java @@ -2,8 +2,12 @@ package cn.iocoder.yudao.module.llm.controller.admin.conversation; import cn.hutool.json.JSONArray; import cn.hutool.json.JSONObject; +import cn.iocoder.yudao.framework.common.exception.ServiceException; +import cn.iocoder.yudao.module.llm.service.http.vo.ModelCompletionsReqVO; import cn.iocoder.yudao.module.llm.service.http.vo.TextToImageReqVo; import com.alibaba.fastjson.JSON; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.MediaType; import org.springframework.web.bind.annotation.*; import javax.annotation.Resource; import org.springframework.validation.annotation.Validated; @@ -15,8 +19,10 @@ import io.swagger.v3.oas.annotations.Operation; import javax.validation.constraints.*; import javax.validation.*; import javax.servlet.http.*; +import java.time.LocalDateTime; import java.util.*; import java.io.IOException; +import java.util.concurrent.CompletableFuture; import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageResult; @@ -32,11 +38,13 @@ import static cn.iocoder.yudao.framework.apilog.core.enums.OperateTypeEnum.*; import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*; import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO; import cn.iocoder.yudao.module.llm.service.conversation.ConversationService; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @Tag(name = "管理后台 - 大模型对话记录") @RestController @RequestMapping("/llm/conversation") @Validated +@Slf4j public class ConversationController { @Resource @@ -101,6 +109,31 @@ public class ConversationController { public CommonResult chat(@Valid @RequestBody ChatReqVO chatReqVO) { return success(conversationService.chat(chatReqVO)); } + + /** + * 对话推理接口,使用 SSE 进行流式响应 + * @param chatReqVO 对话请求对象 + * @return SseEmitter 对象,用于流式发送响应 + */ + @PostMapping("/stream-chat") + public SseEmitter streamChat(@Valid @RequestBody ChatReqVO chatReqVO) { + log.info("收到对话推理请求,请求参数: {}", chatReqVO); + SseEmitter emitter = new SseEmitter(); + try { + conversationService.chatStream(chatReqVO, emitter); + } catch (Exception e) { + log.error("处理对话推理请求时发生异常", e); + try { + emitter.completeWithError(e); + } catch (Exception ex) { + log.error("无法完成 SseEmitter 错误处理", ex); + } + } + log.info("返回 SseEmitter 对象,准备进行流式响应"); + return emitter; + } + + @PostMapping("/text-to-image") @Operation(summary = "文字转图片接口") public CommonResult textToImage(@Valid @RequestBody TextToImageReqVo req) { diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java index 7c582bb70..9e1388ee0 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java @@ -9,7 +9,9 @@ import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*; import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageParam; +import cn.iocoder.yudao.module.llm.service.http.vo.ModelCompletionsReqVO; import cn.iocoder.yudao.module.llm.service.http.vo.TextToImageReqVo; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; /** * 大模型对话记录 Service 接口 @@ -64,4 +66,11 @@ public interface ConversationService { * @return */ JSONArray textToImage(TextToImageReqVo req); -} \ No newline at end of file + + /** + * 聊天流 + * @param chatReqVO chatReqVO + * @param emitter emitter + */ + void chatStream (@Valid ChatReqVO chatReqVO, SseEmitter emitter); +} diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java index 4bbffedce..9ad223380 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java @@ -1,5 +1,7 @@ package cn.iocoder.yudao.module.llm.service.conversation; +import cn.hutool.core.bean.BeanUtil; +import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONArray; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.http.HttpUtils; @@ -32,11 +34,16 @@ import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.toolkit.CollectionUtils; +import lombok.extern.slf4j.Slf4j; import org.springframework.data.redis.core.StringRedisTemplate; +import org.springframework.http.MediaType; import org.springframework.stereotype.Service; import org.springframework.validation.annotation.Validated; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; +import java.io.IOException; import java.util.*; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; @@ -49,6 +56,7 @@ import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*; */ @Service @Validated +@Slf4j public class ConversationServiceImpl implements ConversationService { @Resource @@ -130,19 +138,11 @@ public class ConversationServiceImpl implements ConversationService { ApplicationSaveReqVO applicationSaveReqVO = BeanUtils.toBean(application, ApplicationSaveReqVO.class); applicationService.updateApplication(applicationSaveReqVO); } - /*Long promptId = application.getPromptId(); - PromptTemplatesRespVO promptTemplates = promptTemplatesService.getPromptTemplates(promptId); - if(promptTemplates != null){ - chatReqVO.setSystemPrompt(promptTemplates.getTemplateText()); - }*/ + chatReqVO.setSystemPrompt(application.getPrompt()); } } - /* if (Objects.equals(1, chatReqVO.getModelType())){ - return publicModelChat(chatReqVO); - }else { - return privateModelChat(chatReqVO); - }*/ + return publicModelChat(chatReqVO); } @@ -273,9 +273,171 @@ public class ConversationServiceImpl implements ConversationService { dataRefluxDataSaveReqVO.setPrompt(chatReqVO.getPrompt()); dataRefluxDataSaveReqVO.setResponse(modelCompletionsRespVO.getAnswer()); dataRefluxDataSaveReqVO.setSystem(modelCompletionsRespVO.getSystem()); + dataRefluxDataSaveReqVO.setMaxTokens(chatReqVO.getMaxTokens()); + dataRefluxDataSaveReqVO.setTemperature(chatReqVO.getTemperature()); dataRefluxDataService.saveDataRefluxData(dataRefluxDataSaveReqVO); return chatRespVO; } + + /** + * 处理对话请求,以流式方式返回结果 + * @param chatReqVO 对话请求对象 + * @param emitter SseEmitter 对象,用于流式发送响应 + */ + public void chatStream(ChatReqVO chatReqVO, SseEmitter emitter) { + log.info("开始处理对话请求,请求参数: {}", chatReqVO); + // 检查系统提示信息,如果为空则尝试从应用中获取 + if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().equals("")) { + if (chatReqVO.getApplicationId() != null) { + log.info("系统提示信息为空,尝试从应用中获取,应用 ID: {}", chatReqVO.getApplicationId()); + ApplicationRespVO application = applicationService.getApplication(chatReqVO.getApplicationId()); + List messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1); + if (CollectionUtils.isEmpty(messageHistoryList)) { + log.info("聊天历史记录为空,更新应用聊天计数"); + application.setChatCount(application.getChatCount() + 1); + ApplicationSaveReqVO applicationSaveReqVO = BeanUtil.toBean(application, ApplicationSaveReqVO.class); + applicationService.updateApplication(applicationSaveReqVO); + } + chatReqVO.setSystemPrompt(application.getPrompt()); + log.info("已更新系统提示信息为: {}", chatReqVO.getSystemPrompt()); + } + } + // 调用公共模型聊天流式处理方法 + publicModelChatStream(chatReqVO, emitter); + } + + /** + * 公共模型聊天流式处理方法 + * @param chatReqVO 对话请求对象 + * @param emitter SseEmitter 对象,用于流式发送响应 + */ + public void publicModelChatStream(ChatReqVO chatReqVO, SseEmitter emitter) { + log.info("开始公共模型聊天流式处理,请求参数: {}", chatReqVO); + // 检查 UUID 是否为空,若为空则生成一个 + if (StrUtil.isBlank(chatReqVO.getUuid())) { + log.info("UUID 为空,生成新的 UUID"); + chatReqVO.setUuid(UUID.randomUUID().toString()); + } + String model = null; + String selfModelUrl = ""; + // 根据模型类型获取模型信息 + if (Objects.equals(1, chatReqVO.getModelType())) { + log.info("使用预制模型,模型 ID: {}", chatReqVO.getModelId()); + BaseModelDO baseModelDO = baseModelService.getBaseModel(chatReqVO.getModelId()); + if (baseModelDO == null) { + log.error("预制模型不存在,模型 ID: {}", chatReqVO.getModelId()); + try { + emitter.completeWithError(new RuntimeException("BASE_MODEL_NOT_EXISTS")); + } catch (Exception e) { + log.error("无法完成 SseEmitter 错误处理", e); + } + return; + } + selfModelUrl = baseModelDO.getChatUrl(); + model = baseModelDO.getAigcModelName(); + log.info("获取到预制模型信息,模型名称: {}, 聊天 URL: {}", model, selfModelUrl); + } else if (Objects.equals(0, chatReqVO.getModelType())) { + log.info("使用自定义模型,模型 ID: {}", chatReqVO.getModelId()); + ModelServiceDO modelServiceDO = modelServiceMapper.selectById(chatReqVO.getModelId()); + if (modelServiceDO == null) { + log.error("自定义模型服务不存在,模型 ID: {}", chatReqVO.getModelId()); + try { + emitter.completeWithError(new RuntimeException("MODEL_SERVICE_NOT_EXISTS")); + } catch (Exception e) { + log.error("无法完成 SseEmitter 错误处理", e); + } + return; + } + model = modelServiceDO.getBaseModelName(); + selfModelUrl = modelServiceDO.getModelUrl(); + log.info("获取到自定义模型信息,模型名称: {}, 模型 URL: {}", model, selfModelUrl); + } else { + log.error("无效的模型类型,模型类型: {}", chatReqVO.getModelType()); + try { + emitter.completeWithError(new RuntimeException("BASE_MODEL_NOT_EXISTS")); + } catch (Exception e) { + log.error("无法完成 SseEmitter 错误处理", e); + } + return; + } + + List messages = new ArrayList<>(); + + // 如果知识库 ID 不为空,先调用知识库获取相关信息 + StringBuilder knowledgeBase = new StringBuilder(); + if (chatReqVO.getKnowledge() != null && chatReqVO.getKnowledge() != 0) { + log.info("知识库 ID 不为空,开始查询知识库,知识库 ID: {}", chatReqVO.getKnowledge()); + LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(); + queryWrapper.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, chatReqVO.getKnowledge()); + List fileList = knowledgeDocumentsMapper.selectList(queryWrapper); + for (KnowledgeDocumentsDO knowledgeDocumentsDO : fileList) { + Long id = knowledgeDocumentsDO.getId(); + KnowledgeRagEmbedQueryVO knowledgeRagEmbedQueryVO = new KnowledgeRagEmbedQueryVO(); + knowledgeRagEmbedQueryVO.setFile_id(id.toString()); + knowledgeRagEmbedQueryVO.setQuery(chatReqVO.getPrompt()); + String result = HttpUtils.post(llmBackendProperties.getEmbedQuery(), null, JSON.toJSONString(knowledgeRagEmbedQueryVO)); + com.alibaba.fastjson.JSONArray jsonArray = JSON.parseArray(result); + if (jsonArray != null && !jsonArray.isEmpty()) { + JSONArray jsonArray1 = (JSONArray) jsonArray.get(0); + JSONObject jsonObject = (JSONObject) jsonArray1.get(0); + knowledgeBase.append(jsonObject.get("page_content")); + } + } + log.info("知识库查询完成,获取到的信息: {}", knowledgeBase.toString()); + } + String mess = chatReqVO.getSystemPrompt() + "" + knowledgeBase.toString() + ""; + // 查询历史记录消息,并将查询出来的知识信息放入到 role = system 的消息中 + List messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1); + if (messageHistoryList != null && !messageHistoryList.isEmpty()) { + log.info("存在聊天历史记录,处理历史记录消息"); + for (String messageHistory : messageHistoryList) { + ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class); + if (modelCompletionsMessage.getRole().equals("system")) { + modelCompletionsMessage.setContent(mess); + stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage)); + } + messages.add(modelCompletionsMessage); + } + } else { + log.info("不存在聊天历史记录,创建新的系统消息"); + ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage(); + systemMessage.setRole("system"); + systemMessage.setContent(mess); + stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage)); + messages.add(systemMessage); + } + + // 创建用户消息 + ModelCompletionsReqVO.ModelCompletionsMessage message = new ModelCompletionsReqVO.ModelCompletionsMessage(); + message.setRole("user"); + message.setContent(chatReqVO.getPrompt()); + messages.add(message); + + // 构建模型补全请求对象 + ModelCompletionsReqVO modelCompletionsReqVO = new ModelCompletionsReqVO(); + modelCompletionsReqVO.setMessages(messages); + modelCompletionsReqVO.setModel(model); + log.info("构建模型补全请求对象,请求参数: {}", modelCompletionsReqVO); + + // 调用模型服务进行流式处理 + modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter); + + // 将用户消息存入缓存 + stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(message)); + + // 保存数据回流信息 + DataRefluxDataSaveReqVO dataRefluxDataSaveReqVO = new DataRefluxDataSaveReqVO(); + dataRefluxDataSaveReqVO.setModelServiceId(chatReqVO.getModelId()); + dataRefluxDataSaveReqVO.setModelType(chatReqVO.getModelType()); + dataRefluxDataSaveReqVO.setPrompt(chatReqVO.getPrompt()); + dataRefluxDataSaveReqVO.setSystem("助手"); + dataRefluxDataSaveReqVO.setMaxTokens(chatReqVO.getMaxTokens()); + dataRefluxDataSaveReqVO.setTemperature(chatReqVO.getTemperature()); + dataRefluxDataService.saveDataRefluxData(dataRefluxDataSaveReqVO); + log.info("数据回流信息保存完成"); + } + + /** * 私有模型聊天 * @param chatReqVO diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java index 017daec5a..1eff7fe69 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java @@ -1,5 +1,6 @@ package cn.iocoder.yudao.module.llm.service.http; +import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.common.util.http.HttpUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.llm.dal.dataobject.servername.ServerNameDO; @@ -12,11 +13,14 @@ import com.alibaba.fastjson.JSONObject; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.stereotype.Service; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; +import java.net.HttpURLConnection; +import java.net.URL; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -129,6 +133,83 @@ public class ModelService { throw new RuntimeException("模型补全请求处理失败", e); } } + /** + * 模型补全流式处理方法 + * @param url 模型服务的 URL + * @param req 模型补全请求对象 + * @param emitter SseEmitter 对象,用于流式发送响应 + */ + public void modelCompletionsStream(String url, ModelCompletionsReqVO req, SseEmitter emitter) { + log.info("开始处理模型补全请求,请求参数: {}", req); + // 检查模型是否为空,若为空则设置默认模型 + if (StrUtil.isBlank(req.getModel())) { + log.info("模型 ID 为空,设置为默认模型: {}", DEFAULT_MODEL_ID); + req.setModel(DEFAULT_MODEL_ID); + } + // 记录请求信息 + log.info("请求参数: {}", JSON.toJSONString(req)); + try { + // 发起 HTTP POST 请求 + URL targetUrl; + if (StrUtil.isBlank(url)) { + log.info("URL 为空,使用默认 URL: {}", llmBackendProperties.getModelCompletions()); + targetUrl = new URL(llmBackendProperties.getModelCompletions()); + } else { + log.info("使用指定 URL: {}", url); + targetUrl = new URL(url); + } + HttpURLConnection connection = (HttpURLConnection) targetUrl.openConnection(); + connection.setRequestMethod("POST"); + connection.setRequestProperty("Content-Type", "application/json"); + connection.setDoOutput(true); + connection.getOutputStream().write(JSON.toJSONString(req).getBytes()); + + BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream())); + String line; + StringBuilder responseBuilder = new StringBuilder(); + while ((line = reader.readLine()) != null) { + // 解析流式响应内容 + if (StrUtil.isNotBlank(line)) { + log.info("收到流式响应内容: {}", line); + ChatCompletion chatCompletion = JSON.parseObject(line, ChatCompletion.class); + if (StrUtil.isBlank(chatCompletion.getDetail())) { + String respContent = chatCompletion.getChoices().get(0).getMessage().getContent(); + String patternString = "(.*?)"; + Pattern pattern = Pattern.compile(patternString, Pattern.DOTALL); + Matcher matcher = pattern.matcher(respContent); + String answerContent = matcher.replaceAll(""); + // 流式发送数据 + try { + emitter.send(SseEmitter.event().data(answerContent)); + log.info("已发送流式响应数据: {}", answerContent); + } catch (Exception e) { + log.error("发送流式响应数据时发生异常", e); + try { + emitter.completeWithError(e); + } catch (Exception ex) { + log.error("无法完成 SseEmitter 错误处理", ex); + } + return; + } + } + } + } + // 完成流式传输 + try { + emitter.complete(); + log.info("流式传输完成"); + } catch (Exception e) { + log.error("完成流式传输时发生异常", e); + } + } catch (Exception e) { + log.error("处理模型补全请求时发生异常", e); + try { + emitter.completeWithError(e); + } catch (Exception ex) { + log.error("无法完成 SseEmitter 错误处理", ex); + } + } + } /** * 通过 GPU 类型获取对应的主机地址