refactor(llm): 优化聊天逻辑和知识库处理

-调整系统提示和知识库的处理顺序
- 优化聊天
This commit is contained in:
sunxiqing 2025-03-14 23:24:14 +08:00
parent 9783c5ceb7
commit 1d7d615af4

View File

@ -64,6 +64,16 @@ import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*;
@Slf4j
public class ConversationServiceImpl implements ConversationService {
public static final String PROMPT = "Use the following context as your learned knowledge, inside <context></context> XML tags.When answer to user:\\n- If you don't know, just say that you don't know.\\n- If you don't know when you are not sure, ask for clarification.\\nAvoid mentioning that you obtained the information from the context.\\nAnd answer according to the language of the user's question.\\n\\n";
// 聊天会话历史记录缓存Key
private final static String CHAT_HIStORY_REDIS_KEY = "llm:chat:history";
// 聊天会话历史记录缓存时间
private final static Long CHAT_HISTORY_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L;
/**
* 知识库文档缓存Key
*/
private final static String KNOWLEDGE_DOCUMENTS_REDIS_KEY = "llm:knowledge:documents";
private final static Long KNOWLEDGE_DOCUMENTS_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L;
@Resource
private ConversationMapper conversationMapper;
@Resource
@ -84,24 +94,11 @@ public class ConversationServiceImpl implements ConversationService {
private KnowledgeDocumentsMapper knowledgeDocumentsMapper;
@Resource
private LLMBackendProperties llmBackendProperties;
@Resource
private KnowledgeBaseService knowledgeBaseService;
// 聊天会话历史记录缓存Key
private final static String CHAT_HIStORY_REDIS_KEY = "llm:chat:history";
// 聊天会话历史记录缓存时间
private final static Long CHAT_HISTORY_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L;
/**
* 知识库文档缓存Key
*/
private final static String KNOWLEDGE_DOCUMENTS_REDIS_KEY = "llm:knowledge:documents";
private final static Long KNOWLEDGE_DOCUMENTS_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L;
public static final String PROMPT = "Use the following context as your learned knowledge, inside <context></context> XML tags.When answer to user:\\n- If you don't know, just say that you don't know.\\n- If you don't know when you are not sure, ask for clarification.\\nAvoid mentioning that you obtained the information from the context.\\nAnd answer according to the language of the user's question.\\n\\n";
@Override
public Integer createConversation (ConversationSaveReqVO createReqVO) {
public Integer createConversation(ConversationSaveReqVO createReqVO) {
// 插入
ConversationDO conversation = BeanUtils.toBean(createReqVO, ConversationDO.class);
conversationMapper.insert(conversation);
@ -110,7 +107,7 @@ public class ConversationServiceImpl implements ConversationService {
}
@Override
public void updateConversation (ConversationSaveReqVO updateReqVO) {
public void updateConversation(ConversationSaveReqVO updateReqVO) {
// 校验存在
validateConversationExists(updateReqVO.getId());
// 更新
@ -119,31 +116,31 @@ public class ConversationServiceImpl implements ConversationService {
}
@Override
public void deleteConversation (Integer id) {
public void deleteConversation(Integer id) {
// 校验存在
validateConversationExists(id);
// 删除
conversationMapper.deleteById(id);
}
private void validateConversationExists (Integer id) {
private void validateConversationExists(Integer id) {
if (conversationMapper.selectById(id) == null) {
throw exception(CONVERSATION_NOT_EXISTS);
}
}
@Override
public ConversationDO getConversation (Integer id) {
public ConversationDO getConversation(Integer id) {
return conversationMapper.selectById(id);
}
@Override
public PageResult<ConversationDO> getConversationPage (ConversationPageReqVO pageReqVO) {
public PageResult<ConversationDO> getConversationPage(ConversationPageReqVO pageReqVO) {
return conversationMapper.selectPage(pageReqVO);
}
@Override
public ChatRespVO chat (ChatReqVO chatReqVO) {
public ChatRespVO chat(ChatReqVO chatReqVO) {
if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().equals("")) {
if (chatReqVO.getApplicationId() != null) {
ApplicationRespVO application = applicationService.getApplication(chatReqVO.getApplicationId());
@ -162,7 +159,7 @@ public class ConversationServiceImpl implements ConversationService {
}
@Override
public JSONArray textToImage (TextToImageReqVo req) {
public JSONArray textToImage(TextToImageReqVo req) {
TextToImageRespVo textToImageRespVo = modelService.textToImage(req);
return textToImageRespVo.getData();
}
@ -174,7 +171,7 @@ public class ConversationServiceImpl implements ConversationService {
* @param chatReqVO
* @return
*/
public ChatRespVO publicModelChat (ChatReqVO chatReqVO) {
public ChatRespVO publicModelChat(ChatReqVO chatReqVO) {
if (StringUtils.isBlank(chatReqVO.getUuid())) {
// 如果没有uuid就生成一个
chatReqVO.setUuid(UUID.randomUUID().toString());
@ -300,7 +297,7 @@ public class ConversationServiceImpl implements ConversationService {
* @param emitter SseEmitter 对象用于流式发送响应
*/
@Override
public void chatStream (ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response) {
public void chatStream(ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response) {
log.info("开始处理对话请求,请求参数: {}", chatReqVO);
// 检查系统提示信息如果为空则尝试从应用中获取
if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().isEmpty()) {
@ -317,11 +314,11 @@ public class ConversationServiceImpl implements ConversationService {
// 知识库ID
if (application.getModelServiceId()!= null){
if (application.getModelServiceId() != null) {
chatReqVO.setKnowledge(application.getModelServiceId());
}
if (!org.apache.commons.lang3.StringUtils.isBlank(application.getPrompt())){
if (!org.apache.commons.lang3.StringUtils.isBlank(application.getPrompt())) {
chatReqVO.setSystemPrompt(application.getPrompt());
if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().isEmpty()) {
log.info("应用中未找到系统提示信息,使用默认提示信息");
@ -343,7 +340,7 @@ public class ConversationServiceImpl implements ConversationService {
* @return
*/
@Override
public ParagraphHitRateListVO getParagraphHitRate (String uuid, String groupId) {
public ParagraphHitRateListVO getParagraphHitRate(String uuid, String groupId) {
String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, uuid);
List<String> redisResults = stringRedisTemplate.opsForList().range(redisKey, 0, -1);
log.info("[Redis Query] Key: {} | Results count: {}", redisKey, redisResults != null ? redisResults.size() : 0);
@ -374,7 +371,7 @@ public class ConversationServiceImpl implements ConversationService {
/**
* 验证命中率数据有效性
*/
private boolean isValidHitRateData (JSONObject jsonObj, String expectedUuid, String expectedGroupId) {
private boolean isValidHitRateData(JSONObject jsonObj, String expectedUuid, String expectedGroupId) {
if (!jsonObj.containsKey("uuid") || !jsonObj.containsKey("groupId")) {
log.warn("[Validation] Missing required fields. Existing keys: {}", jsonObj.keySet());
return false;
@ -384,7 +381,7 @@ public class ConversationServiceImpl implements ConversationService {
String actualGroupId = jsonObj.getString("groupId");
return Objects.equals(expectedUuid, actualUuid)
&& Objects.equals(expectedGroupId, actualGroupId);
&& Objects.equals(expectedGroupId, actualGroupId);
}
@ -394,7 +391,7 @@ public class ConversationServiceImpl implements ConversationService {
* @param chatReqVO 对话请求对象
* @param emitter SseEmitter 对象用于流式发送响应
*/
public void publicModelChatStream (ChatReqVO chatReqVO, SseEmitter emitter) {
public void publicModelChatStream(ChatReqVO chatReqVO, SseEmitter emitter) {
log.info("开始公共模型聊天流式处理,请求参数: {}", chatReqVO);
// 检查 UUID 是否为空若为空则生成一个
String uuid = chatReqVO.getUuid();
@ -461,7 +458,7 @@ public class ConversationServiceImpl implements ConversationService {
// 处理 knowledgeBaseString
if (StringUtils.isNotBlank(knowledgeBaseString)) {
knowledgeBaseString = "<context>" + knowledgeBaseString + "</context>";
}else {
} else {
knowledgeBaseString = "<context>" + "</context>";
}
@ -470,28 +467,38 @@ public class ConversationServiceImpl implements ConversationService {
? PROMPT
: chatReqVO.getSystemPrompt() + " \n " + PROMPT;
}
String mess = systemPrompt + " \n "+knowledgeBaseString;
String mess = systemPrompt + " \n " + knowledgeBaseString;
if (chatReqVO.getKnowledge() != null) {
log.info("不存在聊天历史记录,创建新的系统消息");
ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
systemMessage.setRole("system");
systemMessage.setContent(mess);
stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
messages.add(systemMessage);
} else {
// 查询历史记录消息并将查询出来的知识信息放入到 role = system 的消息中
List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
if (messageHistoryList != null && !messageHistoryList.isEmpty()) {
log.info("存在聊天历史记录,处理历史记录消息");
for (String messageHistory : messageHistoryList) {
ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class);
if ("system".equals(modelCompletionsMessage.getRole())) {
modelCompletionsMessage.setContent(mess);
stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
}
messages.add(modelCompletionsMessage);
}
} else {
log.info("不存在聊天历史记录,创建新的系统消息");
ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
systemMessage.setRole("system");
systemMessage.setContent(mess);
stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
messages.add(systemMessage);
}
}
// // 查询历史记录消息并将查询出来的知识信息放入到 role = system 的消息中
// List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
// if (messageHistoryList != null && !messageHistoryList.isEmpty()) {
// log.info("存在聊天历史记录,处理历史记录消息");
// for (String messageHistory : messageHistoryList) {
// ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class);
// if ("system".equals(modelCompletionsMessage.getRole())) {
// modelCompletionsMessage.setContent(mess);
// stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
// }
// messages.add(modelCompletionsMessage);
// }
// } else {
log.info("不存在聊天历史记录,创建新的系统消息");
ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
systemMessage.setRole("system");
systemMessage.setContent(mess);
stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
messages.add(systemMessage);
// }
// 创建用户消息
ModelCompletionsReqVO.ModelCompletionsMessage message = new ModelCompletionsReqVO.ModelCompletionsMessage();
@ -509,7 +516,7 @@ public class ConversationServiceImpl implements ConversationService {
// log.info("构建模型补全请求对象请求参数1: {}", modelCompletionsReqVO);
// 调用模型服务进行流式处理
ModelCompletionsRespVO modelCompletionsRespVO = modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter, chatReqVO.getUuid(), chatReqVO.getGroupId(),knowledgeBaseString);
ModelCompletionsRespVO modelCompletionsRespVO = modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter, chatReqVO.getUuid(), chatReqVO.getGroupId(), knowledgeBaseString);
if (modelCompletionsRespVO == null) {
throw exception(MODEL_COMPLETIONS_ERROR);
}
@ -533,7 +540,7 @@ public class ConversationServiceImpl implements ConversationService {
}
@NotNull
private StringBuilder getKnowledgeBase (ChatReqVO chatReqVO) {
private StringBuilder getKnowledgeBase(ChatReqVO chatReqVO) {
final String LOG_PREFIX = "[KnowledgeBase]";
StringBuilder knowledgeBase = new StringBuilder();
@ -557,12 +564,12 @@ public class ConversationServiceImpl implements ConversationService {
paragraphHitRateListVO.setUuid(chatReqVO.getUuid());
paragraphHitRateListVO.setGroupId(chatReqVO.getGroupId());
KnowledgeHitRateTestReqVO testReqVO=new KnowledgeHitRateTestReqVO();
KnowledgeHitRateTestReqVO testReqVO = new KnowledgeHitRateTestReqVO();
testReqVO.setKnowledgeId(chatReqVO.getKnowledge());
testReqVO.setQuery(chatReqVO.getPrompt());
List<KnowledgeHitRateTestResultVO> result = knowledgeBaseService.executeHitRateTest(testReqVO);
knowledgeBase = handlerResult(result, paragraphHitRateListVO);
knowledgeBase = handlerResult(result, paragraphHitRateListVO);
log.info("{} 知识库构建完成,内容长度: {}", LOG_PREFIX, knowledgeBase.length());
@ -575,8 +582,8 @@ public class ConversationServiceImpl implements ConversationService {
return knowledgeBase;
}
private StringBuilder handlerResult (List<KnowledgeHitRateTestResultVO> result, ParagraphHitRateListVO paragraphHitRateListVO) {
if (CollectionUtils.isEmpty(result)){
private StringBuilder handlerResult(List<KnowledgeHitRateTestResultVO> result, ParagraphHitRateListVO paragraphHitRateListVO) {
if (CollectionUtils.isEmpty(result)) {
return new StringBuilder();
}
@ -591,8 +598,8 @@ public class ConversationServiceImpl implements ConversationService {
return knowledgeBase;
}
private void saveRedis (List<KnowledgeHitRateTestResultVO> result, ParagraphHitRateListVO paragraphHitRateListVO) {
if (CollectionUtils.isEmpty(result)){
private void saveRedis(List<KnowledgeHitRateTestResultVO> result, ParagraphHitRateListVO paragraphHitRateListVO) {
if (CollectionUtils.isEmpty(result)) {
return;
}
List<ParagraphHitRateWordVO> words = new ArrayList<>();
@ -605,13 +612,13 @@ public class ConversationServiceImpl implements ConversationService {
List<ParagraphHitRateWordVO> finalWords = words;
groupedByFileId.forEach((fileId, list) -> {
System.out.println("File ID: " + fileId);
list.forEach(i->{
list.forEach(i -> {
ParagraphHitRateWordVO rateWordVO = new ParagraphHitRateWordVO();
// 设置文档名称
rateWordVO.setDocumentName(i.getFileName());
// 设置段落命中率
List<ParagraphHitRateVO> paragraphHitRate=new ArrayList<>();
List<ParagraphHitRateVO> paragraphHitRate = new ArrayList<>();
for (KnowledgeHitRateTestResultVO i1 : list) {
ParagraphHitRateVO rateVO = new ParagraphHitRateVO();
rateVO.setParagraph(i1.getPageContent());
@ -649,9 +656,9 @@ public class ConversationServiceImpl implements ConversationService {
/**
* 处理单个知识库文档的检索逻辑
*/
private ParagraphHitRateWordVO processDocument (KnowledgeDocumentsDO document,
ChatReqVO chatReqVO,
StringBuilder knowledgeBase) {
private ParagraphHitRateWordVO processDocument(KnowledgeDocumentsDO document,
ChatReqVO chatReqVO,
StringBuilder knowledgeBase) {
ParagraphHitRateWordVO paragraphHitRateListVO = new ParagraphHitRateWordVO();
try {
log.info("{} 处理文档: {}[ID:{}]", "[KnowledgeBase]", document.getDocumentName(), document.getId());
@ -688,9 +695,9 @@ public class ConversationServiceImpl implements ConversationService {
/**
* 解析向量检索响应结果
*/
private List<ParagraphHitRateVO> parseEmbeddingResponse (String response,
StringBuilder knowledgeBase,
String uuid) {
private List<ParagraphHitRateVO> parseEmbeddingResponse(String response,
StringBuilder knowledgeBase,
String uuid) {
if (StringUtils.isBlank(response)) {
log.warn("{} 收到空响应", "[KnowledgeBase]");
return Collections.emptyList();
@ -756,7 +763,7 @@ public class ConversationServiceImpl implements ConversationService {
* @param chatReqVO
* @return
*/
private ChatRespVO privateModelChat (ChatReqVO chatReqVO) {
private ChatRespVO privateModelChat(ChatReqVO chatReqVO) {
if (StringUtils.isBlank(chatReqVO.getUuid())) {
// 如果没有uuid就生成一个
chatReqVO.setUuid(UUID.randomUUID().toString());