feat(llm): 增加知识库文档查询和段落命中率统计功能

- 在 ChatReqVO 中添加 groupId 字段,用于区分不同的对话分组
- 新增 getParagraphHitRate 方法,用于获取段落命中率信息
- 优化 chatStream 方法,增加知识库文档查询逻辑
- 新增 ParagraphHitRateListVO、ParagraphHitRateVO 和 ParagraphHitRateWordVO 类,用于段落命中率统计
This commit is contained in:
Liuyang 2025-03-05 17:11:44 +08:00
parent 840f8003b7
commit 511b99fe62
8 changed files with 322 additions and 54 deletions

View File

@ -11,7 +11,6 @@ import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*;
import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO;
import cn.iocoder.yudao.module.llm.service.conversation.ConversationService;
import cn.iocoder.yudao.module.llm.service.http.vo.TextToImageReqVo;
import com.alibaba.fastjson.JSON;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
@ -27,8 +26,6 @@ import javax.validation.Valid;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import static cn.iocoder.yudao.framework.apilog.core.enums.OperateTypeEnum.EXPORT;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
@ -113,26 +110,26 @@ public class ConversationController {
public SseEmitter streamChat (@Valid @RequestBody ChatReqVO chatReqVO, HttpServletResponse response) {
log.info("收到对话推理请求,请求参数: {}", chatReqVO);
SseEmitter emitter = new SseEmitter(120_000L);
// ExecutorService executor = Executors.newSingleThreadExecutor();
// try {
// executor.execute(() -> {
// try {
// conversationService.chatStream(chatReqVO, emitter, response);
// } catch (Exception e) {
// emitter.completeWithError(e);
// } finally {
// executor.shutdown();
// }
// });
// } catch (Exception e) {
// log.error("处理对话推理请求时发生异常", e);
// try {
// emitter.completeWithError(e);
// } catch (Exception ex) {
// log.error("无法完成 SseEmitter 错误处理", ex);
// }
// }
// log.info("返回 SseEmitter 对象,准备进行流式响应");
// ExecutorService executor = Executors.newSingleThreadExecutor();
// try {
// executor.execute(() -> {
// try {
// conversationService.chatStream(chatReqVO, emitter, response);
// } catch (Exception e) {
// emitter.completeWithError(e);
// } finally {
// executor.shutdown();
// }
// });
// } catch (Exception e) {
// log.error("处理对话推理请求时发生异常", e);
// try {
// emitter.completeWithError(e);
// } catch (Exception ex) {
// log.error("无法完成 SseEmitter 错误处理", ex);
// }
// }
// log.info("返回 SseEmitter 对象,准备进行流式响应");
// 异步处理避免阻塞主线程
CompletableFuture.runAsync(() -> {
try {
@ -150,6 +147,13 @@ public class ConversationController {
return emitter;
}
@GetMapping("/paragraphHitRate")
@Operation(summary = "获得大模型对话记录分页")
@PreAuthorize("@ss.hasPermission('llm:conversation:query')")
public CommonResult<ParagraphHitRateListVO> getParagraphHitRate (@Valid String uuid, @Valid String groupId) {
return success(conversationService.getParagraphHitRate(uuid,groupId));
}
@PostMapping("/text-to-image")
@Operation(summary = "文字转图片接口")
public CommonResult<JSONArray> textToImage (@Valid @RequestBody TextToImageReqVo req) {

View File

@ -31,4 +31,6 @@ public class ChatReqVO {
private Integer maxTokens;
@Schema(description = "随机性temperature")
private Double temperature;
@Schema(description = "分组id")
private String groupId;
}

View File

@ -0,0 +1,15 @@
package cn.iocoder.yudao.module.llm.controller.admin.conversation.vo;
import lombok.Data;
import java.util.List;
/**
* @Description 段落命中率列表
*/
@Data
public class ParagraphHitRateListVO {
private String uuid;
private String groupId;
private List<ParagraphHitRateWordVO> wordList;
}

View File

@ -0,0 +1,26 @@
package cn.iocoder.yudao.module.llm.controller.admin.conversation.vo;
import lombok.Data;
import java.util.List;
/**
* @Description 段落命中率
*/
@Data
public class ParagraphHitRateVO {
/**
* 段落
*/
private String paragraph;
/**
* 命中率
*/
private String hitRate;
/**
* 字数
*/
private Integer wordCount;
}

View File

@ -0,0 +1,18 @@
package cn.iocoder.yudao.module.llm.controller.admin.conversation.vo;
import lombok.Data;
import java.util.List;
/**
* @Description 段落命中率
*/
@Data
public class ParagraphHitRateWordVO {
/**
* 文件名称
*/
private String documentName;
private List<ParagraphHitRateVO> paragraphHitRate;
}

View File

@ -74,4 +74,11 @@ public interface ConversationService {
* @param emitter emitter
*/
void chatStream (@Valid ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response);
/**
* 获取段落命中率
* @param uuid
* @return
*/
ParagraphHitRateListVO getParagraphHitRate (@Valid String uuid,@Valid String groupId);
}

View File

@ -3,16 +3,15 @@ package cn.iocoder.yudao.module.llm.service.conversation;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONArray;
import cn.iocoder.yudao.framework.common.exception.ServiceException;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.http.HttpUtils;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.llm.controller.admin.application.vo.ApplicationRespVO;
import cn.iocoder.yudao.module.llm.controller.admin.application.vo.ApplicationSaveReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*;
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.ChatReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.ChatRespVO;
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.ConversationPageReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.ConversationSaveReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.datarefluxdata.vo.DataRefluxDataSaveReqVO;
import cn.iocoder.yudao.module.llm.dal.dataobject.basemodel.BaseModelDO;
import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO;
@ -30,10 +29,12 @@ import cn.iocoder.yudao.module.llm.service.http.vo.*;
import cn.iocoder.yudao.module.llm.service.prompttemplates.PromptTemplatesService;
import com.alibaba.excel.util.StringUtils;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONException;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
@ -41,6 +42,8 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletResponse;
import java.math.RoundingMode;
import java.text.DecimalFormat;
import java.util.*;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@ -82,6 +85,12 @@ public class ConversationServiceImpl implements ConversationService {
// 聊天会话历史记录缓存时间
private final static Long CHAT_HISTORY_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L;
/**
* 知识库文档缓存Key
*/
private final static String KNOWLEDGE_DOCUMENTS_REDIS_KEY = "llm:knowledge:documents";
private final static Long KNOWLEDGE_DOCUMENTS_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L;
@Override
public Integer createConversation (ConversationSaveReqVO createReqVO) {
// 插入
@ -281,7 +290,7 @@ public class ConversationServiceImpl implements ConversationService {
* @param chatReqVO 对话请求对象
* @param emitter SseEmitter 对象用于流式发送响应
*/
@Override
@Override
public void chatStream (ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response) {
log.info("开始处理对话请求,请求参数: {}", chatReqVO);
// 检查系统提示信息如果为空则尝试从应用中获取
@ -304,6 +313,57 @@ public class ConversationServiceImpl implements ConversationService {
publicModelChatStream(chatReqVO, emitter);
}
/**
* 获取段落命中率
*
* @param uuid
* @return
*/
@Override
public ParagraphHitRateListVO getParagraphHitRate(String uuid, String groupId) {
String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, uuid);
List<String> redisResults = stringRedisTemplate.opsForList().range(redisKey, 0, -1);
log.info("[Redis Query] Key: {} | Results count: {}", redisKey, redisResults != null ? redisResults.size() : 0);
if (CollectionUtils.isNotEmpty(redisResults)) {
for (String jsonResult : redisResults) {
try {
JSONObject jsonObject = JSONObject.parseObject(jsonResult);
log.info("[Processing] Raw JSON: {}", jsonResult);
if (!isValidHitRateData(jsonObject, uuid, groupId)) {
continue;
}
// 类型安全转换FastJSON特性
return jsonObject.toJavaObject(ParagraphHitRateListVO.class);
} catch (JSONException e) {
log.error("[JSON Parse Error] Invalid format: {} | Data: {}", e.getMessage(), jsonResult);
} catch (Exception e) {
log.warn("[Data Validation] Skip invalid record: {} | Reason: {}", jsonResult, e.getMessage());
}
}
}
return null;
}
/**
* 验证命中率数据有效性
*/
private boolean isValidHitRateData(JSONObject jsonObj, String expectedUuid, String expectedGroupId) {
if (!jsonObj.containsKey("uuid") || !jsonObj.containsKey("groupId")) {
log.warn("[Validation] Missing required fields. Existing keys: {}", jsonObj.keySet());
return false;
}
String actualUuid = jsonObj.getString("uuid");
String actualGroupId = jsonObj.getString("groupId");
return Objects.equals(expectedUuid, actualUuid)
&& Objects.equals(expectedGroupId, actualGroupId);
}
/**
* 公共模型聊天流式处理方法
*
@ -313,10 +373,17 @@ public class ConversationServiceImpl implements ConversationService {
public void publicModelChatStream (ChatReqVO chatReqVO, SseEmitter emitter) {
log.info("开始公共模型聊天流式处理,请求参数: {}", chatReqVO);
// 检查 UUID 是否为空若为空则生成一个
if (StrUtil.isBlank(chatReqVO.getUuid())) {
String uuid = chatReqVO.getUuid();
if (StrUtil.isBlank(uuid)) {
log.info("UUID 为空,生成新的 UUID");
chatReqVO.setUuid(UUID.randomUUID().toString());
uuid = UUID.randomUUID().toString();
chatReqVO.setUuid(uuid);
}
// 为每一次对话设置ID
String groupId = UUID.randomUUID().toString();
chatReqVO.setGroupId(groupId);
String model = null;
String selfModelUrl = "";
// 根据模型类型获取模型信息
@ -362,29 +429,9 @@ public class ConversationServiceImpl implements ConversationService {
List<ModelCompletionsReqVO.ModelCompletionsMessage> messages = new ArrayList<>();
// 如果知识库 ID 不为空先调用知识库获取相关信息
StringBuilder knowledgeBase = new StringBuilder();
if (chatReqVO.getKnowledge() != null && chatReqVO.getKnowledge() != 0) {
log.info("知识库 ID 不为空,开始查询知识库,知识库 ID: {}", chatReqVO.getKnowledge());
LambdaQueryWrapper<KnowledgeDocumentsDO> queryWrapper = new LambdaQueryWrapper<>();
queryWrapper.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, chatReqVO.getKnowledge());
List<KnowledgeDocumentsDO> fileList = knowledgeDocumentsMapper.selectList(queryWrapper);
for (KnowledgeDocumentsDO knowledgeDocumentsDO : fileList) {
Long id = knowledgeDocumentsDO.getId();
KnowledgeRagEmbedQueryVO knowledgeRagEmbedQueryVO = new KnowledgeRagEmbedQueryVO();
knowledgeRagEmbedQueryVO.setFile_id(id.toString());
knowledgeRagEmbedQueryVO.setQuery(chatReqVO.getPrompt());
String result = HttpUtils.post(llmBackendProperties.getEmbedQuery(), null, JSON.toJSONString(knowledgeRagEmbedQueryVO));
com.alibaba.fastjson.JSONArray jsonArray = JSON.parseArray(result);
if (jsonArray != null && !jsonArray.isEmpty()) {
com.alibaba.fastjson.JSONArray jsonArray1 = (com.alibaba.fastjson.JSONArray) jsonArray.get(0);
JSONObject jsonObject = (JSONObject) jsonArray1.get(0);
knowledgeBase.append(jsonObject.get("page_content"));
}
}
log.info("知识库查询完成,获取到的信息: {}", knowledgeBase.toString());
}
String mess = chatReqVO.getSystemPrompt() + "" + knowledgeBase.toString() + "";
StringBuilder knowledgeBase = getKnowledgeBase(chatReqVO);
String mess = chatReqVO.getSystemPrompt() + knowledgeBase.toString();
// 查询历史记录消息并将查询出来的知识信息放入到 role = system 的消息中
List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
if (messageHistoryList != null && !messageHistoryList.isEmpty()) {
@ -416,10 +463,10 @@ public class ConversationServiceImpl implements ConversationService {
ModelCompletionsReqVO modelCompletionsReqVO = new ModelCompletionsReqVO();
modelCompletionsReqVO.setMessages(messages);
modelCompletionsReqVO.setModel(model);
log.info("构建模型补全请求对象请求参数1: {}", modelCompletionsReqVO);
// log.info("构建模型补全请求对象请求参数1: {}", modelCompletionsReqVO);
// 调用模型服务进行流式处理
ModelCompletionsRespVO modelCompletionsRespVO = modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter,chatReqVO.getUuid());
ModelCompletionsRespVO modelCompletionsRespVO = modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter, chatReqVO.getUuid(), chatReqVO.getGroupId());
if (modelCompletionsRespVO == null) {
throw exception(MODEL_COMPLETIONS_ERROR);
}
@ -442,6 +489,150 @@ public class ConversationServiceImpl implements ConversationService {
log.info("数据回流信息保存完成");
}
@NotNull
private StringBuilder getKnowledgeBase (ChatReqVO chatReqVO) {
final String LOG_PREFIX = "[KnowledgeBase]";
StringBuilder knowledgeBase = new StringBuilder();
// 参数有效性检查
if (chatReqVO.getKnowledge() == null || chatReqVO.getKnowledge() == 0L) {
log.info("{} 未启用知识库检索knowledgeId: {}", LOG_PREFIX, chatReqVO.getKnowledge());
return knowledgeBase;
}
log.info("{} 开始知识库检索knowledgeId: {}", LOG_PREFIX, chatReqVO.getKnowledge());
// 如果知识库 ID 不为空先调用知识库获取相关信息
try {
// 1. 查询知识库文档列表
List<KnowledgeDocumentsDO> documentList = knowledgeDocumentsMapper.selectList(new LambdaQueryWrapper<KnowledgeDocumentsDO>()
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, chatReqVO.getKnowledge()));
log.info("{} 查询到{}个关联文档", LOG_PREFIX, documentList.size());
// 解析响应数据
ParagraphHitRateListVO paragraphHitRateListVO = new ParagraphHitRateListVO();
paragraphHitRateListVO.setUuid(chatReqVO.getUuid());
paragraphHitRateListVO.setGroupId(chatReqVO.getGroupId());
List<ParagraphHitRateWordVO> words = new ArrayList<>();
// 2. 遍历处理每个文档
for (KnowledgeDocumentsDO document : documentList) {
ParagraphHitRateWordVO rateWordVO = processDocument(document, chatReqVO, knowledgeBase);
words.add(rateWordVO);
}
paragraphHitRateListVO.setWordList(words);
// 请求结果添加到 Redis查询段落命中率
String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, chatReqVO.getUuid());
stringRedisTemplate.opsForList().rightPushIfPresent(redisKey, JSON.toJSONString(paragraphHitRateListVO));
List<String> paragraphHitRateList = stringRedisTemplate.opsForList().range(redisKey, 0, -1);
if (paragraphHitRateList != null && !paragraphHitRateList.isEmpty()) {
log.info("{} 知识库查询段落命中率: {}", "[KnowledgeBase]", paragraphHitRateList);
}
log.info("{} 知识库构建完成,内容长度: {}", LOG_PREFIX, knowledgeBase.length());
} catch (Exception e) {
log.error("{} 知识库处理异常: {}", LOG_PREFIX, e.getMessage(), e);
throw new ServiceException(100, "知识库处理失败,请稍后重试");
}
return knowledgeBase;
}
/**
* 处理单个知识库文档的检索逻辑
*/
private ParagraphHitRateWordVO processDocument (KnowledgeDocumentsDO document,
ChatReqVO chatReqVO,
StringBuilder knowledgeBase) {
ParagraphHitRateWordVO paragraphHitRateListVO = new ParagraphHitRateWordVO();
try {
log.info("{} 处理文档: {}[ID:{}]", "[KnowledgeBase]", document.getDocumentName(), document.getId());
// 构建查询请求
KnowledgeRagEmbedQueryVO queryVO = new KnowledgeRagEmbedQueryVO()
.setFile_id(document.getId().toString())
.setQuery(chatReqVO.getPrompt());
// 发送向量检索请求
long start = System.currentTimeMillis();
String response = HttpUtils.post(
llmBackendProperties.getEmbedQuery(),
null,
JSON.toJSONString(queryVO)
);
log.info("{} 文档[{}]检索耗时: {}ms", "[KnowledgeBase]", document.getId(),
System.currentTimeMillis() - start);
log.info("[KnowledgeBase] 知识库请求结果:{}", response);
paragraphHitRateListVO.setDocumentName(document.getDocumentName());
paragraphHitRateListVO.setParagraphHitRate(parseEmbeddingResponse(response, knowledgeBase, chatReqVO.getUuid()));
} catch (Exception e) {
log.warn("{} 文档[{}]处理异常: {}", "[KnowledgeBase]", document.getId(), e.getMessage());
// 单个文档失败不影响整体流程
}
return paragraphHitRateListVO;
}
/**
* 解析向量检索响应结果
*/
private List<ParagraphHitRateVO> parseEmbeddingResponse (String response,
StringBuilder knowledgeBase,
String uuid) {
if (StringUtils.isBlank(response)) {
log.warn("{} 收到空响应", "[KnowledgeBase]");
return Collections.emptyList();
}
List<ParagraphHitRateVO> paragraphHitRateList = new ArrayList<>();
try {
com.alibaba.fastjson.JSONArray resultArray = JSON.parseArray(response);
if (resultArray == null || resultArray.isEmpty()) {
log.info("{} 无有效检索结果", "[KnowledgeBase]");
return Collections.emptyList();
}
// 获取第一个结果集
com.alibaba.fastjson.JSONArray firstResult = resultArray.getJSONArray(0);
if (firstResult.isEmpty()) {
return Collections.emptyList();
}
// 提取页面内容
JSONObject content = firstResult.getJSONObject(0);
String pageContent = content.getString("page_content");
log.info("{} 内容: {}", "[KnowledgeBase]", JSON.toJSONString(pageContent));
JSONObject metadata = content.getJSONObject("metadata");
String fileId = metadata.getString("file_id");
Double rate = firstResult.getDouble(1);
DecimalFormat df = new DecimalFormat("#%");
// 明确指定四舍五入模式
df.setRoundingMode(RoundingMode.HALF_UP);
String rateResult = df.format(rate);
log.info("{} 命中率: {}", "[KnowledgeBase]", rateResult);
if (StringUtils.isNotBlank(pageContent)) {
knowledgeBase.append("\n[知识库内容] ").append(pageContent);
log.info("{} 添加知识内容,长度: {}", "[KnowledgeBase]", pageContent.length());
ParagraphHitRateVO paragraphHitRateVO = new ParagraphHitRateVO();
paragraphHitRateVO.setParagraph(pageContent);
paragraphHitRateVO.setHitRate(rateResult);
paragraphHitRateVO.setWordCount(pageContent.length());
paragraphHitRateList.add(paragraphHitRateVO);
}
} catch (JSONException e) {
log.error("{} 响应解析失败: {} | 原始响应: {}", "[KnowledgeBase]", e.getMessage(), response);
throw new RuntimeException("知识库响应解析异常");
}
return paragraphHitRateList;
}
/**
* 私有模型聊天

View File

@ -17,6 +17,11 @@ public class ChatReqVO {
*/
private String uuid;
/**
* 对话分组id
*/
private String groupId;
/**
* 是否结束对话
*/