feat(module-llm):增加 topP 参数并调整相关逻辑

- 在 ChatReqVO 中添加 topP 字段
- 在 ConversationServiceImpl 中处理 topP 参数- 更新 ModelCompletionsReqVO,将 top_p 设为可配置项
- 调整 ModelService 中的默认参数设置
This commit is contained in:
Liuyang 2025-03-12 19:51:36 +08:00
parent 5bbc4931d9
commit 0e308a7f13
4 changed files with 40 additions and 38 deletions

View File

@ -31,6 +31,8 @@ public class ChatReqVO {
private Integer maxTokens;
@Schema(description = "随机性temperature")
private Double temperature;
@Schema(description = "随机性temperature")
private Double topP;
@Schema(description = "分组id")
private String groupId;
}

View File

@ -36,7 +36,6 @@ import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@ -321,7 +320,7 @@ public class ConversationServiceImpl implements ConversationService {
* @return
*/
@Override
public ParagraphHitRateListVO getParagraphHitRate(String uuid, String groupId) {
public ParagraphHitRateListVO getParagraphHitRate (String uuid, String groupId) {
String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, uuid);
List<String> redisResults = stringRedisTemplate.opsForList().range(redisKey, 0, -1);
log.info("[Redis Query] Key: {} | Results count: {}", redisKey, redisResults != null ? redisResults.size() : 0);
@ -352,7 +351,7 @@ public class ConversationServiceImpl implements ConversationService {
/**
* 验证命中率数据有效性
*/
private boolean isValidHitRateData(JSONObject jsonObj, String expectedUuid, String expectedGroupId) {
private boolean isValidHitRateData (JSONObject jsonObj, String expectedUuid, String expectedGroupId) {
if (!jsonObj.containsKey("uuid") || !jsonObj.containsKey("groupId")) {
log.warn("[Validation] Missing required fields. Existing keys: {}", jsonObj.keySet());
return false;
@ -365,7 +364,8 @@ public class ConversationServiceImpl implements ConversationService {
&& Objects.equals(expectedGroupId, actualGroupId);
}
public static final String PROMPT="Use the following context as your learned knowledge, inside <context></context> XML tags.When answer to user:\\n- If you don't know, just say that you don't know.\\n- If you don't know when you are not sure, ask for clarification.\\nAvoid mentioning that you obtained the information from the context.\\nAnd answer according to the language of the user's question.\\n\\n";
public static final String PROMPT = "Use the following context as your learned knowledge, inside <context></context> XML tags.When answer to user:\\n- If you don't know, just say that you don't know.\\n- If you don't know when you are not sure, ask for clarification.\\nAvoid mentioning that you obtained the information from the context.\\nAnd answer according to the language of the user's question.\\n\\n";
/**
* 公共模型聊天流式处理方法
*
@ -430,8 +430,8 @@ public class ConversationServiceImpl implements ConversationService {
}
List<ModelCompletionsReqVO.ModelCompletionsMessage> messages = new ArrayList<>();
String systemPrompt="";
String knowledgeBaseString="";
String systemPrompt = "";
String knowledgeBaseString = "";
if (chatReqVO.getKnowledge() != null) {
StringBuilder knowledgeBase = getKnowledgeBase(chatReqVO);
knowledgeBaseString = knowledgeBase.toString();
@ -448,26 +448,26 @@ public class ConversationServiceImpl implements ConversationService {
}
String mess = systemPrompt + knowledgeBaseString;
// // 查询历史记录消息并将查询出来的知识信息放入到 role = system 的消息中
// List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
// if (messageHistoryList != null && !messageHistoryList.isEmpty()) {
// log.info("存在聊天历史记录,处理历史记录消息");
// for (String messageHistory : messageHistoryList) {
// ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class);
// if ("system".equals(modelCompletionsMessage.getRole())) {
// modelCompletionsMessage.setContent(mess);
// stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
// }
// messages.add(modelCompletionsMessage);
// }
// } else {
log.info("不存在聊天历史记录,创建新的系统消息");
ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
systemMessage.setRole("system");
systemMessage.setContent(mess);
stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
messages.add(systemMessage);
// }
// // 查询历史记录消息并将查询出来的知识信息放入到 role = system 的消息中
// List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
// if (messageHistoryList != null && !messageHistoryList.isEmpty()) {
// log.info("存在聊天历史记录,处理历史记录消息");
// for (String messageHistory : messageHistoryList) {
// ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class);
// if ("system".equals(modelCompletionsMessage.getRole())) {
// modelCompletionsMessage.setContent(mess);
// stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
// }
// messages.add(modelCompletionsMessage);
// }
// } else {
log.info("不存在聊天历史记录,创建新的系统消息");
ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
systemMessage.setRole("system");
systemMessage.setContent(mess);
stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
messages.add(systemMessage);
// }
// 创建用户消息
ModelCompletionsReqVO.ModelCompletionsMessage message = new ModelCompletionsReqVO.ModelCompletionsMessage();
@ -479,6 +479,9 @@ public class ConversationServiceImpl implements ConversationService {
ModelCompletionsReqVO modelCompletionsReqVO = new ModelCompletionsReqVO();
modelCompletionsReqVO.setMessages(messages);
modelCompletionsReqVO.setModel(model);
modelCompletionsReqVO.setTemperature(chatReqVO.getTemperature());
modelCompletionsReqVO.setTop_p(chatReqVO.getTopP());
modelCompletionsReqVO.setMax_tokens(chatReqVO.getMaxTokens());
// log.info("构建模型补全请求对象请求参数1: {}", modelCompletionsReqVO);
// 调用模型服务进行流式处理
@ -533,15 +536,15 @@ public class ConversationServiceImpl implements ConversationService {
// 2. 遍历处理每个文档
for (KnowledgeDocumentsDO document : documentList) {
ParagraphHitRateWordVO rateWordVO = processDocument(document, chatReqVO, knowledgeBase);
if (rateWordVO!=null){
if (rateWordVO != null) {
words.add(rateWordVO);
}
}
if (CollectionUtils.isEmpty(words)){
if (CollectionUtils.isEmpty(words)) {
paragraphHitRateListVO.setWordList(Collections.emptyList());
paragraphHitRateListVO.setGroupId("");
}else {
} else {
paragraphHitRateListVO.setWordList(words);
}
@ -592,7 +595,7 @@ public class ConversationServiceImpl implements ConversationService {
paragraphHitRateListVO.setDocumentName(document.getDocumentName());
List<ParagraphHitRateVO> rateVOS = parseEmbeddingResponse(response, knowledgeBase, chatReqVO.getUuid());
if (CollectionUtils.isEmpty(rateVOS)){
if (CollectionUtils.isEmpty(rateVOS)) {
return null;
}
paragraphHitRateListVO.setParagraphHitRate(rateVOS);
@ -644,12 +647,12 @@ public class ConversationServiceImpl implements ConversationService {
String rateResult = df.format(rate);
log.info("{} 命中率: {}", "[KnowledgeBase]", rateResult);
if (StringUtils.isBlank(rateResult)||rate<=0.0){
if (StringUtils.isBlank(rateResult) || rate <= 0.0) {
return Collections.emptyList();
}
if (StringUtils.isNotBlank(pageContent)) {
// knowledgeBase.append("\n[知识库内容] [内容如下]").append(pageContent);
// knowledgeBase.append("\n[知识库内容] [内容如下]").append(pageContent);
knowledgeBase.append(pageContent);
log.info("{} 添加知识内容,长度: {}", "[KnowledgeBase]", pageContent.length());

View File

@ -149,9 +149,6 @@ public class ModelService {
*/
public ModelCompletionsRespVO modelCompletionsStream (String url, ModelCompletionsReqVO req, SseEmitter emitter, String uuid, String groupId) {
req.setStream(true);
req.setTemperature(0.2);
req.setTop_p(0.9);
req.setMax_tokens(4096);
log.info("开始处理模型补全请求");
// 检查模型是否为空若为空则设置默认模型

View File

@ -15,9 +15,9 @@ public class ModelCompletionsReqVO {
private String model;
private List<ModelCompletionsMessage> messages;
private Integer max_tokens = 4096;
private Double temperature = 0.2;
private Double top_p = 0.9;
private Integer max_tokens;
private Double temperature;
private Double top_p;
private Boolean stream;
@Data