Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
e513685b87
@ -103,6 +103,12 @@ public interface ErrorCodeConstants {
|
||||
|
||||
ErrorCode KNOWLEDGE_BASE_NAME_NOT_EXISTS = new ErrorCode(10040, "知识库名称已存在");
|
||||
|
||||
ErrorCode CHUNK_SIZE_MUST_BE_GREATER_THAN_ZERO = new ErrorCode(10040_1, "分块大小必须大于 0");
|
||||
|
||||
ErrorCode CHUNK_OVERLAP_MUST_BE_GREATER_THAN_OR_EQUAL_TO_ZERO = new ErrorCode(10040_2, "分块重叠必须大于或等于 0");
|
||||
|
||||
ErrorCode CHUNK_OVERLAP_MUST_BE_LESS_THAN_CHUNK_SIZE = new ErrorCode(10040_3, "分块重叠必须小于分块大小");
|
||||
|
||||
ErrorCode APPLICATION_NAME_NOT_EXISTS = new ErrorCode(10041, "应用中心名称已存在");
|
||||
|
||||
ErrorCode MODEL_SERVIC_ENAME_NOT_EXISTS = new ErrorCode(10043, "模型名称已存在");
|
||||
|
@ -109,27 +109,8 @@ public class ConversationController {
|
||||
@PostMapping("/stream-chat")
|
||||
public SseEmitter streamChat (@Valid @RequestBody ChatReqVO chatReqVO, HttpServletResponse response) {
|
||||
log.info("收到对话推理请求,请求参数: {}", chatReqVO);
|
||||
SseEmitter emitter = new SseEmitter(120_000L);
|
||||
// ExecutorService executor = Executors.newSingleThreadExecutor();
|
||||
// try {
|
||||
// executor.execute(() -> {
|
||||
// try {
|
||||
// conversationService.chatStream(chatReqVO, emitter, response);
|
||||
// } catch (Exception e) {
|
||||
// emitter.completeWithError(e);
|
||||
// } finally {
|
||||
// executor.shutdown();
|
||||
// }
|
||||
// });
|
||||
// } catch (Exception e) {
|
||||
// log.error("处理对话推理请求时发生异常", e);
|
||||
// try {
|
||||
// emitter.completeWithError(e);
|
||||
// } catch (Exception ex) {
|
||||
// log.error("无法完成 SseEmitter 错误处理", ex);
|
||||
// }
|
||||
// }
|
||||
// log.info("返回 SseEmitter 对象,准备进行流式响应");
|
||||
SseEmitter emitter = new SseEmitter(120_0000L);
|
||||
|
||||
// 异步处理,避免阻塞主线程
|
||||
CompletableFuture.runAsync(() -> {
|
||||
try {
|
||||
|
@ -11,5 +11,6 @@ import java.util.List;
|
||||
public class ParagraphHitRateListVO {
|
||||
private String uuid;
|
||||
private String groupId;
|
||||
private Boolean isExist;
|
||||
private List<ParagraphHitRateWordVO> wordList;
|
||||
}
|
||||
|
@ -6,9 +6,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.framework.excel.core.util.ExcelUtils;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBasePageReqVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseRespVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseSaveReqVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.*;
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgebase.KnowledgeBaseDO;
|
||||
import cn.iocoder.yudao.module.llm.service.knowledgebase.KnowledgeBaseService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
@ -43,6 +41,20 @@ public class KnowledgeBaseController {
|
||||
return success(knowledgeBaseService.createKnowledgeBase(createReqVO));
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行知识库命中测试
|
||||
*
|
||||
* @param testReqVO 命中测试请求参数
|
||||
* @return 命中测试结果
|
||||
*/
|
||||
@PostMapping("/hit-test")
|
||||
@Operation(summary = "执行知识库命中测试")
|
||||
public CommonResult<List<KnowledgeHitRateTestResultVO>> executeHitRateTest(
|
||||
@Valid @RequestBody KnowledgeHitRateTestReqVO testReqVO) {
|
||||
List<KnowledgeHitRateTestResultVO> result = knowledgeBaseService.executeHitRateTest(testReqVO);
|
||||
return success(result);
|
||||
}
|
||||
|
||||
@PutMapping("/update")
|
||||
@Operation(summary = "更新知识库")
|
||||
// @PreAuthorize("@ss.hasPermission('llm:knowledge-base:update')")
|
||||
|
@ -52,4 +52,15 @@ public class KnowledgeBaseRespVO {
|
||||
@Schema(description = "文件引用上传")
|
||||
private List<KnowledgeDocumentsRespVO> knowledgeDocuments;
|
||||
|
||||
}
|
||||
/**
|
||||
* 分块大小
|
||||
*/
|
||||
@Schema(description = "分块大小")
|
||||
private Integer chunkSize;
|
||||
|
||||
/**
|
||||
* 分块重叠
|
||||
*/
|
||||
@Schema(description = "分块重叠,")
|
||||
private Integer chunkOverlap;
|
||||
}
|
||||
|
@ -35,8 +35,15 @@ public class KnowledgeBaseSaveReqVO {
|
||||
/**
|
||||
* 分块大小
|
||||
*/
|
||||
@Schema(description = "分块大小")
|
||||
private Integer chunkSize;
|
||||
|
||||
/**
|
||||
* 分块重叠
|
||||
*/
|
||||
@Schema(description = "分块重叠,")
|
||||
private Integer chunkOverlap;
|
||||
|
||||
@Schema(description = "文件引用")
|
||||
private String knowledgeFile;
|
||||
|
||||
|
@ -0,0 +1,34 @@
|
||||
package cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import javax.validation.constraints.NotNull;
|
||||
|
||||
/**
|
||||
* @Description 知识库命中率测试请求参数
|
||||
*/
|
||||
@Data
|
||||
public class KnowledgeHitRateTestReqVO {
|
||||
/**
|
||||
* 查询内容
|
||||
*/
|
||||
@NotNull(message = "查询内容不能为空")
|
||||
private String query;
|
||||
|
||||
/**
|
||||
* 知识库ID
|
||||
*/
|
||||
@NotNull(message = "知识库ID不能为空")
|
||||
private Long knowledgeId;
|
||||
|
||||
/**
|
||||
* 返回结果的条数(k值)
|
||||
*/
|
||||
// @NotNull(message = "k值不能为空")
|
||||
private Integer k;
|
||||
|
||||
/**
|
||||
* Score阈值
|
||||
*/
|
||||
private Double score;
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
package cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* @Description 知识库命中率测试返回结果
|
||||
*/
|
||||
@Data
|
||||
public class KnowledgeHitRateTestResultVO {
|
||||
/**
|
||||
* 页面内容
|
||||
*/
|
||||
private String pageContent;
|
||||
|
||||
/**
|
||||
* 命中率
|
||||
*/
|
||||
private String hitRate;
|
||||
|
||||
/**
|
||||
* 摘要信息
|
||||
*/
|
||||
private String digest;
|
||||
|
||||
/**
|
||||
* 文件ID
|
||||
*/
|
||||
private Long fileId;
|
||||
|
||||
/**
|
||||
* 文件名称
|
||||
*/
|
||||
private String fileName;
|
||||
}
|
@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.*;
|
||||
|
||||
/**
|
||||
@ -45,7 +46,7 @@ public class KnowledgeBaseDO extends BaseDO {
|
||||
/**
|
||||
* Score阈值
|
||||
*/
|
||||
private Integer score;
|
||||
private Double score;
|
||||
/**
|
||||
* 知识长度
|
||||
*/
|
||||
@ -55,4 +56,13 @@ public class KnowledgeBaseDO extends BaseDO {
|
||||
*/
|
||||
private String knowledgeFile;
|
||||
|
||||
/**
|
||||
* 分块大小
|
||||
*/
|
||||
private Integer chunkSize;
|
||||
|
||||
/**
|
||||
* 分块重叠
|
||||
*/
|
||||
private Integer chunkOverlap;
|
||||
}
|
||||
|
@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.llm.service.async;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
|
||||
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestReqVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestResultVO;
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgedocuments.KnowledgeDocumentsDO;
|
||||
import cn.iocoder.yudao.module.llm.dal.mysql.knowledgedocuments.KnowledgeDocumentsMapper;
|
||||
import cn.iocoder.yudao.module.llm.enums.KnowledgeStatusEnum;
|
||||
@ -9,17 +11,22 @@ import cn.iocoder.yudao.module.llm.framework.backend.config.LLMBackendProperties
|
||||
import cn.iocoder.yudao.module.llm.service.http.RagHttpService;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.KnowledgeRagEmbedReqVO;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.RegUploadReqVO;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.query.multiple.QueryMultipleReqVO;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.query.multiple.QueryResultPairVO;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.URL;
|
||||
import java.math.RoundingMode;
|
||||
import java.text.DecimalFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
|
||||
@ -38,8 +45,8 @@ public class AsyncKnowledgeBase {
|
||||
|
||||
|
||||
// 向向量知识库创建文件
|
||||
// @Async
|
||||
public void createKnowledgeBase(List<KnowledgeDocumentsDO> knowledgeList, List<Long> ids) {
|
||||
// @Async
|
||||
public void createKnowledgeBase (List<KnowledgeDocumentsDO> knowledgeList, List<Long> ids, Map<String, Integer> knowledgeParameters) {
|
||||
log.info("开始执行 createKnowledgeBase 方法。knowledgeList 大小: {}, ids 大小: {}", knowledgeList.size(), ids.size());
|
||||
|
||||
// 如果提供了 ids,则删除现有的知识库文档
|
||||
@ -80,7 +87,7 @@ public class AsyncKnowledgeBase {
|
||||
if (lastIndex != -1) {
|
||||
String extension = knowledge.getDocumentName().substring(lastIndex + 1).toLowerCase();
|
||||
log.info("文档扩展名: {}", extension);
|
||||
knowledgeEmbed(knowledge, knowledge.getKnowledgeBaseId());
|
||||
knowledgeEmbed(knowledge, knowledge.getKnowledgeBaseId(), knowledgeParameters);
|
||||
} else {
|
||||
log.warn("文档无扩展名,跳过处理,文档 ID: {}", knowledge.getId());
|
||||
}
|
||||
@ -119,13 +126,15 @@ public class AsyncKnowledgeBase {
|
||||
* @param knowledge 文件
|
||||
* @param id 知识库id
|
||||
*/
|
||||
public void knowledgeEmbed (KnowledgeDocumentsDO knowledge, Long id) {
|
||||
public void knowledgeEmbed (KnowledgeDocumentsDO knowledge, Long id, Map<String, Integer> knowledgeParameters) {
|
||||
|
||||
// 创建知识向量
|
||||
KnowledgeRagEmbedReqVO ragEmbedReqVo = new KnowledgeRagEmbedReqVO()
|
||||
.setFileId(String.valueOf(knowledge.getId()))
|
||||
.setFileName(knowledge.getDocumentName())
|
||||
.setFileUrl(knowledge.getFileUrl());
|
||||
.setFileUrl(knowledge.getFileUrl())
|
||||
.setChunkSize(knowledgeParameters.get("chunkSize"))
|
||||
.setChunkOverlap(knowledgeParameters.get("chunkOverlap"));
|
||||
|
||||
try {
|
||||
ragHttpService.knowledgeEmbed(ragEmbedReqVo, id);
|
||||
@ -134,4 +143,42 @@ public class AsyncKnowledgeBase {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public List<KnowledgeHitRateTestResultVO> executeHitRateTest (KnowledgeHitRateTestReqVO testReqVO , List<Long> fileIds) {
|
||||
List<String> fileIdStr = fileIds.stream()
|
||||
.map(Object::toString)
|
||||
.collect(Collectors.toList());
|
||||
QueryMultipleReqVO vo = new QueryMultipleReqVO();
|
||||
vo.setQuery(testReqVO.getQuery());
|
||||
vo.setFileIds(fileIdStr);
|
||||
vo.setK(testReqVO.getK());
|
||||
vo.setScore(testReqVO.getScore());
|
||||
|
||||
List<KnowledgeHitRateTestResultVO> resultList = new ArrayList<>();
|
||||
|
||||
List<QueryResultPairVO> result = ragHttpService.executeHitRateTest(vo);
|
||||
for (QueryResultPairVO pair : result) {
|
||||
KnowledgeHitRateTestResultVO resultVO = new KnowledgeHitRateTestResultVO();
|
||||
resultVO.setPageContent(pair.getDocument().getPageContent());
|
||||
|
||||
// DecimalFormat df = new DecimalFormat("0.00%");
|
||||
// df.setRoundingMode(RoundingMode.HALF_UP);
|
||||
// String rateResult = df.format(pair.getHitRate());
|
||||
resultVO.setHitRate(String.valueOf(pair.getHitRate()));
|
||||
resultVO.setDigest(pair.getDocument().getMetadata().getDigest());
|
||||
long fileId = Long.parseLong(pair.getDocument().getMetadata().getFileId());
|
||||
resultVO.setFileId(fileId);
|
||||
|
||||
// 根据 fileId 查找文件名
|
||||
KnowledgeDocumentsDO documents = knowledgeDocumentsMapper.selectOne(KnowledgeDocumentsDO::getFileId, fileId);
|
||||
if (documents!=null && StringUtils.isNotBlank(documents.getDocumentName())){
|
||||
resultVO.setFileName(documents.getDocumentName());
|
||||
}else {
|
||||
resultVO.setFileName("未知文件");
|
||||
}
|
||||
|
||||
resultList.add(resultVO);
|
||||
}
|
||||
return resultList;
|
||||
}
|
||||
}
|
||||
|
@ -13,6 +13,9 @@ import cn.iocoder.yudao.module.llm.controller.admin.application.vo.ApplicationSa
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.ChatReqVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.datarefluxdata.vo.DataRefluxDataSaveReqVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseSaveReqVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestReqVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestResultVO;
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.basemodel.BaseModelDO;
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO;
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgedocuments.KnowledgeDocumentsDO;
|
||||
@ -26,6 +29,7 @@ import cn.iocoder.yudao.module.llm.service.basemodel.BaseModelService;
|
||||
import cn.iocoder.yudao.module.llm.service.datarefluxdata.DataRefluxDataService;
|
||||
import cn.iocoder.yudao.module.llm.service.http.ModelService;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.*;
|
||||
import cn.iocoder.yudao.module.llm.service.knowledgebase.KnowledgeBaseService;
|
||||
import cn.iocoder.yudao.module.llm.service.prompttemplates.PromptTemplatesService;
|
||||
import com.alibaba.excel.util.StringUtils;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
@ -45,6 +49,7 @@ import javax.servlet.http.HttpServletResponse;
|
||||
import java.math.RoundingMode;
|
||||
import java.text.DecimalFormat;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*;
|
||||
@ -80,6 +85,9 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
@Resource
|
||||
private LLMBackendProperties llmBackendProperties;
|
||||
|
||||
@Resource
|
||||
private KnowledgeBaseService knowledgeBaseService;
|
||||
|
||||
// 聊天会话历史记录缓存Key
|
||||
private final static String CHAT_HIStORY_REDIS_KEY = "llm:chat:history";
|
||||
// 聊天会话历史记录缓存时间
|
||||
@ -439,14 +447,16 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
// 处理 knowledgeBaseString
|
||||
if (StringUtils.isNotBlank(knowledgeBaseString)) {
|
||||
knowledgeBaseString = "<context>" + knowledgeBaseString + "</context>";
|
||||
}else {
|
||||
knowledgeBaseString = "<context>" + "</context>";
|
||||
}
|
||||
|
||||
// 处理 systemPrompt
|
||||
systemPrompt = StringUtils.isBlank(chatReqVO.getSystemPrompt())
|
||||
? PROMPT
|
||||
: chatReqVO.getSystemPrompt() + "\n" + PROMPT;
|
||||
: chatReqVO.getSystemPrompt() + " \n " + PROMPT;
|
||||
}
|
||||
String mess = systemPrompt + knowledgeBaseString;
|
||||
String mess = systemPrompt + " \n "+knowledgeBaseString;
|
||||
|
||||
// // 查询历史记录消息,并将查询出来的知识信息放入到 role = system 的消息中
|
||||
// List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
|
||||
@ -532,30 +542,14 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
ParagraphHitRateListVO paragraphHitRateListVO = new ParagraphHitRateListVO();
|
||||
paragraphHitRateListVO.setUuid(chatReqVO.getUuid());
|
||||
paragraphHitRateListVO.setGroupId(chatReqVO.getGroupId());
|
||||
List<ParagraphHitRateWordVO> words = new ArrayList<>();
|
||||
// 2. 遍历处理每个文档
|
||||
for (KnowledgeDocumentsDO document : documentList) {
|
||||
ParagraphHitRateWordVO rateWordVO = processDocument(document, chatReqVO, knowledgeBase);
|
||||
if (rateWordVO != null) {
|
||||
words.add(rateWordVO);
|
||||
}
|
||||
|
||||
}
|
||||
if (CollectionUtils.isEmpty(words)) {
|
||||
paragraphHitRateListVO.setWordList(Collections.emptyList());
|
||||
paragraphHitRateListVO.setGroupId("");
|
||||
} else {
|
||||
paragraphHitRateListVO.setWordList(words);
|
||||
}
|
||||
KnowledgeHitRateTestReqVO testReqVO=new KnowledgeHitRateTestReqVO();
|
||||
testReqVO.setKnowledgeId(chatReqVO.getKnowledge());
|
||||
testReqVO.setQuery(chatReqVO.getPrompt());
|
||||
|
||||
// 请求结果添加到 Redis,查询段落命中率
|
||||
String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, chatReqVO.getUuid());
|
||||
stringRedisTemplate.opsForList().rightPush(redisKey, JSON.toJSONString(paragraphHitRateListVO));
|
||||
List<KnowledgeHitRateTestResultVO> result = knowledgeBaseService.executeHitRateTest(testReqVO);
|
||||
knowledgeBase = handlerResult(result, paragraphHitRateListVO);
|
||||
|
||||
List<String> paragraphHitRateList = stringRedisTemplate.opsForList().range(redisKey, 0, -1);
|
||||
if (paragraphHitRateList != null && !paragraphHitRateList.isEmpty()) {
|
||||
log.info("{} 知识库查询段落命中率: {}", "[KnowledgeBase]", paragraphHitRateList);
|
||||
}
|
||||
|
||||
log.info("{} 知识库构建完成,内容长度: {}", LOG_PREFIX, knowledgeBase.length());
|
||||
} catch (Exception e) {
|
||||
@ -567,6 +561,74 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
return knowledgeBase;
|
||||
}
|
||||
|
||||
private StringBuilder handlerResult (List<KnowledgeHitRateTestResultVO> result, ParagraphHitRateListVO paragraphHitRateListVO) {
|
||||
if (CollectionUtils.isEmpty(result)){
|
||||
return new StringBuilder();
|
||||
}
|
||||
|
||||
// 1: 存储到redis
|
||||
saveRedis(result, paragraphHitRateListVO);
|
||||
|
||||
// 2: 组成返回数据
|
||||
StringBuilder knowledgeBase = new StringBuilder();
|
||||
result.forEach(item -> {
|
||||
knowledgeBase.append(item.getPageContent());
|
||||
});
|
||||
return knowledgeBase;
|
||||
}
|
||||
|
||||
private void saveRedis (List<KnowledgeHitRateTestResultVO> result, ParagraphHitRateListVO paragraphHitRateListVO) {
|
||||
if (CollectionUtils.isEmpty(result)){
|
||||
return;
|
||||
}
|
||||
List<ParagraphHitRateWordVO> words = new ArrayList<>();
|
||||
|
||||
// 按照fileId分组,存到Map中
|
||||
Map<Long, List<KnowledgeHitRateTestResultVO>> groupedByFileId = result.stream()
|
||||
.collect(Collectors.groupingBy(KnowledgeHitRateTestResultVO::getFileId));
|
||||
|
||||
// 遍历Map,查看分组结果
|
||||
groupedByFileId.forEach((fileId, list) -> {
|
||||
System.out.println("File ID: " + fileId);
|
||||
list.forEach(i->{
|
||||
ParagraphHitRateWordVO rateWordVO = new ParagraphHitRateWordVO();
|
||||
// 设置文档名称
|
||||
rateWordVO.setDocumentName(i.getFileName());
|
||||
|
||||
// 设置段落命中率
|
||||
List<ParagraphHitRateVO> paragraphHitRate=new ArrayList<>();
|
||||
for (KnowledgeHitRateTestResultVO i1 : list) {
|
||||
ParagraphHitRateVO rateVO = new ParagraphHitRateVO();
|
||||
rateVO.setParagraph(i1.getPageContent());
|
||||
rateVO.setHitRate(i1.getHitRate());
|
||||
rateVO.setWordCount(i1.getPageContent().length());
|
||||
paragraphHitRate.add(rateVO);
|
||||
}
|
||||
|
||||
rateWordVO.setParagraphHitRate(paragraphHitRate);
|
||||
|
||||
words.add(rateWordVO);
|
||||
});
|
||||
});
|
||||
|
||||
if (CollectionUtils.isEmpty(words)) {
|
||||
paragraphHitRateListVO.setWordList(Collections.emptyList());
|
||||
paragraphHitRateListVO.setIsExist(false);
|
||||
} else {
|
||||
paragraphHitRateListVO.setWordList(words);
|
||||
paragraphHitRateListVO.setIsExist(true);
|
||||
}
|
||||
|
||||
// 请求结果添加到 Redis,查询段落命中率
|
||||
String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, paragraphHitRateListVO.getUuid());
|
||||
stringRedisTemplate.opsForList().rightPush(redisKey, JSON.toJSONString(paragraphHitRateListVO));
|
||||
|
||||
List<String> paragraphHitRateList = stringRedisTemplate.opsForList().range(redisKey, 0, -1);
|
||||
if (paragraphHitRateList != null && !paragraphHitRateList.isEmpty()) {
|
||||
log.info("{} 知识库查询段落命中率: {}", "[KnowledgeBase]", paragraphHitRateList);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理单个知识库文档的检索逻辑
|
||||
*/
|
||||
|
@ -1,6 +1,7 @@
|
||||
package cn.iocoder.yudao.module.llm.service.http;
|
||||
|
||||
|
||||
import cn.hutool.http.Header;
|
||||
import cn.hutool.http.HttpRequest;
|
||||
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
|
||||
import cn.iocoder.yudao.framework.common.util.http.HttpUtils;
|
||||
@ -11,6 +12,9 @@ import cn.iocoder.yudao.module.llm.dal.mysql.knowledgedocuments.KnowledgeDocumen
|
||||
import cn.iocoder.yudao.module.llm.enums.KnowledgeStatusEnum;
|
||||
import cn.iocoder.yudao.module.llm.framework.backend.config.LLMBackendProperties;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.*;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.query.multiple.DocumentInfoVO;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.query.multiple.QueryMultipleReqVO;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.query.multiple.QueryResultPairVO;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONException;
|
||||
@ -49,6 +53,7 @@ import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
@ -390,9 +395,10 @@ public class RagHttpService {
|
||||
String fileName = reqVO.getFileName();
|
||||
String fileUrl = reqVO.getFileUrl();
|
||||
Integer chunkSize = Optional.ofNullable(reqVO.getChunkSize()).orElse(1500);
|
||||
Integer chunkOverlap = Optional.ofNullable(reqVO.getChunkOverlap()).orElse(300);
|
||||
String mediaType = getMediaType(fileName);
|
||||
|
||||
log.info("文件ID: {}, 文件名: {}, 文件URL: {}, 文件类型: {}, 分块大小:{}", fileId, fileName, fileUrl, mediaType,chunkSize);
|
||||
log.info("文件ID: {}, 文件名: {}, 文件URL: {}, 文件类型: {}, 分块大小:{}, 分块重叠:{}", fileId, fileName, fileUrl, mediaType, chunkSize, chunkOverlap);
|
||||
|
||||
// 获取知识库文档
|
||||
log.info("开始获取知识库文档,知识库ID: {}, 文件ID: {}", id, fileId);
|
||||
@ -417,15 +423,15 @@ public class RagHttpService {
|
||||
Path tempFilePath = downloadFileToTemp(fileUrl, fileName);
|
||||
log.info("文件已下载到临时目录: {}", tempFilePath);
|
||||
|
||||
// String fileSuffix = getFileSuffix(fileName);
|
||||
// if ("doc".equals(fileSuffix)) {
|
||||
// log.info("正在处理 doc 文件");
|
||||
// try {
|
||||
// tempFilePath = converterDocToDocx(tempFilePath.toString(), tempFilePath.toString().replace(".doc", ".docx"));
|
||||
// } catch (Exception e) {
|
||||
// throw new RuntimeException(e);
|
||||
// }
|
||||
// }
|
||||
// String fileSuffix = getFileSuffix(fileName);
|
||||
// if ("doc".equals(fileSuffix)) {
|
||||
// log.info("正在处理 doc 文件");
|
||||
// try {
|
||||
// tempFilePath = converterDocToDocx(tempFilePath.toString(), tempFilePath.toString().replace(".doc", ".docx"));
|
||||
// } catch (Exception e) {
|
||||
// throw new RuntimeException(e);
|
||||
// }
|
||||
// }
|
||||
|
||||
// if ("md".equals(fileSuffix)) {
|
||||
// log.info("正在处理 md 文件");
|
||||
@ -450,6 +456,7 @@ public class RagHttpService {
|
||||
.setType(MultipartBody.FORM)
|
||||
.addFormDataPart("file_id", fileId)
|
||||
.addFormDataPart("chunk_size", String.valueOf(chunkSize))
|
||||
.addFormDataPart("chunk_overlap", String.valueOf(chunkOverlap))
|
||||
.addFormDataPart("file", fileName,
|
||||
RequestBody.create(tempFilePath.toFile(), MediaType.parse(mediaType))
|
||||
)
|
||||
@ -728,4 +735,101 @@ public class RagHttpService {
|
||||
private KnowledgeDocumentsDO getKnowledgeDocuments (String fileId) {
|
||||
return knowledgeDocumentsMapper.selectById(fileId);
|
||||
}
|
||||
|
||||
public List<QueryResultPairVO> executeHitRateTest (QueryMultipleReqVO vo) {
|
||||
|
||||
String jsonString = JSON.toJSONString(vo);
|
||||
String url = llmBackendProperties.getRagQueryMultiple();
|
||||
//链式构建请求
|
||||
String result2 = HttpRequest.post(url)
|
||||
.header(Header.ACCEPT, "application/json")
|
||||
.header(Header.CONTENT_TYPE, "application/json")
|
||||
.body(jsonString)
|
||||
.timeout(20000)
|
||||
.execute().body();
|
||||
cn.hutool.core.lang.Console.log(result2);
|
||||
|
||||
log.info("请求参数: {}",jsonString);
|
||||
log.info("请求结果: {}",JSON.toJSONString(result2));
|
||||
return parseHitRateTestResults(result2,vo.getScore());
|
||||
}
|
||||
|
||||
private static List<QueryResultPairVO> parseHitRateTestResults (String json, Double score) {
|
||||
boolean array= json.trim().startsWith("[");
|
||||
// 先判断 JSON 是否是一个数组
|
||||
|
||||
if (!array){
|
||||
// 判断是否存在 detail 字段
|
||||
JSONObject jsonObject = JSON.parseObject(json);
|
||||
if (jsonObject.containsKey("detail")) {
|
||||
String detail = jsonObject.getString("detail");
|
||||
|
||||
if (detail.contains("No documents found for the given query")) {
|
||||
throw exception(new ErrorCode(100_100_1, "未找到符合条件的文档,请检查查询条件!"));
|
||||
}
|
||||
return new ArrayList<>();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 将 JSON 转换为 List<QueryResultPair>
|
||||
// 解析 JSON 数组
|
||||
JSONArray jsonArray = JSON.parseArray(json);
|
||||
|
||||
// 创建结果列表
|
||||
List<QueryResultPairVO> results = new ArrayList<>();
|
||||
|
||||
// 遍历 JSON 数组
|
||||
for (int i = 0; i < jsonArray.size(); i++) {
|
||||
JSONArray pairArray = jsonArray.getJSONArray(i);
|
||||
|
||||
// 解析文档信息
|
||||
JSONObject documentJson = pairArray.getJSONObject(0);
|
||||
DocumentInfoVO document = JSON.parseObject(documentJson.toJSONString(), DocumentInfoVO.class);
|
||||
|
||||
// 解析命中率
|
||||
Double rate = pairArray.getDoubleValue(1);
|
||||
|
||||
if (rate >= score) {
|
||||
QueryResultPairVO pair = new QueryResultPairVO();
|
||||
pair.setDocument(document);
|
||||
pair.setHitRate(rate);
|
||||
results.add(pair);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// // 访问数据
|
||||
// for (QueryResultPairVO pair : results) {
|
||||
// System.out.println("Page Content: " + pair.getDocument().getPageContent());
|
||||
// System.out.println("Hit Rate: " + pair.getHitRate());
|
||||
// System.out.println("File ID: " + pair.getDocument().getMetadata().getFileId());
|
||||
// System.out.println("----------------------");
|
||||
// }
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
public static void main (String[] args) {
|
||||
List<String> ids = new ArrayList<>();
|
||||
ids.add("1111");
|
||||
ids.add("1234");
|
||||
QueryMultipleReqVO vo = new QueryMultipleReqVO();
|
||||
vo.setQuery("可乐鸡翅怎么做");
|
||||
vo.setFileIds(ids);
|
||||
vo.setK(4);
|
||||
String jsonString = JSON.toJSONString(vo);
|
||||
String url = "http://192.168.18.66:8123/query_multiple";
|
||||
//链式构建请求
|
||||
String result2 = HttpRequest.post(url)
|
||||
.header(Header.ACCEPT, "application/json")
|
||||
.header(Header.CONTENT_TYPE, "application/json")
|
||||
.body(jsonString)
|
||||
.timeout(20000)
|
||||
.execute().body();
|
||||
cn.hutool.core.lang.Console.log(result2);
|
||||
// extracted(result2);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
@ -46,4 +46,9 @@ public class KnowledgeRagEmbedReqVO {
|
||||
* 分块大小
|
||||
*/
|
||||
private Integer chunkSize;
|
||||
|
||||
/**
|
||||
* 分块重叠
|
||||
*/
|
||||
private Integer chunkOverlap;
|
||||
}
|
||||
|
@ -0,0 +1,29 @@
|
||||
package cn.iocoder.yudao.module.llm.service.http.vo.query.multiple;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* @Description 文档信息类
|
||||
*/
|
||||
@Data
|
||||
public class DocumentInfoVO {
|
||||
/**
|
||||
* 文档ID(可为空)
|
||||
*/
|
||||
private String id;
|
||||
|
||||
/**
|
||||
* 元数据
|
||||
*/
|
||||
private MetadataVO metadata;
|
||||
|
||||
/**
|
||||
* 页面内容
|
||||
*/
|
||||
private String pageContent;
|
||||
|
||||
/**
|
||||
* 文档类型
|
||||
*/
|
||||
private String type;
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
package cn.iocoder.yudao.module.llm.service.http.vo.query.multiple;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* @Description 文档元数据类
|
||||
*/
|
||||
@Data
|
||||
public class MetadataVO {
|
||||
/**
|
||||
* 文件ID
|
||||
*/
|
||||
private String fileId;
|
||||
|
||||
/**
|
||||
* 用户ID
|
||||
*/
|
||||
private String userId;
|
||||
|
||||
/**
|
||||
* 文件摘要
|
||||
*/
|
||||
private String digest;
|
||||
|
||||
/**
|
||||
* 文件来源路径
|
||||
*/
|
||||
private String source;
|
||||
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
package cn.iocoder.yudao.module.llm.service.http.vo.query.multiple;
|
||||
|
||||
import com.alibaba.fastjson.annotation.JSONField;
|
||||
import lombok.Data;
|
||||
|
||||
import javax.validation.constraints.NotNull;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @Description 知识库多文件查询
|
||||
*/
|
||||
@Data
|
||||
public class QueryMultipleReqVO {
|
||||
/**
|
||||
* 查询内容
|
||||
*/
|
||||
@NotNull(message = "查询内容不能为空")
|
||||
private String query;
|
||||
|
||||
/**
|
||||
* 文件ID列表
|
||||
*/
|
||||
@NotNull(message = "文件ID列表不能为空")
|
||||
@JSONField(name = "file_ids")
|
||||
private List<String> fileIds;
|
||||
|
||||
/**
|
||||
* 返回结果的条数(k值)
|
||||
*/
|
||||
// @NotNull(message = "k值不能为空")
|
||||
private Integer k;
|
||||
|
||||
private Double score;
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
package cn.iocoder.yudao.module.llm.service.http.vo.query.multiple;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* @Description 查询结果对(包含文档信息和命中率)
|
||||
*/
|
||||
@Data
|
||||
public class QueryResultPairVO {
|
||||
/**
|
||||
* 文档信息
|
||||
*/
|
||||
private DocumentInfoVO document;
|
||||
|
||||
/**
|
||||
* 命中率
|
||||
*/
|
||||
private Double hitRate;
|
||||
}
|
@ -1,9 +1,7 @@
|
||||
package cn.iocoder.yudao.module.llm.service.knowledgebase;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBasePageReqVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseRespVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseSaveReqVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.*;
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgebase.KnowledgeBaseDO;
|
||||
|
||||
import javax.validation.Valid;
|
||||
@ -70,4 +68,11 @@ public interface KnowledgeBaseService {
|
||||
* @param updateReqVO 更新信息
|
||||
*/
|
||||
void updateKnowledgeBaseInfo (@Valid KnowledgeBaseSaveReqVO updateReqVO);
|
||||
|
||||
/**
|
||||
* 执行知识库命中测试
|
||||
* @param testReqVO 测试信息
|
||||
* @return 返回结果
|
||||
*/
|
||||
List<KnowledgeHitRateTestResultVO> executeHitRateTest (@Valid KnowledgeHitRateTestReqVO testReqVO);
|
||||
}
|
||||
|
@ -5,9 +5,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBasePageReqVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseRespVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseSaveReqVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.*;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgedocuments.vo.KnowledgeDocumentsRespVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.knowledgedocuments.vo.KnowledgeDocumentsSaveReqVO;
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgebase.KnowledgeBaseDO;
|
||||
@ -17,24 +15,17 @@ import cn.iocoder.yudao.module.llm.dal.mysql.knowledgedocuments.KnowledgeDocumen
|
||||
import cn.iocoder.yudao.module.llm.service.application.ApplicationService;
|
||||
import cn.iocoder.yudao.module.llm.service.async.AsyncKnowledgeBase;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||
import kong.unirest.Unirest;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import javax.annotation.Tainted;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.KNOWLEDGE_BASE_NAME_NOT_EXISTS;
|
||||
import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.KNOWLEDGE_BASE_NOT_EXISTS;
|
||||
import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*;
|
||||
|
||||
/**
|
||||
* 知识库 Service 实现类
|
||||
@ -66,68 +57,139 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
|
||||
}
|
||||
|
||||
@Override
|
||||
// @Transactional(rollbackFor = Exception.class)
|
||||
// @Transactional(rollbackFor = Exception.class)
|
||||
public void updateKnowledgeBase (KnowledgeBaseSaveReqVO updateReqVO) {
|
||||
// 1. 校验知识库是否存在
|
||||
validateKnowledgeParam(updateReqVO);
|
||||
|
||||
// 2. 更新知识库主表基础信息
|
||||
KnowledgeBaseDO updateObj = BeanUtils.toBean(updateReqVO, KnowledgeBaseDO.class);
|
||||
knowledgeBaseMapper.updateById(updateObj);
|
||||
|
||||
// 3. 处理附表(知识文档)数据
|
||||
handleKnowledgeDocuments(updateReqVO, updateObj);
|
||||
}
|
||||
|
||||
/**
|
||||
* 校验知识库参数
|
||||
*
|
||||
* @param updateReqVO 更新知识库参数
|
||||
*/
|
||||
private void validateKnowledgeParam (KnowledgeBaseSaveReqVO updateReqVO) {
|
||||
// 1. 校验知识库是否存在
|
||||
validateKnowledgeBaseExists(updateReqVO.getId());
|
||||
|
||||
// 2. 校验知识库名称是否重复
|
||||
validateKnowledgeBaseNameExists(updateReqVO);
|
||||
|
||||
// 3. 更新知识库主表
|
||||
KnowledgeBaseDO updateObj = BeanUtils.toBean(updateReqVO, KnowledgeBaseDO.class);
|
||||
knowledgeBaseMapper.updateById(updateObj);
|
||||
// 3. 校验分块大小和分块重叠是否正确
|
||||
validateChunkParameters(updateReqVO.getChunkSize(), updateReqVO.getChunkOverlap());
|
||||
}
|
||||
|
||||
// Unirest.config().reset();
|
||||
// Unirest.config()
|
||||
// .socketTimeout(86400000)
|
||||
// .connectTimeout(100000)
|
||||
// .concurrency(10, 5)
|
||||
// .setDefaultHeader("Accept", "application/json");
|
||||
/**
|
||||
* 校验分块大小和分块重叠是否合法
|
||||
*
|
||||
* @param chunkSize 分块大小
|
||||
* @param chunkOverlap 分块重叠
|
||||
* @throws IllegalArgumentException 如果校验不通过
|
||||
*/
|
||||
private void validateChunkParameters (int chunkSize, int chunkOverlap) {
|
||||
if (chunkSize < 1) {
|
||||
throw exception(CHUNK_SIZE_MUST_BE_GREATER_THAN_ZERO);
|
||||
}
|
||||
if (chunkOverlap < 0) {
|
||||
throw exception(CHUNK_OVERLAP_MUST_BE_GREATER_THAN_OR_EQUAL_TO_ZERO);
|
||||
}
|
||||
if (chunkOverlap >= chunkSize) {
|
||||
throw exception(CHUNK_OVERLAP_MUST_BE_LESS_THAN_CHUNK_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 处理附表(知识文档)数据
|
||||
if (!CollectionUtils.isAnyEmpty(updateReqVO.getKnowledgeDocuments())) {
|
||||
// 4.1 获取需要保留的文档 ID
|
||||
List<Long> retainedIds = updateReqVO.getKnowledgeDocuments().stream()
|
||||
.map(KnowledgeDocumentsSaveReqVO::getId)
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// 4.2 删除不需要保留的文档
|
||||
LambdaQueryWrapperX<KnowledgeDocumentsDO> deleteWrapper = new LambdaQueryWrapperX<KnowledgeDocumentsDO>()
|
||||
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, updateReqVO.getId());
|
||||
if (!CollectionUtils.isAnyEmpty(retainedIds)) {
|
||||
deleteWrapper.notIn(KnowledgeDocumentsDO::getId, retainedIds);
|
||||
}
|
||||
knowledgeDocumentsMapper.delete(deleteWrapper);
|
||||
|
||||
// 4.3 更新或插入文档数据
|
||||
List<KnowledgeDocumentsDO> newDocuments = new ArrayList<>();
|
||||
updateReqVO.getKnowledgeDocuments().forEach(doc -> {
|
||||
KnowledgeDocumentsDO docDO = BeanUtils.toBean(doc, KnowledgeDocumentsDO.class);
|
||||
docDO.setKnowledgeBaseId(updateReqVO.getId());
|
||||
docDO.setChunkSize(updateObj.getKnowledgeLength());
|
||||
if (doc.getId() == null) {
|
||||
newDocuments.add(docDO); // 收集新增文档
|
||||
}
|
||||
knowledgeDocumentsMapper.insertOrUpdate(docDO); // 更新或插入文档
|
||||
});
|
||||
|
||||
// 4.4 异步处理新增文档和删除的文档
|
||||
List<Long> deleteIds = knowledgeDocumentsMapper.selectDeleteIds(updateReqVO.getId());
|
||||
asyncKnowledgeBase.createKnowledgeBase(newDocuments, deleteIds);
|
||||
} else {
|
||||
// 5. 如果传入的文档列表为空,则删除所有关联文档
|
||||
knowledgeDocumentsMapper.delete(new LambdaQueryWrapperX<KnowledgeDocumentsDO>()
|
||||
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, updateReqVO.getId()));
|
||||
|
||||
// 5.1 异步处理删除的文档
|
||||
List<Long> deleteIds = knowledgeDocumentsMapper.selectDeleteIds(updateReqVO.getId());
|
||||
if (!CollectionUtils.isAnyEmpty(deleteIds)) {
|
||||
asyncKnowledgeBase.createKnowledgeBase(null, deleteIds);
|
||||
}
|
||||
/**
|
||||
* 处理知识文档数据
|
||||
*
|
||||
* @param updateReqVO 更新知识库参数
|
||||
* @param updateObj 更新知识库对象
|
||||
*/
|
||||
private void handleKnowledgeDocuments (KnowledgeBaseSaveReqVO updateReqVO, KnowledgeBaseDO updateObj) {
|
||||
List<KnowledgeDocumentsSaveReqVO> documents = updateReqVO.getKnowledgeDocuments();
|
||||
if (CollectionUtils.isAnyEmpty(documents)) {
|
||||
// 如果传入的文档列表为空,则删除所有关联文档
|
||||
deleteAllDocuments(updateReqVO.getId());
|
||||
return;
|
||||
}
|
||||
|
||||
// 获取需要保留的文档 ID
|
||||
List<Long> retainedIds = documents.stream()
|
||||
.map(KnowledgeDocumentsSaveReqVO::getId)
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// 删除不需要保留的文档
|
||||
deleteUnretainedDocuments(updateReqVO.getId(), retainedIds);
|
||||
|
||||
// 更新或插入文档数据
|
||||
List<KnowledgeDocumentsDO> newDocuments = updateOrInsertDocuments(documents, updateReqVO.getId(), updateObj.getKnowledgeLength());
|
||||
|
||||
Map<String, Integer> knowledgeParameters = new HashMap<>();
|
||||
knowledgeParameters.put("chunkSize", updateReqVO.getChunkSize());
|
||||
knowledgeParameters.put("chunkOverlap", updateReqVO.getChunkOverlap());
|
||||
|
||||
// 异步处理新增文档和删除的文档
|
||||
List<Long> deleteIds = knowledgeDocumentsMapper.selectDeleteIds(updateReqVO.getId());
|
||||
asyncKnowledgeBase.createKnowledgeBase(newDocuments, deleteIds, knowledgeParameters);
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除所有关联的文档
|
||||
*
|
||||
* @param knowledgeBaseId 知识库 ID
|
||||
*/
|
||||
private void deleteAllDocuments (Long knowledgeBaseId) {
|
||||
knowledgeDocumentsMapper.delete(new LambdaQueryWrapperX<KnowledgeDocumentsDO>()
|
||||
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, knowledgeBaseId));
|
||||
|
||||
// 异步处理删除的文档
|
||||
List<Long> deleteIds = knowledgeDocumentsMapper.selectDeleteIds(knowledgeBaseId);
|
||||
if (!CollectionUtils.isAnyEmpty(deleteIds)) {
|
||||
asyncKnowledgeBase.createKnowledgeBase(new ArrayList<>(), deleteIds, new HashMap<>());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除不需要保留的文档
|
||||
*
|
||||
* @param knowledgeBaseId 知识库 ID
|
||||
* @param retainedIds 需要保留的文档 ID
|
||||
*/
|
||||
private void deleteUnretainedDocuments (Long knowledgeBaseId, List<Long> retainedIds) {
|
||||
LambdaQueryWrapperX<KnowledgeDocumentsDO> deleteWrapper = new LambdaQueryWrapperX<KnowledgeDocumentsDO>()
|
||||
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, knowledgeBaseId);
|
||||
if (!CollectionUtils.isAnyEmpty(retainedIds)) {
|
||||
deleteWrapper.notIn(KnowledgeDocumentsDO::getId, retainedIds);
|
||||
}
|
||||
knowledgeDocumentsMapper.delete(deleteWrapper);
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新或插入文档数据
|
||||
*
|
||||
* @param documents 需要更新的文档数据
|
||||
* @param knowledgeBaseId 知识库 ID
|
||||
* @param chunkSize
|
||||
* @return 更新或插入的文档数据
|
||||
*/
|
||||
private List<KnowledgeDocumentsDO> updateOrInsertDocuments (List<KnowledgeDocumentsSaveReqVO> documents, Long knowledgeBaseId, Integer chunkSize) {
|
||||
List<KnowledgeDocumentsDO> newDocuments = new ArrayList<>();
|
||||
documents.forEach(doc -> {
|
||||
KnowledgeDocumentsDO docDO = BeanUtils.toBean(doc, KnowledgeDocumentsDO.class);
|
||||
docDO.setKnowledgeBaseId(knowledgeBaseId);
|
||||
if (doc.getId() == null) {
|
||||
newDocuments.add(docDO); // 收集新增文档
|
||||
}
|
||||
knowledgeDocumentsMapper.insertOrUpdate(docDO); // 更新或插入文档
|
||||
});
|
||||
return newDocuments;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -236,6 +298,57 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
|
||||
knowledgeBaseMapper.updateById(knowledgeBaseDO);
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行知识库命中测试
|
||||
*
|
||||
* @param testReqVO 测试信息
|
||||
* @return 返回结果
|
||||
*/
|
||||
@Override
|
||||
public List<KnowledgeHitRateTestResultVO> executeHitRateTest (KnowledgeHitRateTestReqVO testReqVO) {
|
||||
Long knowledgeId = testReqVO.getKnowledgeId();
|
||||
|
||||
KnowledgeBaseDO baseDO = knowledgeBaseMapper.selectOne(KnowledgeBaseDO::getId, knowledgeId);
|
||||
if (baseDO == null) {
|
||||
throw exception(KNOWLEDGE_BASE_NOT_EXISTS);
|
||||
}
|
||||
Integer topK = 4;
|
||||
if (baseDO.getTopK() == null || baseDO.getTopK() <= 0) {
|
||||
testReqVO.setK(topK);
|
||||
} else {
|
||||
topK = baseDO.getTopK();
|
||||
testReqVO.setK(topK);
|
||||
}
|
||||
|
||||
Double score = 0.2;
|
||||
if (baseDO.getScore() == null || baseDO.getTopK() <= 0.0|| baseDO.getScore() > 1) {
|
||||
testReqVO.setScore(score);
|
||||
} else {
|
||||
score = baseDO.getScore();
|
||||
testReqVO.setScore(score);
|
||||
}
|
||||
|
||||
// 根据知识库ID获取参数信息,关联文档
|
||||
List<KnowledgeDocumentsDO> documentsDOS = knowledgeDocumentsMapper.selectList(new LambdaQueryWrapper<KnowledgeDocumentsDO>()
|
||||
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, knowledgeId));
|
||||
if (com.baomidou.mybatisplus.core.toolkit.CollectionUtils.isEmpty(documentsDOS)) {
|
||||
throw exception(KNOWLEDGE_DOCUMENTS_NOT_EXISTS);
|
||||
}
|
||||
|
||||
// 获取fileId列表
|
||||
List<Long> fileIds = documentsDOS.stream()
|
||||
.map(KnowledgeDocumentsDO::getFileId)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<KnowledgeHitRateTestResultVO> result = asyncKnowledgeBase.executeHitRateTest(testReqVO, fileIds);
|
||||
|
||||
if (com.baomidou.mybatisplus.core.toolkit.CollectionUtils.isEmpty(result)) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 校验知识库是否存在
|
||||
*
|
||||
|
Loading…
x
Reference in New Issue
Block a user