feat(llm): 优化知识库命中率测试功能

- 修改 executeHitRateTest 方法签名,使用 KnowledgeHitRateTestReqVO 作为参数
- 优化命中率测试逻辑,增加对 score 阈值的处理
- 调整 KnowledgeBaseDO 中 score 字段类型,从 Integer 改为 Double- 优化 hit rate 测试结果解析逻辑,增加错误处理
- 移除不必要的 DecimalFormat 使用,简化代码
This commit is contained in:
Liuyang 2025-03-13 16:56:27 +08:00
parent b29d9c5b0c
commit dff7904e39
7 changed files with 73 additions and 32 deletions

View File

@ -11,5 +11,6 @@ import java.util.List;
public class ParagraphHitRateListVO {
private String uuid;
private String groupId;
private Boolean isExist;
private List<ParagraphHitRateWordVO> wordList;
}

View File

@ -26,4 +26,9 @@ public class KnowledgeHitRateTestReqVO {
*/
// @NotNull(message = "k值不能为空")
private Integer k;
/**
* Score阈值
*/
private Double score;
}

View File

@ -46,7 +46,7 @@ public class KnowledgeBaseDO extends BaseDO {
/**
* Score阈值
*/
private Integer score;
private Double score;
/**
* 知识长度
*/

View File

@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.llm.service.async;
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestResultVO;
import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgedocuments.KnowledgeDocumentsDO;
import cn.iocoder.yudao.module.llm.dal.mysql.knowledgedocuments.KnowledgeDocumentsMapper;
@ -25,6 +26,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@ -142,11 +144,15 @@ public class AsyncKnowledgeBase {
}
public List<KnowledgeHitRateTestResultVO> executeHitRateTest (String query, List<Long> fileIds, Integer k) {
public List<KnowledgeHitRateTestResultVO> executeHitRateTest (KnowledgeHitRateTestReqVO testReqVO , List<Long> fileIds) {
List<String> fileIdStr = fileIds.stream()
.map(Object::toString)
.collect(Collectors.toList());
QueryMultipleReqVO vo = new QueryMultipleReqVO();
vo.setQuery(query);
vo.setFileIds(Collections.singletonList(String.valueOf(fileIds)));
vo.setK(k);
vo.setQuery(testReqVO.getQuery());
vo.setFileIds(fileIdStr);
vo.setK(testReqVO.getK());
vo.setScore(testReqVO.getScore());
List<KnowledgeHitRateTestResultVO> resultList = new ArrayList<>();
@ -155,10 +161,10 @@ public class AsyncKnowledgeBase {
KnowledgeHitRateTestResultVO resultVO = new KnowledgeHitRateTestResultVO();
resultVO.setPageContent(pair.getDocument().getPageContent());
DecimalFormat df = new DecimalFormat("0.00%");
df.setRoundingMode(RoundingMode.HALF_UP);
String rateResult = df.format(pair.getHitRate());
resultVO.setHitRate(rateResult);
// DecimalFormat df = new DecimalFormat("0.00%");
// df.setRoundingMode(RoundingMode.HALF_UP);
// String rateResult = df.format(pair.getHitRate());
resultVO.setHitRate(String.valueOf(pair.getHitRate()));
resultVO.setDigest(pair.getDocument().getMetadata().getDigest());
long fileId = Long.parseLong(pair.getDocument().getMetadata().getFileId());
resultVO.setFileId(fileId);

View File

@ -47,14 +47,12 @@ import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.io.*;
import java.math.RoundingMode;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@ -751,12 +749,29 @@ public class RagHttpService {
.execute().body();
cn.hutool.core.lang.Console.log(result2);
log.info("请求参数: {}",JSON.toJSONString(jsonString));
log.info("请求参数: {}",jsonString);
log.info("请求结果: {}",JSON.toJSONString(result2));
return parseHitRateTestResults(result2);
return parseHitRateTestResults(result2,vo.getScore());
}
private static List<QueryResultPairVO> parseHitRateTestResults (String json) {
private static List<QueryResultPairVO> parseHitRateTestResults (String json, Double score) {
boolean array= json.trim().startsWith("[");
// 先判断 JSON 是否是一个数组
if (!array){
// 判断是否存在 detail 字段
JSONObject jsonObject = JSON.parseObject(json);
if (jsonObject.containsKey("detail")) {
String detail = jsonObject.getString("detail");
if (detail.contains("No documents found for the given query")) {
throw exception(new ErrorCode(100_100_1, "未找到符合条件的文档,请检查查询条件!"));
}
return new ArrayList<>();
}
}
// JSON 转换为 List<QueryResultPair>
// 解析 JSON 数组
JSONArray jsonArray = JSON.parseArray(json);
@ -775,11 +790,13 @@ public class RagHttpService {
// 解析命中率
Double rate = pairArray.getDoubleValue(1);
// 创建 QueryResultPair 对象并添加到结果列表
QueryResultPairVO pair = new QueryResultPairVO();
pair.setDocument(document);
pair.setHitRate(rate);
results.add(pair);
if (rate >= score) {
QueryResultPairVO pair = new QueryResultPairVO();
pair.setDocument(document);
pair.setHitRate(rate);
results.add(pair);
}
}
// // 访问数据

View File

@ -29,4 +29,6 @@ public class QueryMultipleReqVO {
*/
// @NotNull(message = "k值不能为空")
private Integer k;
private Double score;
}

View File

@ -131,13 +131,13 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
// 更新或插入文档数据
List<KnowledgeDocumentsDO> newDocuments = updateOrInsertDocuments(documents, updateReqVO.getId(), updateObj.getKnowledgeLength());
Map<String,Integer> knowledgeParameters = new HashMap<>();
knowledgeParameters.put("chunkSize",updateReqVO.getChunkSize());
knowledgeParameters.put("chunkOverlap",updateReqVO.getChunkOverlap());
Map<String, Integer> knowledgeParameters = new HashMap<>();
knowledgeParameters.put("chunkSize", updateReqVO.getChunkSize());
knowledgeParameters.put("chunkOverlap", updateReqVO.getChunkOverlap());
// 异步处理新增文档和删除的文档
List<Long> deleteIds = knowledgeDocumentsMapper.selectDeleteIds(updateReqVO.getId());
asyncKnowledgeBase.createKnowledgeBase(newDocuments, deleteIds,knowledgeParameters);
asyncKnowledgeBase.createKnowledgeBase(newDocuments, deleteIds, knowledgeParameters);
}
/**
@ -152,14 +152,15 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
// 异步处理删除的文档
List<Long> deleteIds = knowledgeDocumentsMapper.selectDeleteIds(knowledgeBaseId);
if (!CollectionUtils.isAnyEmpty(deleteIds)) {
asyncKnowledgeBase.createKnowledgeBase(new ArrayList<>(), deleteIds,new HashMap<>());
asyncKnowledgeBase.createKnowledgeBase(new ArrayList<>(), deleteIds, new HashMap<>());
}
}
/**
* 删除不需要保留的文档
*
* @param knowledgeBaseId 知识库 ID
* @param retainedIds 需要保留的文档 ID
* @param retainedIds 需要保留的文档 ID
*/
private void deleteUnretainedDocuments (Long knowledgeBaseId, List<Long> retainedIds) {
LambdaQueryWrapperX<KnowledgeDocumentsDO> deleteWrapper = new LambdaQueryWrapperX<KnowledgeDocumentsDO>()
@ -172,7 +173,8 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
/**
* 更新或插入文档数据
* @param documents 需要更新的文档数据
*
* @param documents 需要更新的文档数据
* @param knowledgeBaseId 知识库 ID
* @param chunkSize
* @return 更新或插入的文档数据
@ -310,14 +312,22 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
if (baseDO == null) {
throw exception(KNOWLEDGE_BASE_NOT_EXISTS);
}
Integer topK=4;
if(baseDO.getTopK()==null||baseDO.getTopK()<=0){
Integer topK = 4;
if (baseDO.getTopK() == null || baseDO.getTopK() <= 0) {
testReqVO.setK(topK);
}else {
topK=baseDO.getTopK();
} else {
topK = baseDO.getTopK();
testReqVO.setK(topK);
}
Double score = 0.2;
if (baseDO.getScore() == null || baseDO.getTopK() <= 0.0|| baseDO.getScore() > 1) {
testReqVO.setScore(score);
} else {
score = baseDO.getScore();
testReqVO.setScore(score);
}
// 根据知识库ID获取参数信息关联文档
List<KnowledgeDocumentsDO> documentsDOS = knowledgeDocumentsMapper.selectList(new LambdaQueryWrapper<KnowledgeDocumentsDO>()
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, knowledgeId));
@ -330,9 +340,9 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
.map(KnowledgeDocumentsDO::getFileId)
.collect(Collectors.toList());
List<KnowledgeHitRateTestResultVO> result = asyncKnowledgeBase.executeHitRateTest(testReqVO.getQuery(), fileIds, testReqVO.getK());
List<KnowledgeHitRateTestResultVO> result = asyncKnowledgeBase.executeHitRateTest(testReqVO, fileIds);
if (com.baomidou.mybatisplus.core.toolkit.CollectionUtils.isEmpty(result)){
if (com.baomidou.mybatisplus.core.toolkit.CollectionUtils.isEmpty(result)) {
return Collections.emptyList();
}