From dff7904e39e6a46f7717f164f87ad292d5a3cb64 Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Thu, 13 Mar 2025 16:56:27 +0800 Subject: [PATCH] =?UTF-8?q?feat(llm):=20=E4=BC=98=E5=8C=96=E7=9F=A5?= =?UTF-8?q?=E8=AF=86=E5=BA=93=E5=91=BD=E4=B8=AD=E7=8E=87=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修改 executeHitRateTest 方法签名,使用 KnowledgeHitRateTestReqVO 作为参数 - 优化命中率测试逻辑,增加对 score 阈值的处理 - 调整 KnowledgeBaseDO 中 score 字段类型,从 Integer 改为 Double- 优化 hit rate 测试结果解析逻辑,增加错误处理 - 移除不必要的 DecimalFormat 使用,简化代码 --- .../vo/ParagraphHitRateListVO.java | 1 + .../vo/KnowledgeHitRateTestReqVO.java | 5 +++ .../knowledgebase/KnowledgeBaseDO.java | 2 +- .../llm/service/async/AsyncKnowledgeBase.java | 22 +++++++---- .../llm/service/http/RagHttpService.java | 37 ++++++++++++++----- .../vo/query/multiple/QueryMultipleReqVO.java | 2 + .../KnowledgeBaseServiceImpl.java | 36 +++++++++++------- 7 files changed, 73 insertions(+), 32 deletions(-) diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ParagraphHitRateListVO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ParagraphHitRateListVO.java index 95595e67d..4c4d7a0d8 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ParagraphHitRateListVO.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ParagraphHitRateListVO.java @@ -11,5 +11,6 @@ import java.util.List; public class ParagraphHitRateListVO { private String uuid; private String groupId; + private Boolean isExist; private List wordList; } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/knowledgebase/vo/KnowledgeHitRateTestReqVO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/knowledgebase/vo/KnowledgeHitRateTestReqVO.java index 4163e848e..5b23ee7b6 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/knowledgebase/vo/KnowledgeHitRateTestReqVO.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/knowledgebase/vo/KnowledgeHitRateTestReqVO.java @@ -26,4 +26,9 @@ public class KnowledgeHitRateTestReqVO { */ // @NotNull(message = "k值不能为空") private Integer k; + + /** + * Score阈值 + */ + private Double score; } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/dal/dataobject/knowledgebase/KnowledgeBaseDO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/dal/dataobject/knowledgebase/KnowledgeBaseDO.java index 79aefc497..26e57133f 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/dal/dataobject/knowledgebase/KnowledgeBaseDO.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/dal/dataobject/knowledgebase/KnowledgeBaseDO.java @@ -46,7 +46,7 @@ public class KnowledgeBaseDO extends BaseDO { /** * Score阈值 */ - private Integer score; + private Double score; /** * 知识长度 */ diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/async/AsyncKnowledgeBase.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/async/AsyncKnowledgeBase.java index d3cfc49f8..22c486038 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/async/AsyncKnowledgeBase.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/async/AsyncKnowledgeBase.java @@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.llm.service.async; import cn.iocoder.yudao.framework.common.exception.ErrorCode; import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils; +import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestReqVO; import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestResultVO; import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgedocuments.KnowledgeDocumentsDO; import cn.iocoder.yudao.module.llm.dal.mysql.knowledgedocuments.KnowledgeDocumentsMapper; @@ -25,6 +26,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; @@ -142,11 +144,15 @@ public class AsyncKnowledgeBase { } - public List executeHitRateTest (String query, List fileIds, Integer k) { + public List executeHitRateTest (KnowledgeHitRateTestReqVO testReqVO , List fileIds) { + List fileIdStr = fileIds.stream() + .map(Object::toString) + .collect(Collectors.toList()); QueryMultipleReqVO vo = new QueryMultipleReqVO(); - vo.setQuery(query); - vo.setFileIds(Collections.singletonList(String.valueOf(fileIds))); - vo.setK(k); + vo.setQuery(testReqVO.getQuery()); + vo.setFileIds(fileIdStr); + vo.setK(testReqVO.getK()); + vo.setScore(testReqVO.getScore()); List resultList = new ArrayList<>(); @@ -155,10 +161,10 @@ public class AsyncKnowledgeBase { KnowledgeHitRateTestResultVO resultVO = new KnowledgeHitRateTestResultVO(); resultVO.setPageContent(pair.getDocument().getPageContent()); - DecimalFormat df = new DecimalFormat("0.00%"); - df.setRoundingMode(RoundingMode.HALF_UP); - String rateResult = df.format(pair.getHitRate()); - resultVO.setHitRate(rateResult); +// DecimalFormat df = new DecimalFormat("0.00%"); +// df.setRoundingMode(RoundingMode.HALF_UP); +// String rateResult = df.format(pair.getHitRate()); + resultVO.setHitRate(String.valueOf(pair.getHitRate())); resultVO.setDigest(pair.getDocument().getMetadata().getDigest()); long fileId = Long.parseLong(pair.getDocument().getMetadata().getFileId()); resultVO.setFileId(fileId); diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java index 66c7c2781..e34a2cb7e 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java @@ -47,14 +47,12 @@ import org.springframework.stereotype.Service; import javax.annotation.Resource; import java.io.*; -import java.math.RoundingMode; import java.net.URL; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardCopyOption; -import java.text.DecimalFormat; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -751,12 +749,29 @@ public class RagHttpService { .execute().body(); cn.hutool.core.lang.Console.log(result2); - log.info("请求参数: {}",JSON.toJSONString(jsonString)); + log.info("请求参数: {}",jsonString); log.info("请求结果: {}",JSON.toJSONString(result2)); - return parseHitRateTestResults(result2); + return parseHitRateTestResults(result2,vo.getScore()); } - private static List parseHitRateTestResults (String json) { + private static List parseHitRateTestResults (String json, Double score) { + boolean array= json.trim().startsWith("["); + // 先判断 JSON 是否是一个数组 + + if (!array){ + // 判断是否存在 detail 字段 + JSONObject jsonObject = JSON.parseObject(json); + if (jsonObject.containsKey("detail")) { + String detail = jsonObject.getString("detail"); + + if (detail.contains("No documents found for the given query")) { + throw exception(new ErrorCode(100_100_1, "未找到符合条件的文档,请检查查询条件!")); + } + return new ArrayList<>(); + } + } + + // 将 JSON 转换为 List // 解析 JSON 数组 JSONArray jsonArray = JSON.parseArray(json); @@ -775,11 +790,13 @@ public class RagHttpService { // 解析命中率 Double rate = pairArray.getDoubleValue(1); - // 创建 QueryResultPair 对象并添加到结果列表 - QueryResultPairVO pair = new QueryResultPairVO(); - pair.setDocument(document); - pair.setHitRate(rate); - results.add(pair); + if (rate >= score) { + QueryResultPairVO pair = new QueryResultPairVO(); + pair.setDocument(document); + pair.setHitRate(rate); + results.add(pair); + } + } // // 访问数据 diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/query/multiple/QueryMultipleReqVO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/query/multiple/QueryMultipleReqVO.java index 2f14cf62e..1252f1366 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/query/multiple/QueryMultipleReqVO.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/query/multiple/QueryMultipleReqVO.java @@ -29,4 +29,6 @@ public class QueryMultipleReqVO { */ // @NotNull(message = "k值不能为空") private Integer k; + + private Double score; } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/knowledgebase/KnowledgeBaseServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/knowledgebase/KnowledgeBaseServiceImpl.java index 77d3a7289..abff96751 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/knowledgebase/KnowledgeBaseServiceImpl.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/knowledgebase/KnowledgeBaseServiceImpl.java @@ -131,13 +131,13 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService { // 更新或插入文档数据 List newDocuments = updateOrInsertDocuments(documents, updateReqVO.getId(), updateObj.getKnowledgeLength()); - Map knowledgeParameters = new HashMap<>(); - knowledgeParameters.put("chunkSize",updateReqVO.getChunkSize()); - knowledgeParameters.put("chunkOverlap",updateReqVO.getChunkOverlap()); + Map knowledgeParameters = new HashMap<>(); + knowledgeParameters.put("chunkSize", updateReqVO.getChunkSize()); + knowledgeParameters.put("chunkOverlap", updateReqVO.getChunkOverlap()); // 异步处理新增文档和删除的文档 List deleteIds = knowledgeDocumentsMapper.selectDeleteIds(updateReqVO.getId()); - asyncKnowledgeBase.createKnowledgeBase(newDocuments, deleteIds,knowledgeParameters); + asyncKnowledgeBase.createKnowledgeBase(newDocuments, deleteIds, knowledgeParameters); } /** @@ -152,14 +152,15 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService { // 异步处理删除的文档 List deleteIds = knowledgeDocumentsMapper.selectDeleteIds(knowledgeBaseId); if (!CollectionUtils.isAnyEmpty(deleteIds)) { - asyncKnowledgeBase.createKnowledgeBase(new ArrayList<>(), deleteIds,new HashMap<>()); + asyncKnowledgeBase.createKnowledgeBase(new ArrayList<>(), deleteIds, new HashMap<>()); } } /** * 删除不需要保留的文档 + * * @param knowledgeBaseId 知识库 ID - * @param retainedIds 需要保留的文档 ID + * @param retainedIds 需要保留的文档 ID */ private void deleteUnretainedDocuments (Long knowledgeBaseId, List retainedIds) { LambdaQueryWrapperX deleteWrapper = new LambdaQueryWrapperX() @@ -172,7 +173,8 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService { /** * 更新或插入文档数据 - * @param documents 需要更新的文档数据 + * + * @param documents 需要更新的文档数据 * @param knowledgeBaseId 知识库 ID * @param chunkSize * @return 更新或插入的文档数据 @@ -310,14 +312,22 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService { if (baseDO == null) { throw exception(KNOWLEDGE_BASE_NOT_EXISTS); } - Integer topK=4; - if(baseDO.getTopK()==null||baseDO.getTopK()<=0){ + Integer topK = 4; + if (baseDO.getTopK() == null || baseDO.getTopK() <= 0) { testReqVO.setK(topK); - }else { - topK=baseDO.getTopK(); + } else { + topK = baseDO.getTopK(); testReqVO.setK(topK); } + Double score = 0.2; + if (baseDO.getScore() == null || baseDO.getTopK() <= 0.0|| baseDO.getScore() > 1) { + testReqVO.setScore(score); + } else { + score = baseDO.getScore(); + testReqVO.setScore(score); + } + // 根据知识库ID获取参数信息,关联文档 List documentsDOS = knowledgeDocumentsMapper.selectList(new LambdaQueryWrapper() .eq(KnowledgeDocumentsDO::getKnowledgeBaseId, knowledgeId)); @@ -330,9 +340,9 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService { .map(KnowledgeDocumentsDO::getFileId) .collect(Collectors.toList()); - List result = asyncKnowledgeBase.executeHitRateTest(testReqVO.getQuery(), fileIds, testReqVO.getK()); + List result = asyncKnowledgeBase.executeHitRateTest(testReqVO, fileIds); - if (com.baomidou.mybatisplus.core.toolkit.CollectionUtils.isEmpty(result)){ + if (com.baomidou.mybatisplus.core.toolkit.CollectionUtils.isEmpty(result)) { return Collections.emptyList(); }