refactor(llm): 优化知识库命中率测试功能
- 添加文件名字段并格式化命中率显示 - 增加知识库存在性检查和默认 topK 值设置 - 优化日志输出,记录请求参数和结果 - 统一数据类型:将命中率从 Double 改为 String
This commit is contained in:
parent
8cb60e82a8
commit
9e82ebdf5a
@ -15,7 +15,7 @@ public class KnowledgeHitRateTestResultVO {
|
||||
/**
|
||||
* 命中率
|
||||
*/
|
||||
private Double hitRate;
|
||||
private String hitRate;
|
||||
|
||||
/**
|
||||
* 摘要信息
|
||||
@ -26,4 +26,9 @@ public class KnowledgeHitRateTestResultVO {
|
||||
* 文件ID
|
||||
*/
|
||||
private Long fileId;
|
||||
|
||||
/**
|
||||
* 文件名称
|
||||
*/
|
||||
private String fileName;
|
||||
}
|
||||
|
@ -12,12 +12,15 @@ import cn.iocoder.yudao.module.llm.service.http.vo.KnowledgeRagEmbedReqVO;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.RegUploadReqVO;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.query.multiple.QueryMultipleReqVO;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.query.multiple.QueryResultPairVO;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.io.IOException;
|
||||
import java.math.RoundingMode;
|
||||
import java.text.DecimalFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
@ -151,9 +154,23 @@ public class AsyncKnowledgeBase {
|
||||
for (QueryResultPairVO pair : result) {
|
||||
KnowledgeHitRateTestResultVO resultVO = new KnowledgeHitRateTestResultVO();
|
||||
resultVO.setPageContent(pair.getDocument().getPageContent());
|
||||
resultVO.setHitRate(pair.getHitRate());
|
||||
|
||||
DecimalFormat df = new DecimalFormat("0.00%");
|
||||
df.setRoundingMode(RoundingMode.HALF_UP);
|
||||
String rateResult = df.format(pair.getHitRate());
|
||||
resultVO.setHitRate(rateResult);
|
||||
resultVO.setDigest(pair.getDocument().getMetadata().getDigest());
|
||||
resultVO.setFileId(Long.parseLong(pair.getDocument().getMetadata().getFileId()));
|
||||
long fileId = Long.parseLong(pair.getDocument().getMetadata().getFileId());
|
||||
resultVO.setFileId(fileId);
|
||||
|
||||
// 根据 fileId 查找文件名
|
||||
KnowledgeDocumentsDO documents = knowledgeDocumentsMapper.selectOne(KnowledgeDocumentsDO::getFileId, fileId);
|
||||
if (documents!=null && StringUtils.isNotBlank(documents.getDocumentName())){
|
||||
resultVO.setFileName(documents.getDocumentName());
|
||||
}else {
|
||||
resultVO.setFileName("未知文件");
|
||||
}
|
||||
|
||||
resultList.add(resultVO);
|
||||
}
|
||||
return resultList;
|
||||
|
@ -47,12 +47,14 @@ import org.springframework.stereotype.Service;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.io.*;
|
||||
import java.math.RoundingMode;
|
||||
import java.net.URL;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.text.DecimalFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@ -748,6 +750,9 @@ public class RagHttpService {
|
||||
.timeout(20000)
|
||||
.execute().body();
|
||||
cn.hutool.core.lang.Console.log(result2);
|
||||
|
||||
log.info("请求参数: {}",JSON.toJSONString(jsonString));
|
||||
log.info("请求结果: {}",JSON.toJSONString(result2));
|
||||
return parseHitRateTestResults(result2);
|
||||
}
|
||||
|
||||
@ -768,12 +773,12 @@ public class RagHttpService {
|
||||
DocumentInfoVO document = JSON.parseObject(documentJson.toJSONString(), DocumentInfoVO.class);
|
||||
|
||||
// 解析命中率
|
||||
double hitRate = pairArray.getDoubleValue(1);
|
||||
Double rate = pairArray.getDoubleValue(1);
|
||||
|
||||
// 创建 QueryResultPair 对象并添加到结果列表
|
||||
QueryResultPairVO pair = new QueryResultPairVO();
|
||||
pair.setDocument(document);
|
||||
pair.setHitRate(hitRate);
|
||||
pair.setHitRate(rate);
|
||||
results.add(pair);
|
||||
}
|
||||
|
||||
|
@ -306,6 +306,18 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
|
||||
public List<KnowledgeHitRateTestResultVO> executeHitRateTest (KnowledgeHitRateTestReqVO testReqVO) {
|
||||
Long knowledgeId = testReqVO.getKnowledgeId();
|
||||
|
||||
KnowledgeBaseDO baseDO = knowledgeBaseMapper.selectOne(KnowledgeBaseDO::getId, knowledgeId);
|
||||
if (baseDO == null) {
|
||||
throw exception(KNOWLEDGE_BASE_NOT_EXISTS);
|
||||
}
|
||||
Integer topK=4;
|
||||
if(baseDO.getTopK()==null||baseDO.getTopK()<=0){
|
||||
testReqVO.setK(topK);
|
||||
}else {
|
||||
topK=baseDO.getTopK();
|
||||
testReqVO.setK(topK);
|
||||
}
|
||||
|
||||
// 根据知识库ID获取参数信息,关联文档
|
||||
List<KnowledgeDocumentsDO> documentsDOS = knowledgeDocumentsMapper.selectList(new LambdaQueryWrapper<KnowledgeDocumentsDO>()
|
||||
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, knowledgeId));
|
||||
|
Loading…
x
Reference in New Issue
Block a user