refactor(llm): 重构模型补全流式处理逻辑
- 优化了 ChatReqVO 和 DataRefluxDataSaveReqVO 中的数据类型 - 重构了 ConversationController 和 ConversationService 中的 chatStream 方法 - 重新实现了 ModelService 中的 modelCompletionsStream 方法,采用更高效的处理方式 - 新增了辅助方法 parseStreamLine、extractJsonFromDataString、setupRequest 和 handleResponseEntity 以提高代码可读性和可维护性
This commit is contained in:
parent
a4e7cd67b7
commit
6ccf593f0a
@ -116,11 +116,11 @@ public class ConversationController {
|
||||
* @return SseEmitter 对象,用于流式发送响应
|
||||
*/
|
||||
@PostMapping("/stream-chat")
|
||||
public SseEmitter streamChat(@Valid @RequestBody ChatReqVO chatReqVO) {
|
||||
public SseEmitter streamChat(@Valid @RequestBody ChatReqVO chatReqVO,HttpServletResponse response) {
|
||||
log.info("收到对话推理请求,请求参数: {}", chatReqVO);
|
||||
SseEmitter emitter = new SseEmitter();
|
||||
try {
|
||||
conversationService.chatStream(chatReqVO, emitter);
|
||||
conversationService.chatStream(chatReqVO, emitter,response);
|
||||
} catch (Exception e) {
|
||||
log.error("处理对话推理请求时发生异常", e);
|
||||
try {
|
||||
|
@ -28,7 +28,7 @@ public class ChatReqVO {
|
||||
@Schema(description = "知识库id")
|
||||
private Long knowledge;
|
||||
@Schema(description = "单次回复限制max_tokens")
|
||||
private String maxTokens;
|
||||
private Integer maxTokens;
|
||||
@Schema(description = "随机性temperature")
|
||||
private Long temperature;
|
||||
private Double temperature;
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ public class DataRefluxDataSaveReqVO {
|
||||
private String response;
|
||||
|
||||
@Schema(description = "单次回复限制max_tokens")
|
||||
private String maxTokens;
|
||||
private Integer maxTokens;
|
||||
@Schema(description = "随机性temperature")
|
||||
private Long temperature;
|
||||
private Double temperature;
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package cn.iocoder.yudao.module.llm.service.conversation;
|
||||
|
||||
import java.util.*;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.validation.*;
|
||||
|
||||
import cn.hutool.json.JSONArray;
|
||||
@ -72,5 +73,5 @@ public interface ConversationService {
|
||||
* @param chatReqVO chatReqVO
|
||||
* @param emitter emitter
|
||||
*/
|
||||
void chatStream (@Valid ChatReqVO chatReqVO, SseEmitter emitter);
|
||||
void chatStream (@Valid ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response);
|
||||
}
|
||||
|
@ -43,6 +43,7 @@ import org.springframework.web.reactive.function.client.WebClient;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
|
||||
@ -284,7 +285,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
* @param chatReqVO 对话请求对象
|
||||
* @param emitter SseEmitter 对象,用于流式发送响应
|
||||
*/
|
||||
public void chatStream(ChatReqVO chatReqVO, SseEmitter emitter) {
|
||||
public void chatStream(ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response) {
|
||||
log.info("开始处理对话请求,请求参数: {}", chatReqVO);
|
||||
// 检查系统提示信息,如果为空则尝试从应用中获取
|
||||
if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().equals("")) {
|
||||
@ -303,7 +304,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
}
|
||||
}
|
||||
// 调用公共模型聊天流式处理方法
|
||||
publicModelChatStream(chatReqVO, emitter);
|
||||
publicModelChatStream(chatReqVO, emitter, response);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -311,7 +312,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
* @param chatReqVO 对话请求对象
|
||||
* @param emitter SseEmitter 对象,用于流式发送响应
|
||||
*/
|
||||
public void publicModelChatStream(ChatReqVO chatReqVO, SseEmitter emitter) {
|
||||
public void publicModelChatStream(ChatReqVO chatReqVO, SseEmitter emitter,HttpServletResponse response) {
|
||||
log.info("开始公共模型聊天流式处理,请求参数: {}", chatReqVO);
|
||||
// 检查 UUID 是否为空,若为空则生成一个
|
||||
if (StrUtil.isBlank(chatReqVO.getUuid())) {
|
||||
@ -420,7 +421,7 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
log.info("构建模型补全请求对象,请求参数: {}", modelCompletionsReqVO);
|
||||
|
||||
// 调用模型服务进行流式处理
|
||||
modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter);
|
||||
modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter, response);
|
||||
|
||||
// 将用户消息存入缓存
|
||||
stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(message));
|
||||
|
@ -1,6 +1,5 @@
|
||||
package cn.iocoder.yudao.module.llm.service.http;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.common.util.http.HttpUtils;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.servername.ServerNameDO;
|
||||
@ -12,15 +11,21 @@ import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.http.HttpEntity;
|
||||
import org.apache.http.HttpResponse;
|
||||
import org.apache.http.client.HttpClient;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.apache.http.entity.StringEntity;
|
||||
import org.apache.http.impl.client.HttpClients;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.io.*;
|
||||
import java.net.HttpURLConnection;
|
||||
import java.net.URL;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.*;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
@ -133,93 +138,73 @@ public class ModelService {
|
||||
throw new RuntimeException("模型补全请求处理失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 模型补全流式处理方法
|
||||
*
|
||||
* @param url 模型服务的 URL
|
||||
* @param req 模型补全请求对象
|
||||
* @param emitter SseEmitter 对象,用于流式发送响应
|
||||
*/
|
||||
public void modelCompletionsStream(String url, ModelCompletionsReqVO req, SseEmitter emitter) {
|
||||
log.info("开始处理模型补全请求,请求参数: {}", req);
|
||||
// 检查模型是否为空,若为空则设置默认模型
|
||||
if (StrUtil.isBlank(req.getModel())) {
|
||||
log.info("模型 ID 为空,设置为默认模型: {}", DEFAULT_MODEL_ID);
|
||||
req.setModel(DEFAULT_MODEL_ID);
|
||||
}
|
||||
// 记录请求信息
|
||||
log.info("请求参数: {}", JSON.toJSONString(req));
|
||||
public void modelCompletionsStream (String url, ModelCompletionsReqVO req, SseEmitter emitter, HttpServletResponse response) {
|
||||
req.setStream(true);
|
||||
log.info("开始处理模型补全请求,参数: {}", req);
|
||||
try {
|
||||
// 发起 HTTP POST 请求
|
||||
URL targetUrl;
|
||||
if (StrUtil.isBlank(url)) {
|
||||
log.info("URL 为空,使用默认 URL: {}", llmBackendProperties.getModelCompletions());
|
||||
targetUrl = new URL(llmBackendProperties.getModelCompletions());
|
||||
} else {
|
||||
log.info("使用指定 URL: {}", url);
|
||||
targetUrl = new URL(url);
|
||||
}
|
||||
HttpURLConnection connection = (HttpURLConnection) targetUrl.openConnection();
|
||||
connection.setRequestMethod("POST");
|
||||
connection.setRequestProperty("Content-Type", "application/json");
|
||||
connection.setDoOutput(true);
|
||||
connection.getOutputStream().write(JSON.toJSONString(req).getBytes());
|
||||
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream()));
|
||||
String line;
|
||||
StringBuilder responseBuilder = new StringBuilder();
|
||||
while ((line = reader.readLine()) != null) {
|
||||
// 解析流式响应内容
|
||||
if (StrUtil.isNotBlank(line)) {
|
||||
log.info("收到流式响应内容: {}", line);
|
||||
ChatCompletion chatCompletion = JSON.parseObject(line, ChatCompletion.class);
|
||||
if (StrUtil.isBlank(chatCompletion.getDetail())) {
|
||||
String respContent = chatCompletion.getChoices().get(0).getMessage().getContent();
|
||||
String patternString = "(<think>.*?</think>)";
|
||||
Pattern pattern = Pattern.compile(patternString, Pattern.DOTALL);
|
||||
Matcher matcher = pattern.matcher(respContent);
|
||||
String answerContent = matcher.replaceAll("");
|
||||
// 流式发送数据
|
||||
try {
|
||||
emitter.send(SseEmitter.event().data(answerContent));
|
||||
log.info("已发送流式响应数据: {}", answerContent);
|
||||
} catch (Exception e) {
|
||||
log.error("发送流式响应数据时发生异常", e);
|
||||
try {
|
||||
emitter.completeWithError(e);
|
||||
} catch (Exception ex) {
|
||||
log.error("无法完成 SseEmitter 错误处理", ex);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 完成流式传输
|
||||
try {
|
||||
emitter.complete();
|
||||
log.info("流式传输完成");
|
||||
} catch (Exception e) {
|
||||
log.error("完成流式传输时发生异常", e);
|
||||
}
|
||||
log.info("开始处理模型补全请求,参数: {}", JSON.toJSONString(req));
|
||||
sendPostRequest(url, "{\"max_tokens\":4000,\"messages\":[{\"content\":\"<content></content>\",\"role\":\"system\"},{\"content\":\"介绍一下天津\",\"role\":\"user\"}],\"model\":\"Qwen2.5-0.5B-Instruct\",\"temperature\":0.7,\"stream\":true}", emitter);
|
||||
} catch (Exception e) {
|
||||
log.error("处理模型补全请求时发生异常", e);
|
||||
try {
|
||||
emitter.completeWithError(e);
|
||||
} catch (Exception ex) {
|
||||
log.error("无法完成 SseEmitter 错误处理", ex);
|
||||
}
|
||||
emitter.completeWithError(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过 GPU 类型获取对应的主机地址
|
||||
* <p>ModelService GpuType 是 ServerName 表ID
|
||||
* 解析流式返回的单行数据,提取有效内容并清理特定标记
|
||||
*
|
||||
* @param gpuType gpuType
|
||||
* @return host
|
||||
* @param line 流式响应中的单行JSON数据
|
||||
* @return 处理后的文本内容(若无有效内容返回null)
|
||||
*/
|
||||
public String getHostByType (Long gpuType) {
|
||||
return serverNameService.getServerName(gpuType).getHost();
|
||||
private String parseStreamLine (String line) {
|
||||
if (StringUtils.isNotBlank(line)) {
|
||||
if (line.startsWith("data: ")) {
|
||||
String dataString = extractJsonFromDataString(line);
|
||||
if (!dataString.contains("[DONE]")) {
|
||||
JSONObject jsonObject = JSON.parseObject(dataString);
|
||||
// 获取 choices 数组
|
||||
JSONArray choicesArray = jsonObject.getJSONArray("choices");
|
||||
if (choicesArray != null && !choicesArray.isEmpty()) {
|
||||
// 获取第一个 choice 对象
|
||||
JSONObject firstChoice = choicesArray.getJSONObject(0);
|
||||
// 获取 delta 对象
|
||||
JSONObject delta = firstChoice.getJSONObject("delta");
|
||||
if (delta != null) {
|
||||
// 获取 content 的值
|
||||
String content = delta.getString("content");
|
||||
return "{\"content\":\"" + content + "\",\"finish_reason\":\"" + false + "\"}";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return "{\"content\":\"" + "\",\"finish_reason\":\"" + true + "\"}";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 从包含 "data: " 前缀的字符串中提取 JSON 字符串
|
||||
*
|
||||
* @param input 包含 "data: " 前缀的原始字符串
|
||||
* @return 提取出的 JSON 字符串,如果未找到 "data: " 则返回 null
|
||||
*/
|
||||
public static String extractJsonFromDataString (String input) {
|
||||
if (input == null) {
|
||||
return null;
|
||||
}
|
||||
int index = input.indexOf("data: ");
|
||||
if (index != -1) {
|
||||
return input.substring(index + "data: ".length()).trim();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -335,4 +320,70 @@ public class ModelService {
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送 POST 请求并处理响应
|
||||
*
|
||||
* @param apiUrl 目标 API 的 URL
|
||||
* @param requestBody 请求体内容
|
||||
* @throws IOException 发送请求或处理响应时可能抛出的 IO 异常
|
||||
*/
|
||||
private void sendPostRequest (String apiUrl, String requestBody, SseEmitter emitter) throws IOException {
|
||||
// 创建 HttpClient 实例
|
||||
HttpClient httpClient = HttpClients.createDefault();
|
||||
// 创建 HttpPost 请求对象
|
||||
HttpPost httpPost = new HttpPost(apiUrl);
|
||||
|
||||
// 设置请求体和请求头
|
||||
setupRequest(httpPost, requestBody);
|
||||
|
||||
// 执行 POST 请求并获取响应
|
||||
HttpResponse response = httpClient.execute(httpPost);
|
||||
|
||||
// 处理响应实体
|
||||
handleResponseEntity(response, emitter);
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置请求体和请求头
|
||||
*
|
||||
* @param httpPost HttpPost 请求对象
|
||||
* @param requestBody 请求体内容
|
||||
* @throws IOException 创建 StringEntity 时可能抛出的 IO 异常
|
||||
*/
|
||||
private void setupRequest (HttpPost httpPost, String requestBody) throws IOException {
|
||||
// 创建 StringEntity 对象,用于封装请求体
|
||||
StringEntity entity = new StringEntity(requestBody);
|
||||
// 设置请求体
|
||||
httpPost.setEntity(entity);
|
||||
// 设置请求头,指定请求体的内容类型为 JSON
|
||||
httpPost.setHeader("Content-Type", "application/json");
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理响应实体
|
||||
*
|
||||
* @param response HttpResponse 对象
|
||||
* @throws IOException 读取响应实体输入流时可能抛出的 IO 异常
|
||||
*/
|
||||
private void handleResponseEntity (HttpResponse response, SseEmitter emitter) throws IOException {
|
||||
// 获取响应实体
|
||||
HttpEntity responseEntity = response.getEntity();
|
||||
if (responseEntity != null) {
|
||||
// 使用 try-with-resources 语句自动关闭输入流和 BufferedReader
|
||||
try (InputStream inputStream = responseEntity.getContent();
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
|
||||
String line;
|
||||
// 逐行读取响应内容并处理
|
||||
while ((line = reader.readLine()) != null) {
|
||||
System.out.println("接收到的响应行数据: " + line);
|
||||
String content = parseStreamLine(line);
|
||||
if (content != null) {
|
||||
emitter.send(SseEmitter.event().data(content));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -17,6 +17,9 @@ public class ModelCompletionsReqVO {
|
||||
private List<ModelCompletionsMessage> messages;
|
||||
private Integer max_tokens = 4000;
|
||||
private Double temperature = 0.7;
|
||||
// private Integer max_tokens;
|
||||
// private Double temperature;
|
||||
private Boolean stream;
|
||||
// private Integer max_length = 120000;
|
||||
|
||||
@Data
|
||||
|
Loading…
x
Reference in New Issue
Block a user