refactor(llm): 优化聊天逻辑和知识库处理
-调整系统提示和知识库的处理顺序 - 优化聊天
This commit is contained in:
parent
9783c5ceb7
commit
1d7d615af4
@ -64,6 +64,16 @@ import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*;
|
||||
@Slf4j
|
||||
public class ConversationServiceImpl implements ConversationService {
|
||||
|
||||
public static final String PROMPT = "Use the following context as your learned knowledge, inside <context></context> XML tags.When answer to user:\\n- If you don't know, just say that you don't know.\\n- If you don't know when you are not sure, ask for clarification.\\nAvoid mentioning that you obtained the information from the context.\\nAnd answer according to the language of the user's question.\\n\\n";
|
||||
// 聊天会话历史记录缓存Key
|
||||
private final static String CHAT_HIStORY_REDIS_KEY = "llm:chat:history";
|
||||
// 聊天会话历史记录缓存时间
|
||||
private final static Long CHAT_HISTORY_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L;
|
||||
/**
|
||||
* 知识库文档缓存Key
|
||||
*/
|
||||
private final static String KNOWLEDGE_DOCUMENTS_REDIS_KEY = "llm:knowledge:documents";
|
||||
private final static Long KNOWLEDGE_DOCUMENTS_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L;
|
||||
@Resource
|
||||
private ConversationMapper conversationMapper;
|
||||
@Resource
|
||||
@ -84,24 +94,11 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
private KnowledgeDocumentsMapper knowledgeDocumentsMapper;
|
||||
@Resource
|
||||
private LLMBackendProperties llmBackendProperties;
|
||||
|
||||
@Resource
|
||||
private KnowledgeBaseService knowledgeBaseService;
|
||||
|
||||
// 聊天会话历史记录缓存Key
|
||||
private final static String CHAT_HIStORY_REDIS_KEY = "llm:chat:history";
|
||||
// 聊天会话历史记录缓存时间
|
||||
private final static Long CHAT_HISTORY_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L;
|
||||
|
||||
/**
|
||||
* 知识库文档缓存Key
|
||||
*/
|
||||
private final static String KNOWLEDGE_DOCUMENTS_REDIS_KEY = "llm:knowledge:documents";
|
||||
private final static Long KNOWLEDGE_DOCUMENTS_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L;
|
||||
public static final String PROMPT = "Use the following context as your learned knowledge, inside <context></context> XML tags.When answer to user:\\n- If you don't know, just say that you don't know.\\n- If you don't know when you are not sure, ask for clarification.\\nAvoid mentioning that you obtained the information from the context.\\nAnd answer according to the language of the user's question.\\n\\n";
|
||||
|
||||
@Override
|
||||
public Integer createConversation (ConversationSaveReqVO createReqVO) {
|
||||
public Integer createConversation(ConversationSaveReqVO createReqVO) {
|
||||
// 插入
|
||||
ConversationDO conversation = BeanUtils.toBean(createReqVO, ConversationDO.class);
|
||||
conversationMapper.insert(conversation);
|
||||
@ -110,7 +107,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateConversation (ConversationSaveReqVO updateReqVO) {
|
||||
public void updateConversation(ConversationSaveReqVO updateReqVO) {
|
||||
// 校验存在
|
||||
validateConversationExists(updateReqVO.getId());
|
||||
// 更新
|
||||
@ -119,31 +116,31 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteConversation (Integer id) {
|
||||
public void deleteConversation(Integer id) {
|
||||
// 校验存在
|
||||
validateConversationExists(id);
|
||||
// 删除
|
||||
conversationMapper.deleteById(id);
|
||||
}
|
||||
|
||||
private void validateConversationExists (Integer id) {
|
||||
private void validateConversationExists(Integer id) {
|
||||
if (conversationMapper.selectById(id) == null) {
|
||||
throw exception(CONVERSATION_NOT_EXISTS);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public ConversationDO getConversation (Integer id) {
|
||||
public ConversationDO getConversation(Integer id) {
|
||||
return conversationMapper.selectById(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageResult<ConversationDO> getConversationPage (ConversationPageReqVO pageReqVO) {
|
||||
public PageResult<ConversationDO> getConversationPage(ConversationPageReqVO pageReqVO) {
|
||||
return conversationMapper.selectPage(pageReqVO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatRespVO chat (ChatReqVO chatReqVO) {
|
||||
public ChatRespVO chat(ChatReqVO chatReqVO) {
|
||||
if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().equals("")) {
|
||||
if (chatReqVO.getApplicationId() != null) {
|
||||
ApplicationRespVO application = applicationService.getApplication(chatReqVO.getApplicationId());
|
||||
@ -162,7 +159,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public JSONArray textToImage (TextToImageReqVo req) {
|
||||
public JSONArray textToImage(TextToImageReqVo req) {
|
||||
TextToImageRespVo textToImageRespVo = modelService.textToImage(req);
|
||||
return textToImageRespVo.getData();
|
||||
}
|
||||
@ -174,7 +171,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
* @param chatReqVO
|
||||
* @return
|
||||
*/
|
||||
public ChatRespVO publicModelChat (ChatReqVO chatReqVO) {
|
||||
public ChatRespVO publicModelChat(ChatReqVO chatReqVO) {
|
||||
if (StringUtils.isBlank(chatReqVO.getUuid())) {
|
||||
// 如果没有uuid,就生成一个
|
||||
chatReqVO.setUuid(UUID.randomUUID().toString());
|
||||
@ -300,7 +297,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
* @param emitter SseEmitter 对象,用于流式发送响应
|
||||
*/
|
||||
@Override
|
||||
public void chatStream (ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response) {
|
||||
public void chatStream(ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response) {
|
||||
log.info("开始处理对话请求,请求参数: {}", chatReqVO);
|
||||
// 检查系统提示信息,如果为空则尝试从应用中获取
|
||||
if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().isEmpty()) {
|
||||
@ -317,11 +314,11 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
|
||||
|
||||
// 知识库ID
|
||||
if (application.getModelServiceId()!= null){
|
||||
if (application.getModelServiceId() != null) {
|
||||
chatReqVO.setKnowledge(application.getModelServiceId());
|
||||
}
|
||||
|
||||
if (!org.apache.commons.lang3.StringUtils.isBlank(application.getPrompt())){
|
||||
if (!org.apache.commons.lang3.StringUtils.isBlank(application.getPrompt())) {
|
||||
chatReqVO.setSystemPrompt(application.getPrompt());
|
||||
if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().isEmpty()) {
|
||||
log.info("应用中未找到系统提示信息,使用默认提示信息");
|
||||
@ -343,7 +340,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public ParagraphHitRateListVO getParagraphHitRate (String uuid, String groupId) {
|
||||
public ParagraphHitRateListVO getParagraphHitRate(String uuid, String groupId) {
|
||||
String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, uuid);
|
||||
List<String> redisResults = stringRedisTemplate.opsForList().range(redisKey, 0, -1);
|
||||
log.info("[Redis Query] Key: {} | Results count: {}", redisKey, redisResults != null ? redisResults.size() : 0);
|
||||
@ -374,7 +371,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
/**
|
||||
* 验证命中率数据有效性
|
||||
*/
|
||||
private boolean isValidHitRateData (JSONObject jsonObj, String expectedUuid, String expectedGroupId) {
|
||||
private boolean isValidHitRateData(JSONObject jsonObj, String expectedUuid, String expectedGroupId) {
|
||||
if (!jsonObj.containsKey("uuid") || !jsonObj.containsKey("groupId")) {
|
||||
log.warn("[Validation] Missing required fields. Existing keys: {}", jsonObj.keySet());
|
||||
return false;
|
||||
@ -384,7 +381,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
String actualGroupId = jsonObj.getString("groupId");
|
||||
|
||||
return Objects.equals(expectedUuid, actualUuid)
|
||||
&& Objects.equals(expectedGroupId, actualGroupId);
|
||||
&& Objects.equals(expectedGroupId, actualGroupId);
|
||||
}
|
||||
|
||||
|
||||
@ -394,7 +391,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
* @param chatReqVO 对话请求对象
|
||||
* @param emitter SseEmitter 对象,用于流式发送响应
|
||||
*/
|
||||
public void publicModelChatStream (ChatReqVO chatReqVO, SseEmitter emitter) {
|
||||
public void publicModelChatStream(ChatReqVO chatReqVO, SseEmitter emitter) {
|
||||
log.info("开始公共模型聊天流式处理,请求参数: {}", chatReqVO);
|
||||
// 检查 UUID 是否为空,若为空则生成一个
|
||||
String uuid = chatReqVO.getUuid();
|
||||
@ -461,7 +458,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
// 处理 knowledgeBaseString
|
||||
if (StringUtils.isNotBlank(knowledgeBaseString)) {
|
||||
knowledgeBaseString = "<context>" + knowledgeBaseString + "</context>";
|
||||
}else {
|
||||
} else {
|
||||
knowledgeBaseString = "<context>" + "</context>";
|
||||
}
|
||||
|
||||
@ -470,28 +467,38 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
? PROMPT
|
||||
: chatReqVO.getSystemPrompt() + " \n " + PROMPT;
|
||||
}
|
||||
String mess = systemPrompt + " \n "+knowledgeBaseString;
|
||||
String mess = systemPrompt + " \n " + knowledgeBaseString;
|
||||
|
||||
if (chatReqVO.getKnowledge() != null) {
|
||||
log.info("不存在聊天历史记录,创建新的系统消息");
|
||||
ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
|
||||
systemMessage.setRole("system");
|
||||
systemMessage.setContent(mess);
|
||||
stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
|
||||
messages.add(systemMessage);
|
||||
} else {
|
||||
// 查询历史记录消息,并将查询出来的知识信息放入到 role = system 的消息中
|
||||
List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
|
||||
if (messageHistoryList != null && !messageHistoryList.isEmpty()) {
|
||||
log.info("存在聊天历史记录,处理历史记录消息");
|
||||
for (String messageHistory : messageHistoryList) {
|
||||
ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class);
|
||||
if ("system".equals(modelCompletionsMessage.getRole())) {
|
||||
modelCompletionsMessage.setContent(mess);
|
||||
stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
|
||||
}
|
||||
messages.add(modelCompletionsMessage);
|
||||
}
|
||||
} else {
|
||||
log.info("不存在聊天历史记录,创建新的系统消息");
|
||||
ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
|
||||
systemMessage.setRole("system");
|
||||
systemMessage.setContent(mess);
|
||||
stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
|
||||
messages.add(systemMessage);
|
||||
}
|
||||
}
|
||||
|
||||
// // 查询历史记录消息,并将查询出来的知识信息放入到 role = system 的消息中
|
||||
// List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
|
||||
// if (messageHistoryList != null && !messageHistoryList.isEmpty()) {
|
||||
// log.info("存在聊天历史记录,处理历史记录消息");
|
||||
// for (String messageHistory : messageHistoryList) {
|
||||
// ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class);
|
||||
// if ("system".equals(modelCompletionsMessage.getRole())) {
|
||||
// modelCompletionsMessage.setContent(mess);
|
||||
// stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
|
||||
// }
|
||||
// messages.add(modelCompletionsMessage);
|
||||
// }
|
||||
// } else {
|
||||
log.info("不存在聊天历史记录,创建新的系统消息");
|
||||
ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
|
||||
systemMessage.setRole("system");
|
||||
systemMessage.setContent(mess);
|
||||
stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
|
||||
messages.add(systemMessage);
|
||||
// }
|
||||
|
||||
// 创建用户消息
|
||||
ModelCompletionsReqVO.ModelCompletionsMessage message = new ModelCompletionsReqVO.ModelCompletionsMessage();
|
||||
@ -509,7 +516,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
// log.info("构建模型补全请求对象,请求参数1: {}", modelCompletionsReqVO);
|
||||
|
||||
// 调用模型服务进行流式处理
|
||||
ModelCompletionsRespVO modelCompletionsRespVO = modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter, chatReqVO.getUuid(), chatReqVO.getGroupId(),knowledgeBaseString);
|
||||
ModelCompletionsRespVO modelCompletionsRespVO = modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter, chatReqVO.getUuid(), chatReqVO.getGroupId(), knowledgeBaseString);
|
||||
if (modelCompletionsRespVO == null) {
|
||||
throw exception(MODEL_COMPLETIONS_ERROR);
|
||||
}
|
||||
@ -533,7 +540,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private StringBuilder getKnowledgeBase (ChatReqVO chatReqVO) {
|
||||
private StringBuilder getKnowledgeBase(ChatReqVO chatReqVO) {
|
||||
final String LOG_PREFIX = "[KnowledgeBase]";
|
||||
StringBuilder knowledgeBase = new StringBuilder();
|
||||
|
||||
@ -557,12 +564,12 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
paragraphHitRateListVO.setUuid(chatReqVO.getUuid());
|
||||
paragraphHitRateListVO.setGroupId(chatReqVO.getGroupId());
|
||||
|
||||
KnowledgeHitRateTestReqVO testReqVO=new KnowledgeHitRateTestReqVO();
|
||||
KnowledgeHitRateTestReqVO testReqVO = new KnowledgeHitRateTestReqVO();
|
||||
testReqVO.setKnowledgeId(chatReqVO.getKnowledge());
|
||||
testReqVO.setQuery(chatReqVO.getPrompt());
|
||||
|
||||
List<KnowledgeHitRateTestResultVO> result = knowledgeBaseService.executeHitRateTest(testReqVO);
|
||||
knowledgeBase = handlerResult(result, paragraphHitRateListVO);
|
||||
knowledgeBase = handlerResult(result, paragraphHitRateListVO);
|
||||
|
||||
|
||||
log.info("{} 知识库构建完成,内容长度: {}", LOG_PREFIX, knowledgeBase.length());
|
||||
@ -575,8 +582,8 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
return knowledgeBase;
|
||||
}
|
||||
|
||||
private StringBuilder handlerResult (List<KnowledgeHitRateTestResultVO> result, ParagraphHitRateListVO paragraphHitRateListVO) {
|
||||
if (CollectionUtils.isEmpty(result)){
|
||||
private StringBuilder handlerResult(List<KnowledgeHitRateTestResultVO> result, ParagraphHitRateListVO paragraphHitRateListVO) {
|
||||
if (CollectionUtils.isEmpty(result)) {
|
||||
return new StringBuilder();
|
||||
}
|
||||
|
||||
@ -591,8 +598,8 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
return knowledgeBase;
|
||||
}
|
||||
|
||||
private void saveRedis (List<KnowledgeHitRateTestResultVO> result, ParagraphHitRateListVO paragraphHitRateListVO) {
|
||||
if (CollectionUtils.isEmpty(result)){
|
||||
private void saveRedis(List<KnowledgeHitRateTestResultVO> result, ParagraphHitRateListVO paragraphHitRateListVO) {
|
||||
if (CollectionUtils.isEmpty(result)) {
|
||||
return;
|
||||
}
|
||||
List<ParagraphHitRateWordVO> words = new ArrayList<>();
|
||||
@ -605,13 +612,13 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
List<ParagraphHitRateWordVO> finalWords = words;
|
||||
groupedByFileId.forEach((fileId, list) -> {
|
||||
System.out.println("File ID: " + fileId);
|
||||
list.forEach(i->{
|
||||
list.forEach(i -> {
|
||||
ParagraphHitRateWordVO rateWordVO = new ParagraphHitRateWordVO();
|
||||
// 设置文档名称
|
||||
rateWordVO.setDocumentName(i.getFileName());
|
||||
|
||||
// 设置段落命中率
|
||||
List<ParagraphHitRateVO> paragraphHitRate=new ArrayList<>();
|
||||
List<ParagraphHitRateVO> paragraphHitRate = new ArrayList<>();
|
||||
for (KnowledgeHitRateTestResultVO i1 : list) {
|
||||
ParagraphHitRateVO rateVO = new ParagraphHitRateVO();
|
||||
rateVO.setParagraph(i1.getPageContent());
|
||||
@ -649,9 +656,9 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
/**
|
||||
* 处理单个知识库文档的检索逻辑
|
||||
*/
|
||||
private ParagraphHitRateWordVO processDocument (KnowledgeDocumentsDO document,
|
||||
ChatReqVO chatReqVO,
|
||||
StringBuilder knowledgeBase) {
|
||||
private ParagraphHitRateWordVO processDocument(KnowledgeDocumentsDO document,
|
||||
ChatReqVO chatReqVO,
|
||||
StringBuilder knowledgeBase) {
|
||||
ParagraphHitRateWordVO paragraphHitRateListVO = new ParagraphHitRateWordVO();
|
||||
try {
|
||||
log.info("{} 处理文档: {}[ID:{}]", "[KnowledgeBase]", document.getDocumentName(), document.getId());
|
||||
@ -688,9 +695,9 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
/**
|
||||
* 解析向量检索响应结果
|
||||
*/
|
||||
private List<ParagraphHitRateVO> parseEmbeddingResponse (String response,
|
||||
StringBuilder knowledgeBase,
|
||||
String uuid) {
|
||||
private List<ParagraphHitRateVO> parseEmbeddingResponse(String response,
|
||||
StringBuilder knowledgeBase,
|
||||
String uuid) {
|
||||
if (StringUtils.isBlank(response)) {
|
||||
log.warn("{} 收到空响应", "[KnowledgeBase]");
|
||||
return Collections.emptyList();
|
||||
@ -756,7 +763,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
* @param chatReqVO
|
||||
* @return
|
||||
*/
|
||||
private ChatRespVO privateModelChat (ChatReqVO chatReqVO) {
|
||||
private ChatRespVO privateModelChat(ChatReqVO chatReqVO) {
|
||||
if (StringUtils.isBlank(chatReqVO.getUuid())) {
|
||||
// 如果没有uuid,就生成一个
|
||||
chatReqVO.setUuid(UUID.randomUUID().toString());
|
||||
|
Loading…
x
Reference in New Issue
Block a user