From 1d7d615af4ce655ffb8b4af012848f53cc6fa5a1 Mon Sep 17 00:00:00 2001
From: sunxiqing <2240398334@qq.com>
Date: Fri, 14 Mar 2025 23:24:14 +0800
Subject: [PATCH] =?UTF-8?q?refactor(llm):=20=E4=BC=98=E5=8C=96=E8=81=8A?=
=?UTF-8?q?=E5=A4=A9=E9=80=BB=E8=BE=91=E5=92=8C=E7=9F=A5=E8=AF=86=E5=BA=93?=
=?UTF-8?q?=E5=A4=84=E7=90=86?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
-调整系统提示和知识库的处理顺序
- 优化聊天
---
.../conversation/ConversationServiceImpl.java | 143 +++++++++---------
1 file changed, 75 insertions(+), 68 deletions(-)
diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java
index 435a16c7d..15a6cc474 100644
--- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java
+++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java
@@ -64,6 +64,16 @@ import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*;
@Slf4j
public class ConversationServiceImpl implements ConversationService {
+ public static final String PROMPT = "Use the following context as your learned knowledge, inside XML tags.When answer to user:\\n- If you don't know, just say that you don't know.\\n- If you don't know when you are not sure, ask for clarification.\\nAvoid mentioning that you obtained the information from the context.\\nAnd answer according to the language of the user's question.\\n\\n";
+ // 聊天会话历史记录缓存Key
+ private final static String CHAT_HIStORY_REDIS_KEY = "llm:chat:history";
+ // 聊天会话历史记录缓存时间
+ private final static Long CHAT_HISTORY_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L;
+ /**
+ * 知识库文档缓存Key
+ */
+ private final static String KNOWLEDGE_DOCUMENTS_REDIS_KEY = "llm:knowledge:documents";
+ private final static Long KNOWLEDGE_DOCUMENTS_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L;
@Resource
private ConversationMapper conversationMapper;
@Resource
@@ -84,24 +94,11 @@ public class ConversationServiceImpl implements ConversationService {
private KnowledgeDocumentsMapper knowledgeDocumentsMapper;
@Resource
private LLMBackendProperties llmBackendProperties;
-
@Resource
private KnowledgeBaseService knowledgeBaseService;
- // 聊天会话历史记录缓存Key
- private final static String CHAT_HIStORY_REDIS_KEY = "llm:chat:history";
- // 聊天会话历史记录缓存时间
- private final static Long CHAT_HISTORY_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L;
-
- /**
- * 知识库文档缓存Key
- */
- private final static String KNOWLEDGE_DOCUMENTS_REDIS_KEY = "llm:knowledge:documents";
- private final static Long KNOWLEDGE_DOCUMENTS_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L;
- public static final String PROMPT = "Use the following context as your learned knowledge, inside XML tags.When answer to user:\\n- If you don't know, just say that you don't know.\\n- If you don't know when you are not sure, ask for clarification.\\nAvoid mentioning that you obtained the information from the context.\\nAnd answer according to the language of the user's question.\\n\\n";
-
@Override
- public Integer createConversation (ConversationSaveReqVO createReqVO) {
+ public Integer createConversation(ConversationSaveReqVO createReqVO) {
// 插入
ConversationDO conversation = BeanUtils.toBean(createReqVO, ConversationDO.class);
conversationMapper.insert(conversation);
@@ -110,7 +107,7 @@ public class ConversationServiceImpl implements ConversationService {
}
@Override
- public void updateConversation (ConversationSaveReqVO updateReqVO) {
+ public void updateConversation(ConversationSaveReqVO updateReqVO) {
// 校验存在
validateConversationExists(updateReqVO.getId());
// 更新
@@ -119,31 +116,31 @@ public class ConversationServiceImpl implements ConversationService {
}
@Override
- public void deleteConversation (Integer id) {
+ public void deleteConversation(Integer id) {
// 校验存在
validateConversationExists(id);
// 删除
conversationMapper.deleteById(id);
}
- private void validateConversationExists (Integer id) {
+ private void validateConversationExists(Integer id) {
if (conversationMapper.selectById(id) == null) {
throw exception(CONVERSATION_NOT_EXISTS);
}
}
@Override
- public ConversationDO getConversation (Integer id) {
+ public ConversationDO getConversation(Integer id) {
return conversationMapper.selectById(id);
}
@Override
- public PageResult getConversationPage (ConversationPageReqVO pageReqVO) {
+ public PageResult getConversationPage(ConversationPageReqVO pageReqVO) {
return conversationMapper.selectPage(pageReqVO);
}
@Override
- public ChatRespVO chat (ChatReqVO chatReqVO) {
+ public ChatRespVO chat(ChatReqVO chatReqVO) {
if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().equals("")) {
if (chatReqVO.getApplicationId() != null) {
ApplicationRespVO application = applicationService.getApplication(chatReqVO.getApplicationId());
@@ -162,7 +159,7 @@ public class ConversationServiceImpl implements ConversationService {
}
@Override
- public JSONArray textToImage (TextToImageReqVo req) {
+ public JSONArray textToImage(TextToImageReqVo req) {
TextToImageRespVo textToImageRespVo = modelService.textToImage(req);
return textToImageRespVo.getData();
}
@@ -174,7 +171,7 @@ public class ConversationServiceImpl implements ConversationService {
* @param chatReqVO
* @return
*/
- public ChatRespVO publicModelChat (ChatReqVO chatReqVO) {
+ public ChatRespVO publicModelChat(ChatReqVO chatReqVO) {
if (StringUtils.isBlank(chatReqVO.getUuid())) {
// 如果没有uuid,就生成一个
chatReqVO.setUuid(UUID.randomUUID().toString());
@@ -300,7 +297,7 @@ public class ConversationServiceImpl implements ConversationService {
* @param emitter SseEmitter 对象,用于流式发送响应
*/
@Override
- public void chatStream (ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response) {
+ public void chatStream(ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response) {
log.info("开始处理对话请求,请求参数: {}", chatReqVO);
// 检查系统提示信息,如果为空则尝试从应用中获取
if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().isEmpty()) {
@@ -317,11 +314,11 @@ public class ConversationServiceImpl implements ConversationService {
// 知识库ID
- if (application.getModelServiceId()!= null){
+ if (application.getModelServiceId() != null) {
chatReqVO.setKnowledge(application.getModelServiceId());
}
- if (!org.apache.commons.lang3.StringUtils.isBlank(application.getPrompt())){
+ if (!org.apache.commons.lang3.StringUtils.isBlank(application.getPrompt())) {
chatReqVO.setSystemPrompt(application.getPrompt());
if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().isEmpty()) {
log.info("应用中未找到系统提示信息,使用默认提示信息");
@@ -343,7 +340,7 @@ public class ConversationServiceImpl implements ConversationService {
* @return
*/
@Override
- public ParagraphHitRateListVO getParagraphHitRate (String uuid, String groupId) {
+ public ParagraphHitRateListVO getParagraphHitRate(String uuid, String groupId) {
String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, uuid);
List redisResults = stringRedisTemplate.opsForList().range(redisKey, 0, -1);
log.info("[Redis Query] Key: {} | Results count: {}", redisKey, redisResults != null ? redisResults.size() : 0);
@@ -374,7 +371,7 @@ public class ConversationServiceImpl implements ConversationService {
/**
* 验证命中率数据有效性
*/
- private boolean isValidHitRateData (JSONObject jsonObj, String expectedUuid, String expectedGroupId) {
+ private boolean isValidHitRateData(JSONObject jsonObj, String expectedUuid, String expectedGroupId) {
if (!jsonObj.containsKey("uuid") || !jsonObj.containsKey("groupId")) {
log.warn("[Validation] Missing required fields. Existing keys: {}", jsonObj.keySet());
return false;
@@ -384,7 +381,7 @@ public class ConversationServiceImpl implements ConversationService {
String actualGroupId = jsonObj.getString("groupId");
return Objects.equals(expectedUuid, actualUuid)
- && Objects.equals(expectedGroupId, actualGroupId);
+ && Objects.equals(expectedGroupId, actualGroupId);
}
@@ -394,7 +391,7 @@ public class ConversationServiceImpl implements ConversationService {
* @param chatReqVO 对话请求对象
* @param emitter SseEmitter 对象,用于流式发送响应
*/
- public void publicModelChatStream (ChatReqVO chatReqVO, SseEmitter emitter) {
+ public void publicModelChatStream(ChatReqVO chatReqVO, SseEmitter emitter) {
log.info("开始公共模型聊天流式处理,请求参数: {}", chatReqVO);
// 检查 UUID 是否为空,若为空则生成一个
String uuid = chatReqVO.getUuid();
@@ -461,7 +458,7 @@ public class ConversationServiceImpl implements ConversationService {
// 处理 knowledgeBaseString
if (StringUtils.isNotBlank(knowledgeBaseString)) {
knowledgeBaseString = "" + knowledgeBaseString + "";
- }else {
+ } else {
knowledgeBaseString = "" + "";
}
@@ -470,28 +467,38 @@ public class ConversationServiceImpl implements ConversationService {
? PROMPT
: chatReqVO.getSystemPrompt() + " \n " + PROMPT;
}
- String mess = systemPrompt + " \n "+knowledgeBaseString;
+ String mess = systemPrompt + " \n " + knowledgeBaseString;
+
+ if (chatReqVO.getKnowledge() != null) {
+ log.info("不存在聊天历史记录,创建新的系统消息");
+ ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
+ systemMessage.setRole("system");
+ systemMessage.setContent(mess);
+ stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
+ messages.add(systemMessage);
+ } else {
+ // 查询历史记录消息,并将查询出来的知识信息放入到 role = system 的消息中
+ List messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
+ if (messageHistoryList != null && !messageHistoryList.isEmpty()) {
+ log.info("存在聊天历史记录,处理历史记录消息");
+ for (String messageHistory : messageHistoryList) {
+ ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class);
+ if ("system".equals(modelCompletionsMessage.getRole())) {
+ modelCompletionsMessage.setContent(mess);
+ stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
+ }
+ messages.add(modelCompletionsMessage);
+ }
+ } else {
+ log.info("不存在聊天历史记录,创建新的系统消息");
+ ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
+ systemMessage.setRole("system");
+ systemMessage.setContent(mess);
+ stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
+ messages.add(systemMessage);
+ }
+ }
- // // 查询历史记录消息,并将查询出来的知识信息放入到 role = system 的消息中
- // List messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
- // if (messageHistoryList != null && !messageHistoryList.isEmpty()) {
- // log.info("存在聊天历史记录,处理历史记录消息");
- // for (String messageHistory : messageHistoryList) {
- // ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class);
- // if ("system".equals(modelCompletionsMessage.getRole())) {
- // modelCompletionsMessage.setContent(mess);
- // stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
- // }
- // messages.add(modelCompletionsMessage);
- // }
- // } else {
- log.info("不存在聊天历史记录,创建新的系统消息");
- ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
- systemMessage.setRole("system");
- systemMessage.setContent(mess);
- stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
- messages.add(systemMessage);
- // }
// 创建用户消息
ModelCompletionsReqVO.ModelCompletionsMessage message = new ModelCompletionsReqVO.ModelCompletionsMessage();
@@ -509,7 +516,7 @@ public class ConversationServiceImpl implements ConversationService {
// log.info("构建模型补全请求对象,请求参数1: {}", modelCompletionsReqVO);
// 调用模型服务进行流式处理
- ModelCompletionsRespVO modelCompletionsRespVO = modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter, chatReqVO.getUuid(), chatReqVO.getGroupId(),knowledgeBaseString);
+ ModelCompletionsRespVO modelCompletionsRespVO = modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter, chatReqVO.getUuid(), chatReqVO.getGroupId(), knowledgeBaseString);
if (modelCompletionsRespVO == null) {
throw exception(MODEL_COMPLETIONS_ERROR);
}
@@ -533,7 +540,7 @@ public class ConversationServiceImpl implements ConversationService {
}
@NotNull
- private StringBuilder getKnowledgeBase (ChatReqVO chatReqVO) {
+ private StringBuilder getKnowledgeBase(ChatReqVO chatReqVO) {
final String LOG_PREFIX = "[KnowledgeBase]";
StringBuilder knowledgeBase = new StringBuilder();
@@ -557,12 +564,12 @@ public class ConversationServiceImpl implements ConversationService {
paragraphHitRateListVO.setUuid(chatReqVO.getUuid());
paragraphHitRateListVO.setGroupId(chatReqVO.getGroupId());
- KnowledgeHitRateTestReqVO testReqVO=new KnowledgeHitRateTestReqVO();
+ KnowledgeHitRateTestReqVO testReqVO = new KnowledgeHitRateTestReqVO();
testReqVO.setKnowledgeId(chatReqVO.getKnowledge());
testReqVO.setQuery(chatReqVO.getPrompt());
List result = knowledgeBaseService.executeHitRateTest(testReqVO);
- knowledgeBase = handlerResult(result, paragraphHitRateListVO);
+ knowledgeBase = handlerResult(result, paragraphHitRateListVO);
log.info("{} 知识库构建完成,内容长度: {}", LOG_PREFIX, knowledgeBase.length());
@@ -575,8 +582,8 @@ public class ConversationServiceImpl implements ConversationService {
return knowledgeBase;
}
- private StringBuilder handlerResult (List result, ParagraphHitRateListVO paragraphHitRateListVO) {
- if (CollectionUtils.isEmpty(result)){
+ private StringBuilder handlerResult(List result, ParagraphHitRateListVO paragraphHitRateListVO) {
+ if (CollectionUtils.isEmpty(result)) {
return new StringBuilder();
}
@@ -591,8 +598,8 @@ public class ConversationServiceImpl implements ConversationService {
return knowledgeBase;
}
- private void saveRedis (List result, ParagraphHitRateListVO paragraphHitRateListVO) {
- if (CollectionUtils.isEmpty(result)){
+ private void saveRedis(List result, ParagraphHitRateListVO paragraphHitRateListVO) {
+ if (CollectionUtils.isEmpty(result)) {
return;
}
List words = new ArrayList<>();
@@ -605,13 +612,13 @@ public class ConversationServiceImpl implements ConversationService {
List finalWords = words;
groupedByFileId.forEach((fileId, list) -> {
System.out.println("File ID: " + fileId);
- list.forEach(i->{
+ list.forEach(i -> {
ParagraphHitRateWordVO rateWordVO = new ParagraphHitRateWordVO();
// 设置文档名称
rateWordVO.setDocumentName(i.getFileName());
// 设置段落命中率
- List paragraphHitRate=new ArrayList<>();
+ List paragraphHitRate = new ArrayList<>();
for (KnowledgeHitRateTestResultVO i1 : list) {
ParagraphHitRateVO rateVO = new ParagraphHitRateVO();
rateVO.setParagraph(i1.getPageContent());
@@ -649,9 +656,9 @@ public class ConversationServiceImpl implements ConversationService {
/**
* 处理单个知识库文档的检索逻辑
*/
- private ParagraphHitRateWordVO processDocument (KnowledgeDocumentsDO document,
- ChatReqVO chatReqVO,
- StringBuilder knowledgeBase) {
+ private ParagraphHitRateWordVO processDocument(KnowledgeDocumentsDO document,
+ ChatReqVO chatReqVO,
+ StringBuilder knowledgeBase) {
ParagraphHitRateWordVO paragraphHitRateListVO = new ParagraphHitRateWordVO();
try {
log.info("{} 处理文档: {}[ID:{}]", "[KnowledgeBase]", document.getDocumentName(), document.getId());
@@ -688,9 +695,9 @@ public class ConversationServiceImpl implements ConversationService {
/**
* 解析向量检索响应结果
*/
- private List parseEmbeddingResponse (String response,
- StringBuilder knowledgeBase,
- String uuid) {
+ private List parseEmbeddingResponse(String response,
+ StringBuilder knowledgeBase,
+ String uuid) {
if (StringUtils.isBlank(response)) {
log.warn("{} 收到空响应", "[KnowledgeBase]");
return Collections.emptyList();
@@ -756,7 +763,7 @@ public class ConversationServiceImpl implements ConversationService {
* @param chatReqVO
* @return
*/
- private ChatRespVO privateModelChat (ChatReqVO chatReqVO) {
+ private ChatRespVO privateModelChat(ChatReqVO chatReqVO) {
if (StringUtils.isBlank(chatReqVO.getUuid())) {
// 如果没有uuid,就生成一个
chatReqVO.setUuid(UUID.randomUUID().toString());