diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java
index 3204ec00b..ee5bbfca3 100644
--- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java
+++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java
@@ -13,6 +13,9 @@ import cn.iocoder.yudao.module.llm.controller.admin.application.vo.ApplicationSa
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*;
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.ChatReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.datarefluxdata.vo.DataRefluxDataSaveReqVO;
+import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseSaveReqVO;
+import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestReqVO;
+import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestResultVO;
import cn.iocoder.yudao.module.llm.dal.dataobject.basemodel.BaseModelDO;
import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO;
import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgedocuments.KnowledgeDocumentsDO;
@@ -26,6 +29,7 @@ import cn.iocoder.yudao.module.llm.service.basemodel.BaseModelService;
import cn.iocoder.yudao.module.llm.service.datarefluxdata.DataRefluxDataService;
import cn.iocoder.yudao.module.llm.service.http.ModelService;
import cn.iocoder.yudao.module.llm.service.http.vo.*;
+import cn.iocoder.yudao.module.llm.service.knowledgebase.KnowledgeBaseService;
import cn.iocoder.yudao.module.llm.service.prompttemplates.PromptTemplatesService;
import com.alibaba.excel.util.StringUtils;
import com.alibaba.fastjson.JSON;
@@ -45,6 +49,7 @@ import javax.servlet.http.HttpServletResponse;
import java.math.RoundingMode;
import java.text.DecimalFormat;
import java.util.*;
+import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*;
@@ -80,6 +85,9 @@ public class ConversationServiceImpl implements ConversationService {
@Resource
private LLMBackendProperties llmBackendProperties;
+ @Resource
+ private KnowledgeBaseService knowledgeBaseService;
+
// 聊天会话历史记录缓存Key
private final static String CHAT_HIStORY_REDIS_KEY = "llm:chat:history";
// 聊天会话历史记录缓存时间
@@ -439,14 +447,16 @@ public class ConversationServiceImpl implements ConversationService {
// 处理 knowledgeBaseString
if (StringUtils.isNotBlank(knowledgeBaseString)) {
knowledgeBaseString = "" + knowledgeBaseString + "";
+ }else {
+ knowledgeBaseString = "" + "";
}
// 处理 systemPrompt
systemPrompt = StringUtils.isBlank(chatReqVO.getSystemPrompt())
? PROMPT
- : chatReqVO.getSystemPrompt() + "\n" + PROMPT;
+ : chatReqVO.getSystemPrompt() + " \n " + PROMPT;
}
- String mess = systemPrompt + knowledgeBaseString;
+ String mess = systemPrompt + " \n "+knowledgeBaseString;
// // 查询历史记录消息,并将查询出来的知识信息放入到 role = system 的消息中
// List messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
@@ -532,30 +542,14 @@ public class ConversationServiceImpl implements ConversationService {
ParagraphHitRateListVO paragraphHitRateListVO = new ParagraphHitRateListVO();
paragraphHitRateListVO.setUuid(chatReqVO.getUuid());
paragraphHitRateListVO.setGroupId(chatReqVO.getGroupId());
- List words = new ArrayList<>();
- // 2. 遍历处理每个文档
- for (KnowledgeDocumentsDO document : documentList) {
- ParagraphHitRateWordVO rateWordVO = processDocument(document, chatReqVO, knowledgeBase);
- if (rateWordVO != null) {
- words.add(rateWordVO);
- }
- }
- if (CollectionUtils.isEmpty(words)) {
- paragraphHitRateListVO.setWordList(Collections.emptyList());
- paragraphHitRateListVO.setGroupId("");
- } else {
- paragraphHitRateListVO.setWordList(words);
- }
+ KnowledgeHitRateTestReqVO testReqVO=new KnowledgeHitRateTestReqVO();
+ testReqVO.setKnowledgeId(chatReqVO.getKnowledge());
+ testReqVO.setQuery(chatReqVO.getPrompt());
- // 请求结果添加到 Redis,查询段落命中率
- String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, chatReqVO.getUuid());
- stringRedisTemplate.opsForList().rightPush(redisKey, JSON.toJSONString(paragraphHitRateListVO));
+ List result = knowledgeBaseService.executeHitRateTest(testReqVO);
+ knowledgeBase = handlerResult(result, paragraphHitRateListVO);
- List paragraphHitRateList = stringRedisTemplate.opsForList().range(redisKey, 0, -1);
- if (paragraphHitRateList != null && !paragraphHitRateList.isEmpty()) {
- log.info("{} 知识库查询段落命中率: {}", "[KnowledgeBase]", paragraphHitRateList);
- }
log.info("{} 知识库构建完成,内容长度: {}", LOG_PREFIX, knowledgeBase.length());
} catch (Exception e) {
@@ -567,6 +561,74 @@ public class ConversationServiceImpl implements ConversationService {
return knowledgeBase;
}
+ private StringBuilder handlerResult (List result, ParagraphHitRateListVO paragraphHitRateListVO) {
+ if (CollectionUtils.isEmpty(result)){
+ return new StringBuilder();
+ }
+
+ // 1: 存储到redis
+ saveRedis(result, paragraphHitRateListVO);
+
+ // 2: 组成返回数据
+ StringBuilder knowledgeBase = new StringBuilder();
+ result.forEach(item -> {
+ knowledgeBase.append(item.getPageContent());
+ });
+ return knowledgeBase;
+ }
+
+ private void saveRedis (List result, ParagraphHitRateListVO paragraphHitRateListVO) {
+ if (CollectionUtils.isEmpty(result)){
+ return;
+ }
+ List words = new ArrayList<>();
+
+ // 按照fileId分组,存到Map中
+ Map> groupedByFileId = result.stream()
+ .collect(Collectors.groupingBy(KnowledgeHitRateTestResultVO::getFileId));
+
+ // 遍历Map,查看分组结果
+ groupedByFileId.forEach((fileId, list) -> {
+ System.out.println("File ID: " + fileId);
+ list.forEach(i->{
+ ParagraphHitRateWordVO rateWordVO = new ParagraphHitRateWordVO();
+ // 设置文档名称
+ rateWordVO.setDocumentName(i.getFileName());
+
+ // 设置段落命中率
+ List paragraphHitRate=new ArrayList<>();
+ for (KnowledgeHitRateTestResultVO i1 : list) {
+ ParagraphHitRateVO rateVO = new ParagraphHitRateVO();
+ rateVO.setParagraph(i1.getPageContent());
+ rateVO.setHitRate(i1.getHitRate());
+ rateVO.setWordCount(i1.getPageContent().length());
+ paragraphHitRate.add(rateVO);
+ }
+
+ rateWordVO.setParagraphHitRate(paragraphHitRate);
+
+ words.add(rateWordVO);
+ });
+ });
+
+ if (CollectionUtils.isEmpty(words)) {
+ paragraphHitRateListVO.setWordList(Collections.emptyList());
+ paragraphHitRateListVO.setIsExist(false);
+ } else {
+ paragraphHitRateListVO.setWordList(words);
+ paragraphHitRateListVO.setIsExist(true);
+ }
+
+ // 请求结果添加到 Redis,查询段落命中率
+ String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, paragraphHitRateListVO.getUuid());
+ stringRedisTemplate.opsForList().rightPush(redisKey, JSON.toJSONString(paragraphHitRateListVO));
+
+ List paragraphHitRateList = stringRedisTemplate.opsForList().range(redisKey, 0, -1);
+ if (paragraphHitRateList != null && !paragraphHitRateList.isEmpty()) {
+ log.info("{} 知识库查询段落命中率: {}", "[KnowledgeBase]", paragraphHitRateList);
+ }
+ }
+
/**
* 处理单个知识库文档的检索逻辑
*/