From 1d7d615af4ce655ffb8b4af012848f53cc6fa5a1 Mon Sep 17 00:00:00 2001 From: sunxiqing <2240398334@qq.com> Date: Fri, 14 Mar 2025 23:24:14 +0800 Subject: [PATCH] =?UTF-8?q?refactor(llm):=20=E4=BC=98=E5=8C=96=E8=81=8A?= =?UTF-8?q?=E5=A4=A9=E9=80=BB=E8=BE=91=E5=92=8C=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit -调整系统提示和知识库的处理顺序 - 优化聊天 --- .../conversation/ConversationServiceImpl.java | 143 +++++++++--------- 1 file changed, 75 insertions(+), 68 deletions(-) diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java index 435a16c7d..15a6cc474 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java @@ -64,6 +64,16 @@ import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*; @Slf4j public class ConversationServiceImpl implements ConversationService { + public static final String PROMPT = "Use the following context as your learned knowledge, inside XML tags.When answer to user:\\n- If you don't know, just say that you don't know.\\n- If you don't know when you are not sure, ask for clarification.\\nAvoid mentioning that you obtained the information from the context.\\nAnd answer according to the language of the user's question.\\n\\n"; + // 聊天会话历史记录缓存Key + private final static String CHAT_HIStORY_REDIS_KEY = "llm:chat:history"; + // 聊天会话历史记录缓存时间 + private final static Long CHAT_HISTORY_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L; + /** + * 知识库文档缓存Key + */ + private final static String KNOWLEDGE_DOCUMENTS_REDIS_KEY = "llm:knowledge:documents"; + private final static Long KNOWLEDGE_DOCUMENTS_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L; @Resource private ConversationMapper conversationMapper; @Resource @@ -84,24 +94,11 @@ public class ConversationServiceImpl implements ConversationService { private KnowledgeDocumentsMapper knowledgeDocumentsMapper; @Resource private LLMBackendProperties llmBackendProperties; - @Resource private KnowledgeBaseService knowledgeBaseService; - // 聊天会话历史记录缓存Key - private final static String CHAT_HIStORY_REDIS_KEY = "llm:chat:history"; - // 聊天会话历史记录缓存时间 - private final static Long CHAT_HISTORY_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L; - - /** - * 知识库文档缓存Key - */ - private final static String KNOWLEDGE_DOCUMENTS_REDIS_KEY = "llm:knowledge:documents"; - private final static Long KNOWLEDGE_DOCUMENTS_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L; - public static final String PROMPT = "Use the following context as your learned knowledge, inside XML tags.When answer to user:\\n- If you don't know, just say that you don't know.\\n- If you don't know when you are not sure, ask for clarification.\\nAvoid mentioning that you obtained the information from the context.\\nAnd answer according to the language of the user's question.\\n\\n"; - @Override - public Integer createConversation (ConversationSaveReqVO createReqVO) { + public Integer createConversation(ConversationSaveReqVO createReqVO) { // 插入 ConversationDO conversation = BeanUtils.toBean(createReqVO, ConversationDO.class); conversationMapper.insert(conversation); @@ -110,7 +107,7 @@ public class ConversationServiceImpl implements ConversationService { } @Override - public void updateConversation (ConversationSaveReqVO updateReqVO) { + public void updateConversation(ConversationSaveReqVO updateReqVO) { // 校验存在 validateConversationExists(updateReqVO.getId()); // 更新 @@ -119,31 +116,31 @@ public class ConversationServiceImpl implements ConversationService { } @Override - public void deleteConversation (Integer id) { + public void deleteConversation(Integer id) { // 校验存在 validateConversationExists(id); // 删除 conversationMapper.deleteById(id); } - private void validateConversationExists (Integer id) { + private void validateConversationExists(Integer id) { if (conversationMapper.selectById(id) == null) { throw exception(CONVERSATION_NOT_EXISTS); } } @Override - public ConversationDO getConversation (Integer id) { + public ConversationDO getConversation(Integer id) { return conversationMapper.selectById(id); } @Override - public PageResult getConversationPage (ConversationPageReqVO pageReqVO) { + public PageResult getConversationPage(ConversationPageReqVO pageReqVO) { return conversationMapper.selectPage(pageReqVO); } @Override - public ChatRespVO chat (ChatReqVO chatReqVO) { + public ChatRespVO chat(ChatReqVO chatReqVO) { if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().equals("")) { if (chatReqVO.getApplicationId() != null) { ApplicationRespVO application = applicationService.getApplication(chatReqVO.getApplicationId()); @@ -162,7 +159,7 @@ public class ConversationServiceImpl implements ConversationService { } @Override - public JSONArray textToImage (TextToImageReqVo req) { + public JSONArray textToImage(TextToImageReqVo req) { TextToImageRespVo textToImageRespVo = modelService.textToImage(req); return textToImageRespVo.getData(); } @@ -174,7 +171,7 @@ public class ConversationServiceImpl implements ConversationService { * @param chatReqVO * @return */ - public ChatRespVO publicModelChat (ChatReqVO chatReqVO) { + public ChatRespVO publicModelChat(ChatReqVO chatReqVO) { if (StringUtils.isBlank(chatReqVO.getUuid())) { // 如果没有uuid,就生成一个 chatReqVO.setUuid(UUID.randomUUID().toString()); @@ -300,7 +297,7 @@ public class ConversationServiceImpl implements ConversationService { * @param emitter SseEmitter 对象,用于流式发送响应 */ @Override - public void chatStream (ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response) { + public void chatStream(ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response) { log.info("开始处理对话请求,请求参数: {}", chatReqVO); // 检查系统提示信息,如果为空则尝试从应用中获取 if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().isEmpty()) { @@ -317,11 +314,11 @@ public class ConversationServiceImpl implements ConversationService { // 知识库ID - if (application.getModelServiceId()!= null){ + if (application.getModelServiceId() != null) { chatReqVO.setKnowledge(application.getModelServiceId()); } - if (!org.apache.commons.lang3.StringUtils.isBlank(application.getPrompt())){ + if (!org.apache.commons.lang3.StringUtils.isBlank(application.getPrompt())) { chatReqVO.setSystemPrompt(application.getPrompt()); if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().isEmpty()) { log.info("应用中未找到系统提示信息,使用默认提示信息"); @@ -343,7 +340,7 @@ public class ConversationServiceImpl implements ConversationService { * @return */ @Override - public ParagraphHitRateListVO getParagraphHitRate (String uuid, String groupId) { + public ParagraphHitRateListVO getParagraphHitRate(String uuid, String groupId) { String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, uuid); List redisResults = stringRedisTemplate.opsForList().range(redisKey, 0, -1); log.info("[Redis Query] Key: {} | Results count: {}", redisKey, redisResults != null ? redisResults.size() : 0); @@ -374,7 +371,7 @@ public class ConversationServiceImpl implements ConversationService { /** * 验证命中率数据有效性 */ - private boolean isValidHitRateData (JSONObject jsonObj, String expectedUuid, String expectedGroupId) { + private boolean isValidHitRateData(JSONObject jsonObj, String expectedUuid, String expectedGroupId) { if (!jsonObj.containsKey("uuid") || !jsonObj.containsKey("groupId")) { log.warn("[Validation] Missing required fields. Existing keys: {}", jsonObj.keySet()); return false; @@ -384,7 +381,7 @@ public class ConversationServiceImpl implements ConversationService { String actualGroupId = jsonObj.getString("groupId"); return Objects.equals(expectedUuid, actualUuid) - && Objects.equals(expectedGroupId, actualGroupId); + && Objects.equals(expectedGroupId, actualGroupId); } @@ -394,7 +391,7 @@ public class ConversationServiceImpl implements ConversationService { * @param chatReqVO 对话请求对象 * @param emitter SseEmitter 对象,用于流式发送响应 */ - public void publicModelChatStream (ChatReqVO chatReqVO, SseEmitter emitter) { + public void publicModelChatStream(ChatReqVO chatReqVO, SseEmitter emitter) { log.info("开始公共模型聊天流式处理,请求参数: {}", chatReqVO); // 检查 UUID 是否为空,若为空则生成一个 String uuid = chatReqVO.getUuid(); @@ -461,7 +458,7 @@ public class ConversationServiceImpl implements ConversationService { // 处理 knowledgeBaseString if (StringUtils.isNotBlank(knowledgeBaseString)) { knowledgeBaseString = "" + knowledgeBaseString + ""; - }else { + } else { knowledgeBaseString = "" + ""; } @@ -470,28 +467,38 @@ public class ConversationServiceImpl implements ConversationService { ? PROMPT : chatReqVO.getSystemPrompt() + " \n " + PROMPT; } - String mess = systemPrompt + " \n "+knowledgeBaseString; + String mess = systemPrompt + " \n " + knowledgeBaseString; + + if (chatReqVO.getKnowledge() != null) { + log.info("不存在聊天历史记录,创建新的系统消息"); + ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage(); + systemMessage.setRole("system"); + systemMessage.setContent(mess); + stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage)); + messages.add(systemMessage); + } else { + // 查询历史记录消息,并将查询出来的知识信息放入到 role = system 的消息中 + List messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1); + if (messageHistoryList != null && !messageHistoryList.isEmpty()) { + log.info("存在聊天历史记录,处理历史记录消息"); + for (String messageHistory : messageHistoryList) { + ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class); + if ("system".equals(modelCompletionsMessage.getRole())) { + modelCompletionsMessage.setContent(mess); + stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage)); + } + messages.add(modelCompletionsMessage); + } + } else { + log.info("不存在聊天历史记录,创建新的系统消息"); + ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage(); + systemMessage.setRole("system"); + systemMessage.setContent(mess); + stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage)); + messages.add(systemMessage); + } + } - // // 查询历史记录消息,并将查询出来的知识信息放入到 role = system 的消息中 - // List messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1); - // if (messageHistoryList != null && !messageHistoryList.isEmpty()) { - // log.info("存在聊天历史记录,处理历史记录消息"); - // for (String messageHistory : messageHistoryList) { - // ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class); - // if ("system".equals(modelCompletionsMessage.getRole())) { - // modelCompletionsMessage.setContent(mess); - // stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage)); - // } - // messages.add(modelCompletionsMessage); - // } - // } else { - log.info("不存在聊天历史记录,创建新的系统消息"); - ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage(); - systemMessage.setRole("system"); - systemMessage.setContent(mess); - stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage)); - messages.add(systemMessage); - // } // 创建用户消息 ModelCompletionsReqVO.ModelCompletionsMessage message = new ModelCompletionsReqVO.ModelCompletionsMessage(); @@ -509,7 +516,7 @@ public class ConversationServiceImpl implements ConversationService { // log.info("构建模型补全请求对象,请求参数1: {}", modelCompletionsReqVO); // 调用模型服务进行流式处理 - ModelCompletionsRespVO modelCompletionsRespVO = modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter, chatReqVO.getUuid(), chatReqVO.getGroupId(),knowledgeBaseString); + ModelCompletionsRespVO modelCompletionsRespVO = modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter, chatReqVO.getUuid(), chatReqVO.getGroupId(), knowledgeBaseString); if (modelCompletionsRespVO == null) { throw exception(MODEL_COMPLETIONS_ERROR); } @@ -533,7 +540,7 @@ public class ConversationServiceImpl implements ConversationService { } @NotNull - private StringBuilder getKnowledgeBase (ChatReqVO chatReqVO) { + private StringBuilder getKnowledgeBase(ChatReqVO chatReqVO) { final String LOG_PREFIX = "[KnowledgeBase]"; StringBuilder knowledgeBase = new StringBuilder(); @@ -557,12 +564,12 @@ public class ConversationServiceImpl implements ConversationService { paragraphHitRateListVO.setUuid(chatReqVO.getUuid()); paragraphHitRateListVO.setGroupId(chatReqVO.getGroupId()); - KnowledgeHitRateTestReqVO testReqVO=new KnowledgeHitRateTestReqVO(); + KnowledgeHitRateTestReqVO testReqVO = new KnowledgeHitRateTestReqVO(); testReqVO.setKnowledgeId(chatReqVO.getKnowledge()); testReqVO.setQuery(chatReqVO.getPrompt()); List result = knowledgeBaseService.executeHitRateTest(testReqVO); - knowledgeBase = handlerResult(result, paragraphHitRateListVO); + knowledgeBase = handlerResult(result, paragraphHitRateListVO); log.info("{} 知识库构建完成,内容长度: {}", LOG_PREFIX, knowledgeBase.length()); @@ -575,8 +582,8 @@ public class ConversationServiceImpl implements ConversationService { return knowledgeBase; } - private StringBuilder handlerResult (List result, ParagraphHitRateListVO paragraphHitRateListVO) { - if (CollectionUtils.isEmpty(result)){ + private StringBuilder handlerResult(List result, ParagraphHitRateListVO paragraphHitRateListVO) { + if (CollectionUtils.isEmpty(result)) { return new StringBuilder(); } @@ -591,8 +598,8 @@ public class ConversationServiceImpl implements ConversationService { return knowledgeBase; } - private void saveRedis (List result, ParagraphHitRateListVO paragraphHitRateListVO) { - if (CollectionUtils.isEmpty(result)){ + private void saveRedis(List result, ParagraphHitRateListVO paragraphHitRateListVO) { + if (CollectionUtils.isEmpty(result)) { return; } List words = new ArrayList<>(); @@ -605,13 +612,13 @@ public class ConversationServiceImpl implements ConversationService { List finalWords = words; groupedByFileId.forEach((fileId, list) -> { System.out.println("File ID: " + fileId); - list.forEach(i->{ + list.forEach(i -> { ParagraphHitRateWordVO rateWordVO = new ParagraphHitRateWordVO(); // 设置文档名称 rateWordVO.setDocumentName(i.getFileName()); // 设置段落命中率 - List paragraphHitRate=new ArrayList<>(); + List paragraphHitRate = new ArrayList<>(); for (KnowledgeHitRateTestResultVO i1 : list) { ParagraphHitRateVO rateVO = new ParagraphHitRateVO(); rateVO.setParagraph(i1.getPageContent()); @@ -649,9 +656,9 @@ public class ConversationServiceImpl implements ConversationService { /** * 处理单个知识库文档的检索逻辑 */ - private ParagraphHitRateWordVO processDocument (KnowledgeDocumentsDO document, - ChatReqVO chatReqVO, - StringBuilder knowledgeBase) { + private ParagraphHitRateWordVO processDocument(KnowledgeDocumentsDO document, + ChatReqVO chatReqVO, + StringBuilder knowledgeBase) { ParagraphHitRateWordVO paragraphHitRateListVO = new ParagraphHitRateWordVO(); try { log.info("{} 处理文档: {}[ID:{}]", "[KnowledgeBase]", document.getDocumentName(), document.getId()); @@ -688,9 +695,9 @@ public class ConversationServiceImpl implements ConversationService { /** * 解析向量检索响应结果 */ - private List parseEmbeddingResponse (String response, - StringBuilder knowledgeBase, - String uuid) { + private List parseEmbeddingResponse(String response, + StringBuilder knowledgeBase, + String uuid) { if (StringUtils.isBlank(response)) { log.warn("{} 收到空响应", "[KnowledgeBase]"); return Collections.emptyList(); @@ -756,7 +763,7 @@ public class ConversationServiceImpl implements ConversationService { * @param chatReqVO * @return */ - private ChatRespVO privateModelChat (ChatReqVO chatReqVO) { + private ChatRespVO privateModelChat(ChatReqVO chatReqVO) { if (StringUtils.isBlank(chatReqVO.getUuid())) { // 如果没有uuid,就生成一个 chatReqVO.setUuid(UUID.randomUUID().toString());