feat(module-llm):增加 topP 参数并调整相关逻辑
- 在 ChatReqVO 中添加 topP 字段 - 在 ConversationServiceImpl 中处理 topP 参数- 更新 ModelCompletionsReqVO,将 top_p 设为可配置项 - 调整 ModelService 中的默认参数设置
This commit is contained in:
parent
5bbc4931d9
commit
0e308a7f13
@ -31,6 +31,8 @@ public class ChatReqVO {
|
||||
private Integer maxTokens;
|
||||
@Schema(description = "随机性temperature")
|
||||
private Double temperature;
|
||||
@Schema(description = "随机性temperature")
|
||||
private Double topP;
|
||||
@Schema(description = "分组id")
|
||||
private String groupId;
|
||||
}
|
||||
|
@ -36,7 +36,6 @@ import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
@ -321,7 +320,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public ParagraphHitRateListVO getParagraphHitRate(String uuid, String groupId) {
|
||||
public ParagraphHitRateListVO getParagraphHitRate (String uuid, String groupId) {
|
||||
String redisKey = String.format("%s:%s", KNOWLEDGE_DOCUMENTS_REDIS_KEY, uuid);
|
||||
List<String> redisResults = stringRedisTemplate.opsForList().range(redisKey, 0, -1);
|
||||
log.info("[Redis Query] Key: {} | Results count: {}", redisKey, redisResults != null ? redisResults.size() : 0);
|
||||
@ -352,7 +351,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
/**
|
||||
* 验证命中率数据有效性
|
||||
*/
|
||||
private boolean isValidHitRateData(JSONObject jsonObj, String expectedUuid, String expectedGroupId) {
|
||||
private boolean isValidHitRateData (JSONObject jsonObj, String expectedUuid, String expectedGroupId) {
|
||||
if (!jsonObj.containsKey("uuid") || !jsonObj.containsKey("groupId")) {
|
||||
log.warn("[Validation] Missing required fields. Existing keys: {}", jsonObj.keySet());
|
||||
return false;
|
||||
@ -365,7 +364,8 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
&& Objects.equals(expectedGroupId, actualGroupId);
|
||||
}
|
||||
|
||||
public static final String PROMPT="Use the following context as your learned knowledge, inside <context></context> XML tags.When answer to user:\\n- If you don't know, just say that you don't know.\\n- If you don't know when you are not sure, ask for clarification.\\nAvoid mentioning that you obtained the information from the context.\\nAnd answer according to the language of the user's question.\\n\\n";
|
||||
public static final String PROMPT = "Use the following context as your learned knowledge, inside <context></context> XML tags.When answer to user:\\n- If you don't know, just say that you don't know.\\n- If you don't know when you are not sure, ask for clarification.\\nAvoid mentioning that you obtained the information from the context.\\nAnd answer according to the language of the user's question.\\n\\n";
|
||||
|
||||
/**
|
||||
* 公共模型聊天流式处理方法
|
||||
*
|
||||
@ -430,8 +430,8 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
}
|
||||
|
||||
List<ModelCompletionsReqVO.ModelCompletionsMessage> messages = new ArrayList<>();
|
||||
String systemPrompt="";
|
||||
String knowledgeBaseString="";
|
||||
String systemPrompt = "";
|
||||
String knowledgeBaseString = "";
|
||||
if (chatReqVO.getKnowledge() != null) {
|
||||
StringBuilder knowledgeBase = getKnowledgeBase(chatReqVO);
|
||||
knowledgeBaseString = knowledgeBase.toString();
|
||||
@ -448,26 +448,26 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
}
|
||||
String mess = systemPrompt + knowledgeBaseString;
|
||||
|
||||
// // 查询历史记录消息,并将查询出来的知识信息放入到 role = system 的消息中
|
||||
// List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
|
||||
// if (messageHistoryList != null && !messageHistoryList.isEmpty()) {
|
||||
// log.info("存在聊天历史记录,处理历史记录消息");
|
||||
// for (String messageHistory : messageHistoryList) {
|
||||
// ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class);
|
||||
// if ("system".equals(modelCompletionsMessage.getRole())) {
|
||||
// modelCompletionsMessage.setContent(mess);
|
||||
// stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
|
||||
// }
|
||||
// messages.add(modelCompletionsMessage);
|
||||
// }
|
||||
// } else {
|
||||
log.info("不存在聊天历史记录,创建新的系统消息");
|
||||
ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
|
||||
systemMessage.setRole("system");
|
||||
systemMessage.setContent(mess);
|
||||
stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
|
||||
messages.add(systemMessage);
|
||||
// }
|
||||
// // 查询历史记录消息,并将查询出来的知识信息放入到 role = system 的消息中
|
||||
// List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
|
||||
// if (messageHistoryList != null && !messageHistoryList.isEmpty()) {
|
||||
// log.info("存在聊天历史记录,处理历史记录消息");
|
||||
// for (String messageHistory : messageHistoryList) {
|
||||
// ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class);
|
||||
// if ("system".equals(modelCompletionsMessage.getRole())) {
|
||||
// modelCompletionsMessage.setContent(mess);
|
||||
// stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
|
||||
// }
|
||||
// messages.add(modelCompletionsMessage);
|
||||
// }
|
||||
// } else {
|
||||
log.info("不存在聊天历史记录,创建新的系统消息");
|
||||
ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
|
||||
systemMessage.setRole("system");
|
||||
systemMessage.setContent(mess);
|
||||
stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
|
||||
messages.add(systemMessage);
|
||||
// }
|
||||
|
||||
// 创建用户消息
|
||||
ModelCompletionsReqVO.ModelCompletionsMessage message = new ModelCompletionsReqVO.ModelCompletionsMessage();
|
||||
@ -479,6 +479,9 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
ModelCompletionsReqVO modelCompletionsReqVO = new ModelCompletionsReqVO();
|
||||
modelCompletionsReqVO.setMessages(messages);
|
||||
modelCompletionsReqVO.setModel(model);
|
||||
modelCompletionsReqVO.setTemperature(chatReqVO.getTemperature());
|
||||
modelCompletionsReqVO.setTop_p(chatReqVO.getTopP());
|
||||
modelCompletionsReqVO.setMax_tokens(chatReqVO.getMaxTokens());
|
||||
// log.info("构建模型补全请求对象,请求参数1: {}", modelCompletionsReqVO);
|
||||
|
||||
// 调用模型服务进行流式处理
|
||||
@ -533,15 +536,15 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
// 2. 遍历处理每个文档
|
||||
for (KnowledgeDocumentsDO document : documentList) {
|
||||
ParagraphHitRateWordVO rateWordVO = processDocument(document, chatReqVO, knowledgeBase);
|
||||
if (rateWordVO!=null){
|
||||
if (rateWordVO != null) {
|
||||
words.add(rateWordVO);
|
||||
}
|
||||
|
||||
}
|
||||
if (CollectionUtils.isEmpty(words)){
|
||||
if (CollectionUtils.isEmpty(words)) {
|
||||
paragraphHitRateListVO.setWordList(Collections.emptyList());
|
||||
paragraphHitRateListVO.setGroupId("");
|
||||
}else {
|
||||
} else {
|
||||
paragraphHitRateListVO.setWordList(words);
|
||||
}
|
||||
|
||||
@ -592,7 +595,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
|
||||
paragraphHitRateListVO.setDocumentName(document.getDocumentName());
|
||||
List<ParagraphHitRateVO> rateVOS = parseEmbeddingResponse(response, knowledgeBase, chatReqVO.getUuid());
|
||||
if (CollectionUtils.isEmpty(rateVOS)){
|
||||
if (CollectionUtils.isEmpty(rateVOS)) {
|
||||
return null;
|
||||
}
|
||||
paragraphHitRateListVO.setParagraphHitRate(rateVOS);
|
||||
@ -644,12 +647,12 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
String rateResult = df.format(rate);
|
||||
log.info("{} 命中率: {}", "[KnowledgeBase]", rateResult);
|
||||
|
||||
if (StringUtils.isBlank(rateResult)||rate<=0.0){
|
||||
if (StringUtils.isBlank(rateResult) || rate <= 0.0) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
if (StringUtils.isNotBlank(pageContent)) {
|
||||
// knowledgeBase.append("\n[知识库内容] [内容如下]").append(pageContent);
|
||||
// knowledgeBase.append("\n[知识库内容] [内容如下]").append(pageContent);
|
||||
knowledgeBase.append(pageContent);
|
||||
log.info("{} 添加知识内容,长度: {}", "[KnowledgeBase]", pageContent.length());
|
||||
|
||||
|
@ -149,9 +149,6 @@ public class ModelService {
|
||||
*/
|
||||
public ModelCompletionsRespVO modelCompletionsStream (String url, ModelCompletionsReqVO req, SseEmitter emitter, String uuid, String groupId) {
|
||||
req.setStream(true);
|
||||
req.setTemperature(0.2);
|
||||
req.setTop_p(0.9);
|
||||
req.setMax_tokens(4096);
|
||||
log.info("开始处理模型补全请求");
|
||||
|
||||
// 检查模型是否为空,若为空则设置默认模型
|
||||
|
@ -15,9 +15,9 @@ public class ModelCompletionsReqVO {
|
||||
|
||||
private String model;
|
||||
private List<ModelCompletionsMessage> messages;
|
||||
private Integer max_tokens = 4096;
|
||||
private Double temperature = 0.2;
|
||||
private Double top_p = 0.9;
|
||||
private Integer max_tokens;
|
||||
private Double temperature;
|
||||
private Double top_p;
|
||||
private Boolean stream;
|
||||
|
||||
@Data
|
||||
|
Loading…
x
Reference in New Issue
Block a user