refactor(llm): 重构知识库处理逻辑

- 新增 KnowledgeBaseService 接口并注入到 ConversationServiceImpl
- 优化知识库字符串处理逻辑,增加空字符串处理
- 重构系统提示和知识库字符串的组合方式
- 新增知识库命中率测试相关功能
- 优化知识库数据结构,支持段落命中率计算
This commit is contained in:
Liuyang 2025-03-13 15:51:10 +08:00
parent 9e82ebdf5a
commit b29d9c5b0c

View File

@ -13,6 +13,9 @@ import cn.iocoder.yudao.module.llm.controller.admin.application.vo.ApplicationSa
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*;
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.ChatReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.datarefluxdata.vo.DataRefluxDataSaveReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseSaveReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestResultVO;
import cn.iocoder.yudao.module.llm.dal.dataobject.basemodel.BaseModelDO;
import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO;
import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgedocuments.KnowledgeDocumentsDO;
@ -26,6 +29,7 @@ import cn.iocoder.yudao.module.llm.service.basemodel.BaseModelService;
import cn.iocoder.yudao.module.llm.service.datarefluxdata.DataRefluxDataService;
import cn.iocoder.yudao.module.llm.service.http.ModelService;
import cn.iocoder.yudao.module.llm.service.http.vo.*;
import cn.iocoder.yudao.module.llm.service.knowledgebase.KnowledgeBaseService;
import cn.iocoder.yudao.module.llm.service.prompttemplates.PromptTemplatesService;
import com.alibaba.excel.util.StringUtils;
import com.alibaba.fastjson.JSON;
@ -45,6 +49,7 @@ import javax.servlet.http.HttpServletResponse;
import java.math.RoundingMode;
import java.text.DecimalFormat;
import java.util.*;
import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*;
@ -80,6 +85,9 @@ public class ConversationServiceImpl implements ConversationService {
@Resource
private LLMBackendProperties llmBackendProperties;
@Resource
private KnowledgeBaseService knowledgeBaseService;
// 聊天会话历史记录缓存Key
private final static String CHAT_HIStORY_REDIS_KEY = "llm:chat:history";
// 聊天会话历史记录缓存时间
@ -439,14 +447,16 @@ public class ConversationServiceImpl implements ConversationService {
// 处理 knowledgeBaseString
if (StringUtils.isNotBlank(knowledgeBaseString)) {
knowledgeBaseString = "<context>" + knowledgeBaseString + "</context>";
}else {
knowledgeBaseString = "<context>" + "</context>";
}
// 处理 systemPrompt
systemPrompt = StringUtils.isBlank(chatReqVO.getSystemPrompt())
? PROMPT
: chatReqVO.getSystemPrompt() + "\n" + PROMPT;
: chatReqVO.getSystemPrompt() + " \n " + PROMPT;
}
String mess = systemPrompt + knowledgeBaseString;
String mess = systemPrompt + " \n "+knowledgeBaseString;
// // 查询历史记录消息并将查询出来的知识信息放入到 role = system 的消息中
// List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
@ -532,30 +542,14 @@ public class ConversationServiceImpl implements ConversationService {
ParagraphHitRateListVO paragraphHitRateListVO = new ParagraphHitRateListVO();
paragraphHitRateListVO.setUuid(chatReqVO.getUuid());
paragraphHitRateListVO.setGroupId(chatReqVO.getGroupId());
List<ParagraphHitRateWordVO> words = new ArrayList<>();
// 2. 遍历处理每个文档
for (KnowledgeDocumentsDO document : documentList) {
ParagraphHitRateWordVO rateWordVO = processDocument(document, chatReqVO, knowledgeBase);
if (rateWordVO != null) {
words.add(rateWordVO);
}
}
if (CollectionUtils.isEmpty(words)) {
paragraphHitRateListVO.setWordList(Collections.emptyList());
paragraphHitRateListVO.setGroupId("");
} else {
paragraphHitRateListVO.setWordList(words);
}
KnowledgeHitRateTestReqVO testReqVO=new KnowledgeHitRateTestReqVO();
testReqVO.setKnowledgeId(chatReqVO.getKnowledge());
testReqVO.setQuery(chatReqVO.getPrompt());
// 请求结果添加到 Redis查询段落命中率
String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, chatReqVO.getUuid());
stringRedisTemplate.opsForList().rightPush(redisKey, JSON.toJSONString(paragraphHitRateListVO));
List<KnowledgeHitRateTestResultVO> result = knowledgeBaseService.executeHitRateTest(testReqVO);
knowledgeBase = handlerResult(result, paragraphHitRateListVO);
List<String> paragraphHitRateList = stringRedisTemplate.opsForList().range(redisKey, 0, -1);
if (paragraphHitRateList != null && !paragraphHitRateList.isEmpty()) {
log.info("{} 知识库查询段落命中率: {}", "[KnowledgeBase]", paragraphHitRateList);
}
log.info("{} 知识库构建完成,内容长度: {}", LOG_PREFIX, knowledgeBase.length());
} catch (Exception e) {
@ -567,6 +561,74 @@ public class ConversationServiceImpl implements ConversationService {
return knowledgeBase;
}
private StringBuilder handlerResult (List<KnowledgeHitRateTestResultVO> result, ParagraphHitRateListVO paragraphHitRateListVO) {
if (CollectionUtils.isEmpty(result)){
return new StringBuilder();
}
// 1: 存储到redis
saveRedis(result, paragraphHitRateListVO);
// 2: 组成返回数据
StringBuilder knowledgeBase = new StringBuilder();
result.forEach(item -> {
knowledgeBase.append(item.getPageContent());
});
return knowledgeBase;
}
private void saveRedis (List<KnowledgeHitRateTestResultVO> result, ParagraphHitRateListVO paragraphHitRateListVO) {
if (CollectionUtils.isEmpty(result)){
return;
}
List<ParagraphHitRateWordVO> words = new ArrayList<>();
// 按照fileId分组存到Map中
Map<Long, List<KnowledgeHitRateTestResultVO>> groupedByFileId = result.stream()
.collect(Collectors.groupingBy(KnowledgeHitRateTestResultVO::getFileId));
// 遍历Map查看分组结果
groupedByFileId.forEach((fileId, list) -> {
System.out.println("File ID: " + fileId);
list.forEach(i->{
ParagraphHitRateWordVO rateWordVO = new ParagraphHitRateWordVO();
// 设置文档名称
rateWordVO.setDocumentName(i.getFileName());
// 设置段落命中率
List<ParagraphHitRateVO> paragraphHitRate=new ArrayList<>();
for (KnowledgeHitRateTestResultVO i1 : list) {
ParagraphHitRateVO rateVO = new ParagraphHitRateVO();
rateVO.setParagraph(i1.getPageContent());
rateVO.setHitRate(i1.getHitRate());
rateVO.setWordCount(i1.getPageContent().length());
paragraphHitRate.add(rateVO);
}
rateWordVO.setParagraphHitRate(paragraphHitRate);
words.add(rateWordVO);
});
});
if (CollectionUtils.isEmpty(words)) {
paragraphHitRateListVO.setWordList(Collections.emptyList());
paragraphHitRateListVO.setIsExist(false);
} else {
paragraphHitRateListVO.setWordList(words);
paragraphHitRateListVO.setIsExist(true);
}
// 请求结果添加到 Redis查询段落命中率
String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, paragraphHitRateListVO.getUuid());
stringRedisTemplate.opsForList().rightPush(redisKey, JSON.toJSONString(paragraphHitRateListVO));
List<String> paragraphHitRateList = stringRedisTemplate.opsForList().range(redisKey, 0, -1);
if (paragraphHitRateList != null && !paragraphHitRateList.isEmpty()) {
log.info("{} 知识库查询段落命中率: {}", "[KnowledgeBase]", paragraphHitRateList);
}
}
/**
* 处理单个知识库文档的检索逻辑
*/