refactor(llm): 优化知识库命中率测试功能

- 添加文件名字段并格式化命中率显示
- 增加知识库存在性检查和默认 topK 值设置
- 优化日志输出,记录请求参数和结果
- 统一数据类型:将命中率从 Double 改为 String
This commit is contained in:
Liuyang 2025-03-13 15:23:40 +08:00
parent 8cb60e82a8
commit 9e82ebdf5a
4 changed files with 44 additions and 5 deletions

View File

@ -15,7 +15,7 @@ public class KnowledgeHitRateTestResultVO {
/**
* 命中率
*/
private Double hitRate;
private String hitRate;
/**
* 摘要信息
@ -26,4 +26,9 @@ public class KnowledgeHitRateTestResultVO {
* 文件ID
*/
private Long fileId;
/**
* 文件名称
*/
private String fileName;
}

View File

@ -12,12 +12,15 @@ import cn.iocoder.yudao.module.llm.service.http.vo.KnowledgeRagEmbedReqVO;
import cn.iocoder.yudao.module.llm.service.http.vo.RegUploadReqVO;
import cn.iocoder.yudao.module.llm.service.http.vo.query.multiple.QueryMultipleReqVO;
import cn.iocoder.yudao.module.llm.service.http.vo.query.multiple.QueryResultPairVO;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.io.IOException;
import java.math.RoundingMode;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@ -151,9 +154,23 @@ public class AsyncKnowledgeBase {
for (QueryResultPairVO pair : result) {
KnowledgeHitRateTestResultVO resultVO = new KnowledgeHitRateTestResultVO();
resultVO.setPageContent(pair.getDocument().getPageContent());
resultVO.setHitRate(pair.getHitRate());
DecimalFormat df = new DecimalFormat("0.00%");
df.setRoundingMode(RoundingMode.HALF_UP);
String rateResult = df.format(pair.getHitRate());
resultVO.setHitRate(rateResult);
resultVO.setDigest(pair.getDocument().getMetadata().getDigest());
resultVO.setFileId(Long.parseLong(pair.getDocument().getMetadata().getFileId()));
long fileId = Long.parseLong(pair.getDocument().getMetadata().getFileId());
resultVO.setFileId(fileId);
// 根据 fileId 查找文件名
KnowledgeDocumentsDO documents = knowledgeDocumentsMapper.selectOne(KnowledgeDocumentsDO::getFileId, fileId);
if (documents!=null && StringUtils.isNotBlank(documents.getDocumentName())){
resultVO.setFileName(documents.getDocumentName());
}else {
resultVO.setFileName("未知文件");
}
resultList.add(resultVO);
}
return resultList;

View File

@ -47,12 +47,14 @@ import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.io.*;
import java.math.RoundingMode;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@ -748,6 +750,9 @@ public class RagHttpService {
.timeout(20000)
.execute().body();
cn.hutool.core.lang.Console.log(result2);
log.info("请求参数: {}",JSON.toJSONString(jsonString));
log.info("请求结果: {}",JSON.toJSONString(result2));
return parseHitRateTestResults(result2);
}
@ -768,12 +773,12 @@ public class RagHttpService {
DocumentInfoVO document = JSON.parseObject(documentJson.toJSONString(), DocumentInfoVO.class);
// 解析命中率
double hitRate = pairArray.getDoubleValue(1);
Double rate = pairArray.getDoubleValue(1);
// 创建 QueryResultPair 对象并添加到结果列表
QueryResultPairVO pair = new QueryResultPairVO();
pair.setDocument(document);
pair.setHitRate(hitRate);
pair.setHitRate(rate);
results.add(pair);
}

View File

@ -306,6 +306,18 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
public List<KnowledgeHitRateTestResultVO> executeHitRateTest (KnowledgeHitRateTestReqVO testReqVO) {
Long knowledgeId = testReqVO.getKnowledgeId();
KnowledgeBaseDO baseDO = knowledgeBaseMapper.selectOne(KnowledgeBaseDO::getId, knowledgeId);
if (baseDO == null) {
throw exception(KNOWLEDGE_BASE_NOT_EXISTS);
}
Integer topK=4;
if(baseDO.getTopK()==null||baseDO.getTopK()<=0){
testReqVO.setK(topK);
}else {
topK=baseDO.getTopK();
testReqVO.setK(topK);
}
// 根据知识库ID获取参数信息关联文档
List<KnowledgeDocumentsDO> documentsDOS = knowledgeDocumentsMapper.selectList(new LambdaQueryWrapper<KnowledgeDocumentsDO>()
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, knowledgeId));