feat(llm): 优化知识库命中率测试功能
- 修改 executeHitRateTest 方法签名,使用 KnowledgeHitRateTestReqVO 作为参数 - 优化命中率测试逻辑,增加对 score 阈值的处理 - 调整 KnowledgeBaseDO 中 score 字段类型,从 Integer 改为 Double- 优化 hit rate 测试结果解析逻辑,增加错误处理 - 移除不必要的 DecimalFormat 使用,简化代码
This commit is contained in:
parent
b29d9c5b0c
commit
dff7904e39
@ -11,5 +11,6 @@ import java.util.List;
|
||||
public class ParagraphHitRateListVO {
|
||||
private String uuid;
|
||||
private String groupId;
|
||||
private Boolean isExist;
|
||||
private List<ParagraphHitRateWordVO> wordList;
|
||||
}
|
||||
|
@ -26,4 +26,9 @@ public class KnowledgeHitRateTestReqVO {
|
||||
*/
|
||||
// @NotNull(message = "k值不能为空")
|
||||
private Integer k;
|
||||
|
||||
/**
|
||||
* Score阈值
|
||||
*/
|
||||
private Double score;
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ public class KnowledgeBaseDO extends BaseDO {
|
||||
/**
|
||||
* Score阈值
|
||||
*/
|
||||
private Integer score;
|
||||
private Double score;
|
||||
/**
|
||||
* 知识长度
|
||||
*/
|
||||
|
@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.llm.service.async;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
|
||||
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestReqVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestResultVO;
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgedocuments.KnowledgeDocumentsDO;
|
||||
import cn.iocoder.yudao.module.llm.dal.mysql.knowledgedocuments.KnowledgeDocumentsMapper;
|
||||
@ -25,6 +26,7 @@ import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
|
||||
@ -142,11 +144,15 @@ public class AsyncKnowledgeBase {
|
||||
|
||||
}
|
||||
|
||||
public List<KnowledgeHitRateTestResultVO> executeHitRateTest (String query, List<Long> fileIds, Integer k) {
|
||||
public List<KnowledgeHitRateTestResultVO> executeHitRateTest (KnowledgeHitRateTestReqVO testReqVO , List<Long> fileIds) {
|
||||
List<String> fileIdStr = fileIds.stream()
|
||||
.map(Object::toString)
|
||||
.collect(Collectors.toList());
|
||||
QueryMultipleReqVO vo = new QueryMultipleReqVO();
|
||||
vo.setQuery(query);
|
||||
vo.setFileIds(Collections.singletonList(String.valueOf(fileIds)));
|
||||
vo.setK(k);
|
||||
vo.setQuery(testReqVO.getQuery());
|
||||
vo.setFileIds(fileIdStr);
|
||||
vo.setK(testReqVO.getK());
|
||||
vo.setScore(testReqVO.getScore());
|
||||
|
||||
List<KnowledgeHitRateTestResultVO> resultList = new ArrayList<>();
|
||||
|
||||
@ -155,10 +161,10 @@ public class AsyncKnowledgeBase {
|
||||
KnowledgeHitRateTestResultVO resultVO = new KnowledgeHitRateTestResultVO();
|
||||
resultVO.setPageContent(pair.getDocument().getPageContent());
|
||||
|
||||
DecimalFormat df = new DecimalFormat("0.00%");
|
||||
df.setRoundingMode(RoundingMode.HALF_UP);
|
||||
String rateResult = df.format(pair.getHitRate());
|
||||
resultVO.setHitRate(rateResult);
|
||||
// DecimalFormat df = new DecimalFormat("0.00%");
|
||||
// df.setRoundingMode(RoundingMode.HALF_UP);
|
||||
// String rateResult = df.format(pair.getHitRate());
|
||||
resultVO.setHitRate(String.valueOf(pair.getHitRate()));
|
||||
resultVO.setDigest(pair.getDocument().getMetadata().getDigest());
|
||||
long fileId = Long.parseLong(pair.getDocument().getMetadata().getFileId());
|
||||
resultVO.setFileId(fileId);
|
||||
|
@ -47,14 +47,12 @@ import org.springframework.stereotype.Service;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.io.*;
|
||||
import java.math.RoundingMode;
|
||||
import java.net.URL;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.text.DecimalFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@ -751,12 +749,29 @@ public class RagHttpService {
|
||||
.execute().body();
|
||||
cn.hutool.core.lang.Console.log(result2);
|
||||
|
||||
log.info("请求参数: {}",JSON.toJSONString(jsonString));
|
||||
log.info("请求参数: {}",jsonString);
|
||||
log.info("请求结果: {}",JSON.toJSONString(result2));
|
||||
return parseHitRateTestResults(result2);
|
||||
return parseHitRateTestResults(result2,vo.getScore());
|
||||
}
|
||||
|
||||
private static List<QueryResultPairVO> parseHitRateTestResults (String json) {
|
||||
private static List<QueryResultPairVO> parseHitRateTestResults (String json, Double score) {
|
||||
boolean array= json.trim().startsWith("[");
|
||||
// 先判断 JSON 是否是一个数组
|
||||
|
||||
if (!array){
|
||||
// 判断是否存在 detail 字段
|
||||
JSONObject jsonObject = JSON.parseObject(json);
|
||||
if (jsonObject.containsKey("detail")) {
|
||||
String detail = jsonObject.getString("detail");
|
||||
|
||||
if (detail.contains("No documents found for the given query")) {
|
||||
throw exception(new ErrorCode(100_100_1, "未找到符合条件的文档,请检查查询条件!"));
|
||||
}
|
||||
return new ArrayList<>();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 将 JSON 转换为 List<QueryResultPair>
|
||||
// 解析 JSON 数组
|
||||
JSONArray jsonArray = JSON.parseArray(json);
|
||||
@ -775,11 +790,13 @@ public class RagHttpService {
|
||||
// 解析命中率
|
||||
Double rate = pairArray.getDoubleValue(1);
|
||||
|
||||
// 创建 QueryResultPair 对象并添加到结果列表
|
||||
QueryResultPairVO pair = new QueryResultPairVO();
|
||||
pair.setDocument(document);
|
||||
pair.setHitRate(rate);
|
||||
results.add(pair);
|
||||
if (rate >= score) {
|
||||
QueryResultPairVO pair = new QueryResultPairVO();
|
||||
pair.setDocument(document);
|
||||
pair.setHitRate(rate);
|
||||
results.add(pair);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// // 访问数据
|
||||
|
@ -29,4 +29,6 @@ public class QueryMultipleReqVO {
|
||||
*/
|
||||
// @NotNull(message = "k值不能为空")
|
||||
private Integer k;
|
||||
|
||||
private Double score;
|
||||
}
|
||||
|
@ -131,13 +131,13 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
|
||||
// 更新或插入文档数据
|
||||
List<KnowledgeDocumentsDO> newDocuments = updateOrInsertDocuments(documents, updateReqVO.getId(), updateObj.getKnowledgeLength());
|
||||
|
||||
Map<String,Integer> knowledgeParameters = new HashMap<>();
|
||||
knowledgeParameters.put("chunkSize",updateReqVO.getChunkSize());
|
||||
knowledgeParameters.put("chunkOverlap",updateReqVO.getChunkOverlap());
|
||||
Map<String, Integer> knowledgeParameters = new HashMap<>();
|
||||
knowledgeParameters.put("chunkSize", updateReqVO.getChunkSize());
|
||||
knowledgeParameters.put("chunkOverlap", updateReqVO.getChunkOverlap());
|
||||
|
||||
// 异步处理新增文档和删除的文档
|
||||
List<Long> deleteIds = knowledgeDocumentsMapper.selectDeleteIds(updateReqVO.getId());
|
||||
asyncKnowledgeBase.createKnowledgeBase(newDocuments, deleteIds,knowledgeParameters);
|
||||
asyncKnowledgeBase.createKnowledgeBase(newDocuments, deleteIds, knowledgeParameters);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -152,14 +152,15 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
|
||||
// 异步处理删除的文档
|
||||
List<Long> deleteIds = knowledgeDocumentsMapper.selectDeleteIds(knowledgeBaseId);
|
||||
if (!CollectionUtils.isAnyEmpty(deleteIds)) {
|
||||
asyncKnowledgeBase.createKnowledgeBase(new ArrayList<>(), deleteIds,new HashMap<>());
|
||||
asyncKnowledgeBase.createKnowledgeBase(new ArrayList<>(), deleteIds, new HashMap<>());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除不需要保留的文档
|
||||
*
|
||||
* @param knowledgeBaseId 知识库 ID
|
||||
* @param retainedIds 需要保留的文档 ID
|
||||
* @param retainedIds 需要保留的文档 ID
|
||||
*/
|
||||
private void deleteUnretainedDocuments (Long knowledgeBaseId, List<Long> retainedIds) {
|
||||
LambdaQueryWrapperX<KnowledgeDocumentsDO> deleteWrapper = new LambdaQueryWrapperX<KnowledgeDocumentsDO>()
|
||||
@ -172,7 +173,8 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
|
||||
|
||||
/**
|
||||
* 更新或插入文档数据
|
||||
* @param documents 需要更新的文档数据
|
||||
*
|
||||
* @param documents 需要更新的文档数据
|
||||
* @param knowledgeBaseId 知识库 ID
|
||||
* @param chunkSize
|
||||
* @return 更新或插入的文档数据
|
||||
@ -310,14 +312,22 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
|
||||
if (baseDO == null) {
|
||||
throw exception(KNOWLEDGE_BASE_NOT_EXISTS);
|
||||
}
|
||||
Integer topK=4;
|
||||
if(baseDO.getTopK()==null||baseDO.getTopK()<=0){
|
||||
Integer topK = 4;
|
||||
if (baseDO.getTopK() == null || baseDO.getTopK() <= 0) {
|
||||
testReqVO.setK(topK);
|
||||
}else {
|
||||
topK=baseDO.getTopK();
|
||||
} else {
|
||||
topK = baseDO.getTopK();
|
||||
testReqVO.setK(topK);
|
||||
}
|
||||
|
||||
Double score = 0.2;
|
||||
if (baseDO.getScore() == null || baseDO.getTopK() <= 0.0|| baseDO.getScore() > 1) {
|
||||
testReqVO.setScore(score);
|
||||
} else {
|
||||
score = baseDO.getScore();
|
||||
testReqVO.setScore(score);
|
||||
}
|
||||
|
||||
// 根据知识库ID获取参数信息,关联文档
|
||||
List<KnowledgeDocumentsDO> documentsDOS = knowledgeDocumentsMapper.selectList(new LambdaQueryWrapper<KnowledgeDocumentsDO>()
|
||||
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, knowledgeId));
|
||||
@ -330,9 +340,9 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
|
||||
.map(KnowledgeDocumentsDO::getFileId)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<KnowledgeHitRateTestResultVO> result = asyncKnowledgeBase.executeHitRateTest(testReqVO.getQuery(), fileIds, testReqVO.getK());
|
||||
List<KnowledgeHitRateTestResultVO> result = asyncKnowledgeBase.executeHitRateTest(testReqVO, fileIds);
|
||||
|
||||
if (com.baomidou.mybatisplus.core.toolkit.CollectionUtils.isEmpty(result)){
|
||||
if (com.baomidou.mybatisplus.core.toolkit.CollectionUtils.isEmpty(result)) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user