Merge remote-tracking branch 'origin/master'

This commit is contained in:
leon 2025-03-13 17:07:06 +08:00
commit e513685b87
19 changed files with 678 additions and 134 deletions

View File

@ -103,6 +103,12 @@ public interface ErrorCodeConstants {
ErrorCode KNOWLEDGE_BASE_NAME_NOT_EXISTS = new ErrorCode(10040, "知识库名称已存在");
ErrorCode CHUNK_SIZE_MUST_BE_GREATER_THAN_ZERO = new ErrorCode(10040_1, "分块大小必须大于 0");
ErrorCode CHUNK_OVERLAP_MUST_BE_GREATER_THAN_OR_EQUAL_TO_ZERO = new ErrorCode(10040_2, "分块重叠必须大于或等于 0");
ErrorCode CHUNK_OVERLAP_MUST_BE_LESS_THAN_CHUNK_SIZE = new ErrorCode(10040_3, "分块重叠必须小于分块大小");
ErrorCode APPLICATION_NAME_NOT_EXISTS = new ErrorCode(10041, "应用中心名称已存在");
ErrorCode MODEL_SERVIC_ENAME_NOT_EXISTS = new ErrorCode(10043, "模型名称已存在");

View File

@ -109,27 +109,8 @@ public class ConversationController {
@PostMapping("/stream-chat")
public SseEmitter streamChat (@Valid @RequestBody ChatReqVO chatReqVO, HttpServletResponse response) {
log.info("收到对话推理请求,请求参数: {}", chatReqVO);
SseEmitter emitter = new SseEmitter(120_000L);
// ExecutorService executor = Executors.newSingleThreadExecutor();
// try {
// executor.execute(() -> {
// try {
// conversationService.chatStream(chatReqVO, emitter, response);
// } catch (Exception e) {
// emitter.completeWithError(e);
// } finally {
// executor.shutdown();
// }
// });
// } catch (Exception e) {
// log.error("处理对话推理请求时发生异常", e);
// try {
// emitter.completeWithError(e);
// } catch (Exception ex) {
// log.error("无法完成 SseEmitter 错误处理", ex);
// }
// }
// log.info("返回 SseEmitter 对象,准备进行流式响应");
SseEmitter emitter = new SseEmitter(120_0000L);
// 异步处理避免阻塞主线程
CompletableFuture.runAsync(() -> {
try {

View File

@ -11,5 +11,6 @@ import java.util.List;
public class ParagraphHitRateListVO {
private String uuid;
private String groupId;
private Boolean isExist;
private List<ParagraphHitRateWordVO> wordList;
}

View File

@ -6,9 +6,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.excel.core.util.ExcelUtils;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBasePageReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseRespVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseSaveReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.*;
import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgebase.KnowledgeBaseDO;
import cn.iocoder.yudao.module.llm.service.knowledgebase.KnowledgeBaseService;
import io.swagger.v3.oas.annotations.Operation;
@ -43,6 +41,20 @@ public class KnowledgeBaseController {
return success(knowledgeBaseService.createKnowledgeBase(createReqVO));
}
/**
* 执行知识库命中测试
*
* @param testReqVO 命中测试请求参数
* @return 命中测试结果
*/
@PostMapping("/hit-test")
@Operation(summary = "执行知识库命中测试")
public CommonResult<List<KnowledgeHitRateTestResultVO>> executeHitRateTest(
@Valid @RequestBody KnowledgeHitRateTestReqVO testReqVO) {
List<KnowledgeHitRateTestResultVO> result = knowledgeBaseService.executeHitRateTest(testReqVO);
return success(result);
}
@PutMapping("/update")
@Operation(summary = "更新知识库")
// @PreAuthorize("@ss.hasPermission('llm:knowledge-base:update')")

View File

@ -52,4 +52,15 @@ public class KnowledgeBaseRespVO {
@Schema(description = "文件引用上传")
private List<KnowledgeDocumentsRespVO> knowledgeDocuments;
}
/**
* 分块大小
*/
@Schema(description = "分块大小")
private Integer chunkSize;
/**
* 分块重叠
*/
@Schema(description = "分块重叠,")
private Integer chunkOverlap;
}

View File

@ -35,8 +35,15 @@ public class KnowledgeBaseSaveReqVO {
/**
* 分块大小
*/
@Schema(description = "分块大小")
private Integer chunkSize;
/**
* 分块重叠
*/
@Schema(description = "分块重叠,")
private Integer chunkOverlap;
@Schema(description = "文件引用")
private String knowledgeFile;

View File

@ -0,0 +1,34 @@
package cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo;
import lombok.Data;
import javax.validation.constraints.NotNull;
/**
* @Description 知识库命中率测试请求参数
*/
@Data
public class KnowledgeHitRateTestReqVO {
/**
* 查询内容
*/
@NotNull(message = "查询内容不能为空")
private String query;
/**
* 知识库ID
*/
@NotNull(message = "知识库ID不能为空")
private Long knowledgeId;
/**
* 返回结果的条数k值
*/
// @NotNull(message = "k值不能为空")
private Integer k;
/**
* Score阈值
*/
private Double score;
}

View File

@ -0,0 +1,34 @@
package cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo;
import lombok.Data;
/**
* @Description 知识库命中率测试返回结果
*/
@Data
public class KnowledgeHitRateTestResultVO {
/**
* 页面内容
*/
private String pageContent;
/**
* 命中率
*/
private String hitRate;
/**
* 摘要信息
*/
private String digest;
/**
* 文件ID
*/
private Long fileId;
/**
* 文件名称
*/
private String fileName;
}

View File

@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.*;
/**
@ -45,7 +46,7 @@ public class KnowledgeBaseDO extends BaseDO {
/**
* Score阈值
*/
private Integer score;
private Double score;
/**
* 知识长度
*/
@ -55,4 +56,13 @@ public class KnowledgeBaseDO extends BaseDO {
*/
private String knowledgeFile;
/**
* 分块大小
*/
private Integer chunkSize;
/**
* 分块重叠
*/
private Integer chunkOverlap;
}

View File

@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.llm.service.async;
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestResultVO;
import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgedocuments.KnowledgeDocumentsDO;
import cn.iocoder.yudao.module.llm.dal.mysql.knowledgedocuments.KnowledgeDocumentsMapper;
import cn.iocoder.yudao.module.llm.enums.KnowledgeStatusEnum;
@ -9,17 +11,22 @@ import cn.iocoder.yudao.module.llm.framework.backend.config.LLMBackendProperties
import cn.iocoder.yudao.module.llm.service.http.RagHttpService;
import cn.iocoder.yudao.module.llm.service.http.vo.KnowledgeRagEmbedReqVO;
import cn.iocoder.yudao.module.llm.service.http.vo.RegUploadReqVO;
import cn.iocoder.yudao.module.llm.service.http.vo.query.multiple.QueryMultipleReqVO;
import cn.iocoder.yudao.module.llm.service.http.vo.query.multiple.QueryResultPairVO;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.math.RoundingMode;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@ -38,8 +45,8 @@ public class AsyncKnowledgeBase {
// 向向量知识库创建文件
// @Async
public void createKnowledgeBase(List<KnowledgeDocumentsDO> knowledgeList, List<Long> ids) {
// @Async
public void createKnowledgeBase (List<KnowledgeDocumentsDO> knowledgeList, List<Long> ids, Map<String, Integer> knowledgeParameters) {
log.info("开始执行 createKnowledgeBase 方法。knowledgeList 大小: {}, ids 大小: {}", knowledgeList.size(), ids.size());
// 如果提供了 ids则删除现有的知识库文档
@ -80,7 +87,7 @@ public class AsyncKnowledgeBase {
if (lastIndex != -1) {
String extension = knowledge.getDocumentName().substring(lastIndex + 1).toLowerCase();
log.info("文档扩展名: {}", extension);
knowledgeEmbed(knowledge, knowledge.getKnowledgeBaseId());
knowledgeEmbed(knowledge, knowledge.getKnowledgeBaseId(), knowledgeParameters);
} else {
log.warn("文档无扩展名,跳过处理,文档 ID: {}", knowledge.getId());
}
@ -119,13 +126,15 @@ public class AsyncKnowledgeBase {
* @param knowledge 文件
* @param id 知识库id
*/
public void knowledgeEmbed (KnowledgeDocumentsDO knowledge, Long id) {
public void knowledgeEmbed (KnowledgeDocumentsDO knowledge, Long id, Map<String, Integer> knowledgeParameters) {
// 创建知识向量
KnowledgeRagEmbedReqVO ragEmbedReqVo = new KnowledgeRagEmbedReqVO()
.setFileId(String.valueOf(knowledge.getId()))
.setFileName(knowledge.getDocumentName())
.setFileUrl(knowledge.getFileUrl());
.setFileUrl(knowledge.getFileUrl())
.setChunkSize(knowledgeParameters.get("chunkSize"))
.setChunkOverlap(knowledgeParameters.get("chunkOverlap"));
try {
ragHttpService.knowledgeEmbed(ragEmbedReqVo, id);
@ -134,4 +143,42 @@ public class AsyncKnowledgeBase {
}
}
public List<KnowledgeHitRateTestResultVO> executeHitRateTest (KnowledgeHitRateTestReqVO testReqVO , List<Long> fileIds) {
List<String> fileIdStr = fileIds.stream()
.map(Object::toString)
.collect(Collectors.toList());
QueryMultipleReqVO vo = new QueryMultipleReqVO();
vo.setQuery(testReqVO.getQuery());
vo.setFileIds(fileIdStr);
vo.setK(testReqVO.getK());
vo.setScore(testReqVO.getScore());
List<KnowledgeHitRateTestResultVO> resultList = new ArrayList<>();
List<QueryResultPairVO> result = ragHttpService.executeHitRateTest(vo);
for (QueryResultPairVO pair : result) {
KnowledgeHitRateTestResultVO resultVO = new KnowledgeHitRateTestResultVO();
resultVO.setPageContent(pair.getDocument().getPageContent());
// DecimalFormat df = new DecimalFormat("0.00%");
// df.setRoundingMode(RoundingMode.HALF_UP);
// String rateResult = df.format(pair.getHitRate());
resultVO.setHitRate(String.valueOf(pair.getHitRate()));
resultVO.setDigest(pair.getDocument().getMetadata().getDigest());
long fileId = Long.parseLong(pair.getDocument().getMetadata().getFileId());
resultVO.setFileId(fileId);
// 根据 fileId 查找文件名
KnowledgeDocumentsDO documents = knowledgeDocumentsMapper.selectOne(KnowledgeDocumentsDO::getFileId, fileId);
if (documents!=null && StringUtils.isNotBlank(documents.getDocumentName())){
resultVO.setFileName(documents.getDocumentName());
}else {
resultVO.setFileName("未知文件");
}
resultList.add(resultVO);
}
return resultList;
}
}

View File

@ -13,6 +13,9 @@ import cn.iocoder.yudao.module.llm.controller.admin.application.vo.ApplicationSa
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*;
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.ChatReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.datarefluxdata.vo.DataRefluxDataSaveReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseSaveReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeHitRateTestResultVO;
import cn.iocoder.yudao.module.llm.dal.dataobject.basemodel.BaseModelDO;
import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO;
import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgedocuments.KnowledgeDocumentsDO;
@ -26,6 +29,7 @@ import cn.iocoder.yudao.module.llm.service.basemodel.BaseModelService;
import cn.iocoder.yudao.module.llm.service.datarefluxdata.DataRefluxDataService;
import cn.iocoder.yudao.module.llm.service.http.ModelService;
import cn.iocoder.yudao.module.llm.service.http.vo.*;
import cn.iocoder.yudao.module.llm.service.knowledgebase.KnowledgeBaseService;
import cn.iocoder.yudao.module.llm.service.prompttemplates.PromptTemplatesService;
import com.alibaba.excel.util.StringUtils;
import com.alibaba.fastjson.JSON;
@ -45,6 +49,7 @@ import javax.servlet.http.HttpServletResponse;
import java.math.RoundingMode;
import java.text.DecimalFormat;
import java.util.*;
import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*;
@ -80,6 +85,9 @@ public class ConversationServiceImpl implements ConversationService {
@Resource
private LLMBackendProperties llmBackendProperties;
@Resource
private KnowledgeBaseService knowledgeBaseService;
// 聊天会话历史记录缓存Key
private final static String CHAT_HIStORY_REDIS_KEY = "llm:chat:history";
// 聊天会话历史记录缓存时间
@ -439,14 +447,16 @@ public class ConversationServiceImpl implements ConversationService {
// 处理 knowledgeBaseString
if (StringUtils.isNotBlank(knowledgeBaseString)) {
knowledgeBaseString = "<context>" + knowledgeBaseString + "</context>";
}else {
knowledgeBaseString = "<context>" + "</context>";
}
// 处理 systemPrompt
systemPrompt = StringUtils.isBlank(chatReqVO.getSystemPrompt())
? PROMPT
: chatReqVO.getSystemPrompt() + "\n" + PROMPT;
: chatReqVO.getSystemPrompt() + " \n " + PROMPT;
}
String mess = systemPrompt + knowledgeBaseString;
String mess = systemPrompt + " \n "+knowledgeBaseString;
// // 查询历史记录消息并将查询出来的知识信息放入到 role = system 的消息中
// List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
@ -532,30 +542,14 @@ public class ConversationServiceImpl implements ConversationService {
ParagraphHitRateListVO paragraphHitRateListVO = new ParagraphHitRateListVO();
paragraphHitRateListVO.setUuid(chatReqVO.getUuid());
paragraphHitRateListVO.setGroupId(chatReqVO.getGroupId());
List<ParagraphHitRateWordVO> words = new ArrayList<>();
// 2. 遍历处理每个文档
for (KnowledgeDocumentsDO document : documentList) {
ParagraphHitRateWordVO rateWordVO = processDocument(document, chatReqVO, knowledgeBase);
if (rateWordVO != null) {
words.add(rateWordVO);
}
}
if (CollectionUtils.isEmpty(words)) {
paragraphHitRateListVO.setWordList(Collections.emptyList());
paragraphHitRateListVO.setGroupId("");
} else {
paragraphHitRateListVO.setWordList(words);
}
KnowledgeHitRateTestReqVO testReqVO=new KnowledgeHitRateTestReqVO();
testReqVO.setKnowledgeId(chatReqVO.getKnowledge());
testReqVO.setQuery(chatReqVO.getPrompt());
// 请求结果添加到 Redis查询段落命中率
String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, chatReqVO.getUuid());
stringRedisTemplate.opsForList().rightPush(redisKey, JSON.toJSONString(paragraphHitRateListVO));
List<KnowledgeHitRateTestResultVO> result = knowledgeBaseService.executeHitRateTest(testReqVO);
knowledgeBase = handlerResult(result, paragraphHitRateListVO);
List<String> paragraphHitRateList = stringRedisTemplate.opsForList().range(redisKey, 0, -1);
if (paragraphHitRateList != null && !paragraphHitRateList.isEmpty()) {
log.info("{} 知识库查询段落命中率: {}", "[KnowledgeBase]", paragraphHitRateList);
}
log.info("{} 知识库构建完成,内容长度: {}", LOG_PREFIX, knowledgeBase.length());
} catch (Exception e) {
@ -567,6 +561,74 @@ public class ConversationServiceImpl implements ConversationService {
return knowledgeBase;
}
private StringBuilder handlerResult (List<KnowledgeHitRateTestResultVO> result, ParagraphHitRateListVO paragraphHitRateListVO) {
if (CollectionUtils.isEmpty(result)){
return new StringBuilder();
}
// 1: 存储到redis
saveRedis(result, paragraphHitRateListVO);
// 2: 组成返回数据
StringBuilder knowledgeBase = new StringBuilder();
result.forEach(item -> {
knowledgeBase.append(item.getPageContent());
});
return knowledgeBase;
}
private void saveRedis (List<KnowledgeHitRateTestResultVO> result, ParagraphHitRateListVO paragraphHitRateListVO) {
if (CollectionUtils.isEmpty(result)){
return;
}
List<ParagraphHitRateWordVO> words = new ArrayList<>();
// 按照fileId分组存到Map中
Map<Long, List<KnowledgeHitRateTestResultVO>> groupedByFileId = result.stream()
.collect(Collectors.groupingBy(KnowledgeHitRateTestResultVO::getFileId));
// 遍历Map查看分组结果
groupedByFileId.forEach((fileId, list) -> {
System.out.println("File ID: " + fileId);
list.forEach(i->{
ParagraphHitRateWordVO rateWordVO = new ParagraphHitRateWordVO();
// 设置文档名称
rateWordVO.setDocumentName(i.getFileName());
// 设置段落命中率
List<ParagraphHitRateVO> paragraphHitRate=new ArrayList<>();
for (KnowledgeHitRateTestResultVO i1 : list) {
ParagraphHitRateVO rateVO = new ParagraphHitRateVO();
rateVO.setParagraph(i1.getPageContent());
rateVO.setHitRate(i1.getHitRate());
rateVO.setWordCount(i1.getPageContent().length());
paragraphHitRate.add(rateVO);
}
rateWordVO.setParagraphHitRate(paragraphHitRate);
words.add(rateWordVO);
});
});
if (CollectionUtils.isEmpty(words)) {
paragraphHitRateListVO.setWordList(Collections.emptyList());
paragraphHitRateListVO.setIsExist(false);
} else {
paragraphHitRateListVO.setWordList(words);
paragraphHitRateListVO.setIsExist(true);
}
// 请求结果添加到 Redis查询段落命中率
String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, paragraphHitRateListVO.getUuid());
stringRedisTemplate.opsForList().rightPush(redisKey, JSON.toJSONString(paragraphHitRateListVO));
List<String> paragraphHitRateList = stringRedisTemplate.opsForList().range(redisKey, 0, -1);
if (paragraphHitRateList != null && !paragraphHitRateList.isEmpty()) {
log.info("{} 知识库查询段落命中率: {}", "[KnowledgeBase]", paragraphHitRateList);
}
}
/**
* 处理单个知识库文档的检索逻辑
*/

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.llm.service.http;
import cn.hutool.http.Header;
import cn.hutool.http.HttpRequest;
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
import cn.iocoder.yudao.framework.common.util.http.HttpUtils;
@ -11,6 +12,9 @@ import cn.iocoder.yudao.module.llm.dal.mysql.knowledgedocuments.KnowledgeDocumen
import cn.iocoder.yudao.module.llm.enums.KnowledgeStatusEnum;
import cn.iocoder.yudao.module.llm.framework.backend.config.LLMBackendProperties;
import cn.iocoder.yudao.module.llm.service.http.vo.*;
import cn.iocoder.yudao.module.llm.service.http.vo.query.multiple.DocumentInfoVO;
import cn.iocoder.yudao.module.llm.service.http.vo.query.multiple.QueryMultipleReqVO;
import cn.iocoder.yudao.module.llm.service.http.vo.query.multiple.QueryResultPairVO;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONException;
@ -49,6 +53,7 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -390,9 +395,10 @@ public class RagHttpService {
String fileName = reqVO.getFileName();
String fileUrl = reqVO.getFileUrl();
Integer chunkSize = Optional.ofNullable(reqVO.getChunkSize()).orElse(1500);
Integer chunkOverlap = Optional.ofNullable(reqVO.getChunkOverlap()).orElse(300);
String mediaType = getMediaType(fileName);
log.info("文件ID: {}, 文件名: {}, 文件URL: {}, 文件类型: {}, 分块大小:{}", fileId, fileName, fileUrl, mediaType,chunkSize);
log.info("文件ID: {}, 文件名: {}, 文件URL: {}, 文件类型: {}, 分块大小:{}, 分块重叠:{}", fileId, fileName, fileUrl, mediaType, chunkSize, chunkOverlap);
// 获取知识库文档
log.info("开始获取知识库文档知识库ID: {}, 文件ID: {}", id, fileId);
@ -417,15 +423,15 @@ public class RagHttpService {
Path tempFilePath = downloadFileToTemp(fileUrl, fileName);
log.info("文件已下载到临时目录: {}", tempFilePath);
// String fileSuffix = getFileSuffix(fileName);
// if ("doc".equals(fileSuffix)) {
// log.info("正在处理 doc 文件");
// try {
// tempFilePath = converterDocToDocx(tempFilePath.toString(), tempFilePath.toString().replace(".doc", ".docx"));
// } catch (Exception e) {
// throw new RuntimeException(e);
// }
// }
// String fileSuffix = getFileSuffix(fileName);
// if ("doc".equals(fileSuffix)) {
// log.info("正在处理 doc 文件");
// try {
// tempFilePath = converterDocToDocx(tempFilePath.toString(), tempFilePath.toString().replace(".doc", ".docx"));
// } catch (Exception e) {
// throw new RuntimeException(e);
// }
// }
// if ("md".equals(fileSuffix)) {
// log.info("正在处理 md 文件");
@ -450,6 +456,7 @@ public class RagHttpService {
.setType(MultipartBody.FORM)
.addFormDataPart("file_id", fileId)
.addFormDataPart("chunk_size", String.valueOf(chunkSize))
.addFormDataPart("chunk_overlap", String.valueOf(chunkOverlap))
.addFormDataPart("file", fileName,
RequestBody.create(tempFilePath.toFile(), MediaType.parse(mediaType))
)
@ -728,4 +735,101 @@ public class RagHttpService {
private KnowledgeDocumentsDO getKnowledgeDocuments (String fileId) {
return knowledgeDocumentsMapper.selectById(fileId);
}
public List<QueryResultPairVO> executeHitRateTest (QueryMultipleReqVO vo) {
String jsonString = JSON.toJSONString(vo);
String url = llmBackendProperties.getRagQueryMultiple();
//链式构建请求
String result2 = HttpRequest.post(url)
.header(Header.ACCEPT, "application/json")
.header(Header.CONTENT_TYPE, "application/json")
.body(jsonString)
.timeout(20000)
.execute().body();
cn.hutool.core.lang.Console.log(result2);
log.info("请求参数: {}",jsonString);
log.info("请求结果: {}",JSON.toJSONString(result2));
return parseHitRateTestResults(result2,vo.getScore());
}
private static List<QueryResultPairVO> parseHitRateTestResults (String json, Double score) {
boolean array= json.trim().startsWith("[");
// 先判断 JSON 是否是一个数组
if (!array){
// 判断是否存在 detail 字段
JSONObject jsonObject = JSON.parseObject(json);
if (jsonObject.containsKey("detail")) {
String detail = jsonObject.getString("detail");
if (detail.contains("No documents found for the given query")) {
throw exception(new ErrorCode(100_100_1, "未找到符合条件的文档,请检查查询条件!"));
}
return new ArrayList<>();
}
}
// JSON 转换为 List<QueryResultPair>
// 解析 JSON 数组
JSONArray jsonArray = JSON.parseArray(json);
// 创建结果列表
List<QueryResultPairVO> results = new ArrayList<>();
// 遍历 JSON 数组
for (int i = 0; i < jsonArray.size(); i++) {
JSONArray pairArray = jsonArray.getJSONArray(i);
// 解析文档信息
JSONObject documentJson = pairArray.getJSONObject(0);
DocumentInfoVO document = JSON.parseObject(documentJson.toJSONString(), DocumentInfoVO.class);
// 解析命中率
Double rate = pairArray.getDoubleValue(1);
if (rate >= score) {
QueryResultPairVO pair = new QueryResultPairVO();
pair.setDocument(document);
pair.setHitRate(rate);
results.add(pair);
}
}
// // 访问数据
// for (QueryResultPairVO pair : results) {
// System.out.println("Page Content: " + pair.getDocument().getPageContent());
// System.out.println("Hit Rate: " + pair.getHitRate());
// System.out.println("File ID: " + pair.getDocument().getMetadata().getFileId());
// System.out.println("----------------------");
// }
return results;
}
public static void main (String[] args) {
List<String> ids = new ArrayList<>();
ids.add("1111");
ids.add("1234");
QueryMultipleReqVO vo = new QueryMultipleReqVO();
vo.setQuery("可乐鸡翅怎么做");
vo.setFileIds(ids);
vo.setK(4);
String jsonString = JSON.toJSONString(vo);
String url = "http://192.168.18.66:8123/query_multiple";
//链式构建请求
String result2 = HttpRequest.post(url)
.header(Header.ACCEPT, "application/json")
.header(Header.CONTENT_TYPE, "application/json")
.body(jsonString)
.timeout(20000)
.execute().body();
cn.hutool.core.lang.Console.log(result2);
// extracted(result2);
}
}

View File

@ -46,4 +46,9 @@ public class KnowledgeRagEmbedReqVO {
* 分块大小
*/
private Integer chunkSize;
/**
* 分块重叠
*/
private Integer chunkOverlap;
}

View File

@ -0,0 +1,29 @@
package cn.iocoder.yudao.module.llm.service.http.vo.query.multiple;
import lombok.Data;
/**
* @Description 文档信息类
*/
@Data
public class DocumentInfoVO {
/**
* 文档ID可为空
*/
private String id;
/**
* 元数据
*/
private MetadataVO metadata;
/**
* 页面内容
*/
private String pageContent;
/**
* 文档类型
*/
private String type;
}

View File

@ -0,0 +1,30 @@
package cn.iocoder.yudao.module.llm.service.http.vo.query.multiple;
import lombok.Data;
/**
* @Description 文档元数据类
*/
@Data
public class MetadataVO {
/**
* 文件ID
*/
private String fileId;
/**
* 用户ID
*/
private String userId;
/**
* 文件摘要
*/
private String digest;
/**
* 文件来源路径
*/
private String source;
}

View File

@ -0,0 +1,34 @@
package cn.iocoder.yudao.module.llm.service.http.vo.query.multiple;
import com.alibaba.fastjson.annotation.JSONField;
import lombok.Data;
import javax.validation.constraints.NotNull;
import java.util.List;
/**
* @Description 知识库多文件查询
*/
@Data
public class QueryMultipleReqVO {
/**
* 查询内容
*/
@NotNull(message = "查询内容不能为空")
private String query;
/**
* 文件ID列表
*/
@NotNull(message = "文件ID列表不能为空")
@JSONField(name = "file_ids")
private List<String> fileIds;
/**
* 返回结果的条数k值
*/
// @NotNull(message = "k值不能为空")
private Integer k;
private Double score;
}

View File

@ -0,0 +1,19 @@
package cn.iocoder.yudao.module.llm.service.http.vo.query.multiple;
import lombok.Data;
/**
* @Description 查询结果对包含文档信息和命中率
*/
@Data
public class QueryResultPairVO {
/**
* 文档信息
*/
private DocumentInfoVO document;
/**
* 命中率
*/
private Double hitRate;
}

View File

@ -1,9 +1,7 @@
package cn.iocoder.yudao.module.llm.service.knowledgebase;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBasePageReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseRespVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseSaveReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.*;
import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgebase.KnowledgeBaseDO;
import javax.validation.Valid;
@ -70,4 +68,11 @@ public interface KnowledgeBaseService {
* @param updateReqVO 更新信息
*/
void updateKnowledgeBaseInfo (@Valid KnowledgeBaseSaveReqVO updateReqVO);
/**
* 执行知识库命中测试
* @param testReqVO 测试信息
* @return 返回结果
*/
List<KnowledgeHitRateTestResultVO> executeHitRateTest (@Valid KnowledgeHitRateTestReqVO testReqVO);
}

View File

@ -5,9 +5,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBasePageReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseRespVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.KnowledgeBaseSaveReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgebase.vo.*;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgedocuments.vo.KnowledgeDocumentsRespVO;
import cn.iocoder.yudao.module.llm.controller.admin.knowledgedocuments.vo.KnowledgeDocumentsSaveReqVO;
import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgebase.KnowledgeBaseDO;
@ -17,24 +15,17 @@ import cn.iocoder.yudao.module.llm.dal.mysql.knowledgedocuments.KnowledgeDocumen
import cn.iocoder.yudao.module.llm.service.application.ApplicationService;
import cn.iocoder.yudao.module.llm.service.async.AsyncKnowledgeBase;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import kong.unirest.Unirest;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.validation.annotation.Validated;
import javax.annotation.Resource;
import javax.annotation.Tainted;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.*;
import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.KNOWLEDGE_BASE_NAME_NOT_EXISTS;
import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.KNOWLEDGE_BASE_NOT_EXISTS;
import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*;
/**
* 知识库 Service 实现类
@ -66,68 +57,139 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
}
@Override
// @Transactional(rollbackFor = Exception.class)
// @Transactional(rollbackFor = Exception.class)
public void updateKnowledgeBase (KnowledgeBaseSaveReqVO updateReqVO) {
// 1. 校验知识库是否存在
validateKnowledgeParam(updateReqVO);
// 2. 更新知识库主表基础信息
KnowledgeBaseDO updateObj = BeanUtils.toBean(updateReqVO, KnowledgeBaseDO.class);
knowledgeBaseMapper.updateById(updateObj);
// 3. 处理附表知识文档数据
handleKnowledgeDocuments(updateReqVO, updateObj);
}
/**
* 校验知识库参数
*
* @param updateReqVO 更新知识库参数
*/
private void validateKnowledgeParam (KnowledgeBaseSaveReqVO updateReqVO) {
// 1. 校验知识库是否存在
validateKnowledgeBaseExists(updateReqVO.getId());
// 2. 校验知识库名称是否重复
validateKnowledgeBaseNameExists(updateReqVO);
// 3. 更新知识库主表
KnowledgeBaseDO updateObj = BeanUtils.toBean(updateReqVO, KnowledgeBaseDO.class);
knowledgeBaseMapper.updateById(updateObj);
// 3. 校验分块大小和分块重叠是否正确
validateChunkParameters(updateReqVO.getChunkSize(), updateReqVO.getChunkOverlap());
}
// Unirest.config().reset();
// Unirest.config()
// .socketTimeout(86400000)
// .connectTimeout(100000)
// .concurrency(10, 5)
// .setDefaultHeader("Accept", "application/json");
/**
* 校验分块大小和分块重叠是否合法
*
* @param chunkSize 分块大小
* @param chunkOverlap 分块重叠
* @throws IllegalArgumentException 如果校验不通过
*/
private void validateChunkParameters (int chunkSize, int chunkOverlap) {
if (chunkSize < 1) {
throw exception(CHUNK_SIZE_MUST_BE_GREATER_THAN_ZERO);
}
if (chunkOverlap < 0) {
throw exception(CHUNK_OVERLAP_MUST_BE_GREATER_THAN_OR_EQUAL_TO_ZERO);
}
if (chunkOverlap >= chunkSize) {
throw exception(CHUNK_OVERLAP_MUST_BE_LESS_THAN_CHUNK_SIZE);
}
}
// 4. 处理附表知识文档数据
if (!CollectionUtils.isAnyEmpty(updateReqVO.getKnowledgeDocuments())) {
// 4.1 获取需要保留的文档 ID
List<Long> retainedIds = updateReqVO.getKnowledgeDocuments().stream()
.map(KnowledgeDocumentsSaveReqVO::getId)
.filter(Objects::nonNull)
.collect(Collectors.toList());
// 4.2 删除不需要保留的文档
LambdaQueryWrapperX<KnowledgeDocumentsDO> deleteWrapper = new LambdaQueryWrapperX<KnowledgeDocumentsDO>()
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, updateReqVO.getId());
if (!CollectionUtils.isAnyEmpty(retainedIds)) {
deleteWrapper.notIn(KnowledgeDocumentsDO::getId, retainedIds);
}
knowledgeDocumentsMapper.delete(deleteWrapper);
// 4.3 更新或插入文档数据
List<KnowledgeDocumentsDO> newDocuments = new ArrayList<>();
updateReqVO.getKnowledgeDocuments().forEach(doc -> {
KnowledgeDocumentsDO docDO = BeanUtils.toBean(doc, KnowledgeDocumentsDO.class);
docDO.setKnowledgeBaseId(updateReqVO.getId());
docDO.setChunkSize(updateObj.getKnowledgeLength());
if (doc.getId() == null) {
newDocuments.add(docDO); // 收集新增文档
}
knowledgeDocumentsMapper.insertOrUpdate(docDO); // 更新或插入文档
});
// 4.4 异步处理新增文档和删除的文档
List<Long> deleteIds = knowledgeDocumentsMapper.selectDeleteIds(updateReqVO.getId());
asyncKnowledgeBase.createKnowledgeBase(newDocuments, deleteIds);
} else {
// 5. 如果传入的文档列表为空则删除所有关联文档
knowledgeDocumentsMapper.delete(new LambdaQueryWrapperX<KnowledgeDocumentsDO>()
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, updateReqVO.getId()));
// 5.1 异步处理删除的文档
List<Long> deleteIds = knowledgeDocumentsMapper.selectDeleteIds(updateReqVO.getId());
if (!CollectionUtils.isAnyEmpty(deleteIds)) {
asyncKnowledgeBase.createKnowledgeBase(null, deleteIds);
}
/**
* 处理知识文档数据
*
* @param updateReqVO 更新知识库参数
* @param updateObj 更新知识库对象
*/
private void handleKnowledgeDocuments (KnowledgeBaseSaveReqVO updateReqVO, KnowledgeBaseDO updateObj) {
List<KnowledgeDocumentsSaveReqVO> documents = updateReqVO.getKnowledgeDocuments();
if (CollectionUtils.isAnyEmpty(documents)) {
// 如果传入的文档列表为空则删除所有关联文档
deleteAllDocuments(updateReqVO.getId());
return;
}
// 获取需要保留的文档 ID
List<Long> retainedIds = documents.stream()
.map(KnowledgeDocumentsSaveReqVO::getId)
.filter(Objects::nonNull)
.collect(Collectors.toList());
// 删除不需要保留的文档
deleteUnretainedDocuments(updateReqVO.getId(), retainedIds);
// 更新或插入文档数据
List<KnowledgeDocumentsDO> newDocuments = updateOrInsertDocuments(documents, updateReqVO.getId(), updateObj.getKnowledgeLength());
Map<String, Integer> knowledgeParameters = new HashMap<>();
knowledgeParameters.put("chunkSize", updateReqVO.getChunkSize());
knowledgeParameters.put("chunkOverlap", updateReqVO.getChunkOverlap());
// 异步处理新增文档和删除的文档
List<Long> deleteIds = knowledgeDocumentsMapper.selectDeleteIds(updateReqVO.getId());
asyncKnowledgeBase.createKnowledgeBase(newDocuments, deleteIds, knowledgeParameters);
}
/**
* 删除所有关联的文档
*
* @param knowledgeBaseId 知识库 ID
*/
private void deleteAllDocuments (Long knowledgeBaseId) {
knowledgeDocumentsMapper.delete(new LambdaQueryWrapperX<KnowledgeDocumentsDO>()
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, knowledgeBaseId));
// 异步处理删除的文档
List<Long> deleteIds = knowledgeDocumentsMapper.selectDeleteIds(knowledgeBaseId);
if (!CollectionUtils.isAnyEmpty(deleteIds)) {
asyncKnowledgeBase.createKnowledgeBase(new ArrayList<>(), deleteIds, new HashMap<>());
}
}
/**
* 删除不需要保留的文档
*
* @param knowledgeBaseId 知识库 ID
* @param retainedIds 需要保留的文档 ID
*/
private void deleteUnretainedDocuments (Long knowledgeBaseId, List<Long> retainedIds) {
LambdaQueryWrapperX<KnowledgeDocumentsDO> deleteWrapper = new LambdaQueryWrapperX<KnowledgeDocumentsDO>()
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, knowledgeBaseId);
if (!CollectionUtils.isAnyEmpty(retainedIds)) {
deleteWrapper.notIn(KnowledgeDocumentsDO::getId, retainedIds);
}
knowledgeDocumentsMapper.delete(deleteWrapper);
}
/**
* 更新或插入文档数据
*
* @param documents 需要更新的文档数据
* @param knowledgeBaseId 知识库 ID
* @param chunkSize
* @return 更新或插入的文档数据
*/
private List<KnowledgeDocumentsDO> updateOrInsertDocuments (List<KnowledgeDocumentsSaveReqVO> documents, Long knowledgeBaseId, Integer chunkSize) {
List<KnowledgeDocumentsDO> newDocuments = new ArrayList<>();
documents.forEach(doc -> {
KnowledgeDocumentsDO docDO = BeanUtils.toBean(doc, KnowledgeDocumentsDO.class);
docDO.setKnowledgeBaseId(knowledgeBaseId);
if (doc.getId() == null) {
newDocuments.add(docDO); // 收集新增文档
}
knowledgeDocumentsMapper.insertOrUpdate(docDO); // 更新或插入文档
});
return newDocuments;
}
@Override
@ -236,6 +298,57 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
knowledgeBaseMapper.updateById(knowledgeBaseDO);
}
/**
* 执行知识库命中测试
*
* @param testReqVO 测试信息
* @return 返回结果
*/
@Override
public List<KnowledgeHitRateTestResultVO> executeHitRateTest (KnowledgeHitRateTestReqVO testReqVO) {
Long knowledgeId = testReqVO.getKnowledgeId();
KnowledgeBaseDO baseDO = knowledgeBaseMapper.selectOne(KnowledgeBaseDO::getId, knowledgeId);
if (baseDO == null) {
throw exception(KNOWLEDGE_BASE_NOT_EXISTS);
}
Integer topK = 4;
if (baseDO.getTopK() == null || baseDO.getTopK() <= 0) {
testReqVO.setK(topK);
} else {
topK = baseDO.getTopK();
testReqVO.setK(topK);
}
Double score = 0.2;
if (baseDO.getScore() == null || baseDO.getTopK() <= 0.0|| baseDO.getScore() > 1) {
testReqVO.setScore(score);
} else {
score = baseDO.getScore();
testReqVO.setScore(score);
}
// 根据知识库ID获取参数信息关联文档
List<KnowledgeDocumentsDO> documentsDOS = knowledgeDocumentsMapper.selectList(new LambdaQueryWrapper<KnowledgeDocumentsDO>()
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, knowledgeId));
if (com.baomidou.mybatisplus.core.toolkit.CollectionUtils.isEmpty(documentsDOS)) {
throw exception(KNOWLEDGE_DOCUMENTS_NOT_EXISTS);
}
// 获取fileId列表
List<Long> fileIds = documentsDOS.stream()
.map(KnowledgeDocumentsDO::getFileId)
.collect(Collectors.toList());
List<KnowledgeHitRateTestResultVO> result = asyncKnowledgeBase.executeHitRateTest(testReqVO, fileIds);
if (com.baomidou.mybatisplus.core.toolkit.CollectionUtils.isEmpty(result)) {
return Collections.emptyList();
}
return result;
}
/**
* 校验知识库是否存在
*