diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java index 3204ec00b..ee5bbfca3 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java @@ -13,6 +13,9 @@ import cn.iocoder.yudao.module.llm.controller.admin.application.vo.ApplicationSa import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*; import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.ChatReqVO; import cn.iocoder.yudao.module.llm.controller.admin.datarefluxdata.vo.DataRefluxDataSaveReqVO; +import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseSaveReqVO; +import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestReqVO; +import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestResultVO; import cn.iocoder.yudao.module.llm.dal.dataobject.basemodel.BaseModelDO; import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO; import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgedocuments.KnowledgeDocumentsDO; @@ -26,6 +29,7 @@ import cn.iocoder.yudao.module.llm.service.basemodel.BaseModelService; import cn.iocoder.yudao.module.llm.service.datarefluxdata.DataRefluxDataService; import cn.iocoder.yudao.module.llm.service.http.ModelService; import cn.iocoder.yudao.module.llm.service.http.vo.*; +import cn.iocoder.yudao.module.llm.service.knowledgebase.KnowledgeBaseService; import cn.iocoder.yudao.module.llm.service.prompttemplates.PromptTemplatesService; import com.alibaba.excel.util.StringUtils; import com.alibaba.fastjson.JSON; @@ -45,6 +49,7 @@ import javax.servlet.http.HttpServletResponse; import java.math.RoundingMode; import java.text.DecimalFormat; import java.util.*; +import java.util.stream.Collectors; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*; @@ -80,6 +85,9 @@ public class ConversationServiceImpl implements ConversationService { @Resource private LLMBackendProperties llmBackendProperties; + @Resource + private KnowledgeBaseService knowledgeBaseService; + // 聊天会话历史记录缓存Key private final static String CHAT_HIStORY_REDIS_KEY = "llm:chat:history"; // 聊天会话历史记录缓存时间 @@ -439,14 +447,16 @@ public class ConversationServiceImpl implements ConversationService { // 处理 knowledgeBaseString if (StringUtils.isNotBlank(knowledgeBaseString)) { knowledgeBaseString = "" + knowledgeBaseString + ""; + }else { + knowledgeBaseString = "" + ""; } // 处理 systemPrompt systemPrompt = StringUtils.isBlank(chatReqVO.getSystemPrompt()) ? PROMPT - : chatReqVO.getSystemPrompt() + "\n" + PROMPT; + : chatReqVO.getSystemPrompt() + " \n " + PROMPT; } - String mess = systemPrompt + knowledgeBaseString; + String mess = systemPrompt + " \n "+knowledgeBaseString; // // 查询历史记录消息,并将查询出来的知识信息放入到 role = system 的消息中 // List messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1); @@ -532,30 +542,14 @@ public class ConversationServiceImpl implements ConversationService { ParagraphHitRateListVO paragraphHitRateListVO = new ParagraphHitRateListVO(); paragraphHitRateListVO.setUuid(chatReqVO.getUuid()); paragraphHitRateListVO.setGroupId(chatReqVO.getGroupId()); - List words = new ArrayList<>(); - // 2. 遍历处理每个文档 - for (KnowledgeDocumentsDO document : documentList) { - ParagraphHitRateWordVO rateWordVO = processDocument(document, chatReqVO, knowledgeBase); - if (rateWordVO != null) { - words.add(rateWordVO); - } - } - if (CollectionUtils.isEmpty(words)) { - paragraphHitRateListVO.setWordList(Collections.emptyList()); - paragraphHitRateListVO.setGroupId(""); - } else { - paragraphHitRateListVO.setWordList(words); - } + KnowledgeHitRateTestReqVO testReqVO=new KnowledgeHitRateTestReqVO(); + testReqVO.setKnowledgeId(chatReqVO.getKnowledge()); + testReqVO.setQuery(chatReqVO.getPrompt()); - // 请求结果添加到 Redis,查询段落命中率 - String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, chatReqVO.getUuid()); - stringRedisTemplate.opsForList().rightPush(redisKey, JSON.toJSONString(paragraphHitRateListVO)); + List result = knowledgeBaseService.executeHitRateTest(testReqVO); + knowledgeBase = handlerResult(result, paragraphHitRateListVO); - List paragraphHitRateList = stringRedisTemplate.opsForList().range(redisKey, 0, -1); - if (paragraphHitRateList != null && !paragraphHitRateList.isEmpty()) { - log.info("{} 知识库查询段落命中率: {}", "[KnowledgeBase]", paragraphHitRateList); - } log.info("{} 知识库构建完成,内容长度: {}", LOG_PREFIX, knowledgeBase.length()); } catch (Exception e) { @@ -567,6 +561,74 @@ public class ConversationServiceImpl implements ConversationService { return knowledgeBase; } + private StringBuilder handlerResult (List result, ParagraphHitRateListVO paragraphHitRateListVO) { + if (CollectionUtils.isEmpty(result)){ + return new StringBuilder(); + } + + // 1: 存储到redis + saveRedis(result, paragraphHitRateListVO); + + // 2: 组成返回数据 + StringBuilder knowledgeBase = new StringBuilder(); + result.forEach(item -> { + knowledgeBase.append(item.getPageContent()); + }); + return knowledgeBase; + } + + private void saveRedis (List result, ParagraphHitRateListVO paragraphHitRateListVO) { + if (CollectionUtils.isEmpty(result)){ + return; + } + List words = new ArrayList<>(); + + // 按照fileId分组,存到Map中 + Map> groupedByFileId = result.stream() + .collect(Collectors.groupingBy(KnowledgeHitRateTestResultVO::getFileId)); + + // 遍历Map,查看分组结果 + groupedByFileId.forEach((fileId, list) -> { + System.out.println("File ID: " + fileId); + list.forEach(i->{ + ParagraphHitRateWordVO rateWordVO = new ParagraphHitRateWordVO(); + // 设置文档名称 + rateWordVO.setDocumentName(i.getFileName()); + + // 设置段落命中率 + List paragraphHitRate=new ArrayList<>(); + for (KnowledgeHitRateTestResultVO i1 : list) { + ParagraphHitRateVO rateVO = new ParagraphHitRateVO(); + rateVO.setParagraph(i1.getPageContent()); + rateVO.setHitRate(i1.getHitRate()); + rateVO.setWordCount(i1.getPageContent().length()); + paragraphHitRate.add(rateVO); + } + + rateWordVO.setParagraphHitRate(paragraphHitRate); + + words.add(rateWordVO); + }); + }); + + if (CollectionUtils.isEmpty(words)) { + paragraphHitRateListVO.setWordList(Collections.emptyList()); + paragraphHitRateListVO.setIsExist(false); + } else { + paragraphHitRateListVO.setWordList(words); + paragraphHitRateListVO.setIsExist(true); + } + + // 请求结果添加到 Redis,查询段落命中率 + String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, paragraphHitRateListVO.getUuid()); + stringRedisTemplate.opsForList().rightPush(redisKey, JSON.toJSONString(paragraphHitRateListVO)); + + List paragraphHitRateList = stringRedisTemplate.opsForList().range(redisKey, 0, -1); + if (paragraphHitRateList != null && !paragraphHitRateList.isEmpty()) { + log.info("{} 知识库查询段落命中率: {}", "[KnowledgeBase]", paragraphHitRateList); + } + } + /** * 处理单个知识库文档的检索逻辑 */