diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java index f019cead0..52c6033cc 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java @@ -116,11 +116,11 @@ public class ConversationController { * @return SseEmitter 对象,用于流式发送响应 */ @PostMapping("/stream-chat") - public SseEmitter streamChat(@Valid @RequestBody ChatReqVO chatReqVO) { + public SseEmitter streamChat(@Valid @RequestBody ChatReqVO chatReqVO,HttpServletResponse response) { log.info("收到对话推理请求,请求参数: {}", chatReqVO); SseEmitter emitter = new SseEmitter(); try { - conversationService.chatStream(chatReqVO, emitter); + conversationService.chatStream(chatReqVO, emitter,response); } catch (Exception e) { log.error("处理对话推理请求时发生异常", e); try { diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ChatReqVO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ChatReqVO.java index b567ea182..c7f211bf5 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ChatReqVO.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ChatReqVO.java @@ -28,7 +28,7 @@ public class ChatReqVO { @Schema(description = "知识库id") private Long knowledge; @Schema(description = "单次回复限制max_tokens") - private String maxTokens; + private Integer maxTokens; @Schema(description = "随机性temperature") - private Long temperature; + private Double temperature; } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/datarefluxdata/vo/DataRefluxDataSaveReqVO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/datarefluxdata/vo/DataRefluxDataSaveReqVO.java index bd1df6e13..2ec9c537f 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/datarefluxdata/vo/DataRefluxDataSaveReqVO.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/datarefluxdata/vo/DataRefluxDataSaveReqVO.java @@ -33,7 +33,7 @@ public class DataRefluxDataSaveReqVO { private String response; @Schema(description = "单次回复限制max_tokens") - private String maxTokens; + private Integer maxTokens; @Schema(description = "随机性temperature") - private Long temperature; + private Double temperature; } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java index 9e1388ee0..9a1397859 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java @@ -1,6 +1,7 @@ package cn.iocoder.yudao.module.llm.service.conversation; import java.util.*; +import javax.servlet.http.HttpServletResponse; import javax.validation.*; import cn.hutool.json.JSONArray; @@ -72,5 +73,5 @@ public interface ConversationService { * @param chatReqVO chatReqVO * @param emitter emitter */ - void chatStream (@Valid ChatReqVO chatReqVO, SseEmitter emitter); + void chatStream (@Valid ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response); } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java index 9ad223380..5145c1185 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java @@ -43,6 +43,7 @@ import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; +import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.util.*; @@ -284,7 +285,7 @@ public class ConversationServiceImpl implements ConversationService { * @param chatReqVO 对话请求对象 * @param emitter SseEmitter 对象,用于流式发送响应 */ - public void chatStream(ChatReqVO chatReqVO, SseEmitter emitter) { + public void chatStream(ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response) { log.info("开始处理对话请求,请求参数: {}", chatReqVO); // 检查系统提示信息,如果为空则尝试从应用中获取 if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().equals("")) { @@ -303,7 +304,7 @@ public class ConversationServiceImpl implements ConversationService { } } // 调用公共模型聊天流式处理方法 - publicModelChatStream(chatReqVO, emitter); + publicModelChatStream(chatReqVO, emitter, response); } /** @@ -311,7 +312,7 @@ public class ConversationServiceImpl implements ConversationService { * @param chatReqVO 对话请求对象 * @param emitter SseEmitter 对象,用于流式发送响应 */ - public void publicModelChatStream(ChatReqVO chatReqVO, SseEmitter emitter) { + public void publicModelChatStream(ChatReqVO chatReqVO, SseEmitter emitter,HttpServletResponse response) { log.info("开始公共模型聊天流式处理,请求参数: {}", chatReqVO); // 检查 UUID 是否为空,若为空则生成一个 if (StrUtil.isBlank(chatReqVO.getUuid())) { @@ -420,7 +421,7 @@ public class ConversationServiceImpl implements ConversationService { log.info("构建模型补全请求对象,请求参数: {}", modelCompletionsReqVO); // 调用模型服务进行流式处理 - modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter); + modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter, response); // 将用户消息存入缓存 stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(message)); diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java index 1eff7fe69..11cc5cc27 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java @@ -1,6 +1,5 @@ package cn.iocoder.yudao.module.llm.service.http; -import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.common.util.http.HttpUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.llm.dal.dataobject.servername.ServerNameDO; @@ -12,15 +11,21 @@ import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; +import org.apache.http.HttpEntity; +import org.apache.http.HttpResponse; +import org.apache.http.client.HttpClient; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.HttpClients; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; -import java.io.BufferedReader; -import java.io.InputStream; -import java.io.InputStreamReader; +import javax.servlet.http.HttpServletResponse; +import java.io.*; import java.net.HttpURLConnection; import java.net.URL; +import java.nio.charset.StandardCharsets; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -133,93 +138,73 @@ public class ModelService { throw new RuntimeException("模型补全请求处理失败", e); } } + /** * 模型补全流式处理方法 + * * @param url 模型服务的 URL * @param req 模型补全请求对象 - * @param emitter SseEmitter 对象,用于流式发送响应 */ - public void modelCompletionsStream(String url, ModelCompletionsReqVO req, SseEmitter emitter) { - log.info("开始处理模型补全请求,请求参数: {}", req); - // 检查模型是否为空,若为空则设置默认模型 - if (StrUtil.isBlank(req.getModel())) { - log.info("模型 ID 为空,设置为默认模型: {}", DEFAULT_MODEL_ID); - req.setModel(DEFAULT_MODEL_ID); - } - // 记录请求信息 - log.info("请求参数: {}", JSON.toJSONString(req)); + public void modelCompletionsStream (String url, ModelCompletionsReqVO req, SseEmitter emitter, HttpServletResponse response) { + req.setStream(true); + log.info("开始处理模型补全请求,参数: {}", req); try { - // 发起 HTTP POST 请求 - URL targetUrl; - if (StrUtil.isBlank(url)) { - log.info("URL 为空,使用默认 URL: {}", llmBackendProperties.getModelCompletions()); - targetUrl = new URL(llmBackendProperties.getModelCompletions()); - } else { - log.info("使用指定 URL: {}", url); - targetUrl = new URL(url); - } - HttpURLConnection connection = (HttpURLConnection) targetUrl.openConnection(); - connection.setRequestMethod("POST"); - connection.setRequestProperty("Content-Type", "application/json"); - connection.setDoOutput(true); - connection.getOutputStream().write(JSON.toJSONString(req).getBytes()); - - BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream())); - String line; - StringBuilder responseBuilder = new StringBuilder(); - while ((line = reader.readLine()) != null) { - // 解析流式响应内容 - if (StrUtil.isNotBlank(line)) { - log.info("收到流式响应内容: {}", line); - ChatCompletion chatCompletion = JSON.parseObject(line, ChatCompletion.class); - if (StrUtil.isBlank(chatCompletion.getDetail())) { - String respContent = chatCompletion.getChoices().get(0).getMessage().getContent(); - String patternString = "(.*?)"; - Pattern pattern = Pattern.compile(patternString, Pattern.DOTALL); - Matcher matcher = pattern.matcher(respContent); - String answerContent = matcher.replaceAll(""); - // 流式发送数据 - try { - emitter.send(SseEmitter.event().data(answerContent)); - log.info("已发送流式响应数据: {}", answerContent); - } catch (Exception e) { - log.error("发送流式响应数据时发生异常", e); - try { - emitter.completeWithError(e); - } catch (Exception ex) { - log.error("无法完成 SseEmitter 错误处理", ex); - } - return; - } - } - } - } - // 完成流式传输 - try { - emitter.complete(); - log.info("流式传输完成"); - } catch (Exception e) { - log.error("完成流式传输时发生异常", e); - } + log.info("开始处理模型补全请求,参数: {}", JSON.toJSONString(req)); + sendPostRequest(url, "{\"max_tokens\":4000,\"messages\":[{\"content\":\"\",\"role\":\"system\"},{\"content\":\"介绍一下天津\",\"role\":\"user\"}],\"model\":\"Qwen2.5-0.5B-Instruct\",\"temperature\":0.7,\"stream\":true}", emitter); } catch (Exception e) { - log.error("处理模型补全请求时发生异常", e); - try { - emitter.completeWithError(e); - } catch (Exception ex) { - log.error("无法完成 SseEmitter 错误处理", ex); - } + emitter.completeWithError(e); } } /** - * 通过 GPU 类型获取对应的主机地址 - *

ModelService GpuType 是 ServerName 表ID + * 解析流式返回的单行数据,提取有效内容并清理特定标记 * - * @param gpuType gpuType - * @return host + * @param line 流式响应中的单行JSON数据 + * @return 处理后的文本内容(若无有效内容返回null) */ - public String getHostByType (Long gpuType) { - return serverNameService.getServerName(gpuType).getHost(); + private String parseStreamLine (String line) { + if (StringUtils.isNotBlank(line)) { + if (line.startsWith("data: ")) { + String dataString = extractJsonFromDataString(line); + if (!dataString.contains("[DONE]")) { + JSONObject jsonObject = JSON.parseObject(dataString); + // 获取 choices 数组 + JSONArray choicesArray = jsonObject.getJSONArray("choices"); + if (choicesArray != null && !choicesArray.isEmpty()) { + // 获取第一个 choice 对象 + JSONObject firstChoice = choicesArray.getJSONObject(0); + // 获取 delta 对象 + JSONObject delta = firstChoice.getJSONObject("delta"); + if (delta != null) { + // 获取 content 的值 + String content = delta.getString("content"); + return "{\"content\":\"" + content + "\",\"finish_reason\":\"" + false + "\"}"; + } + } + } else { + return "{\"content\":\"" + "\",\"finish_reason\":\"" + true + "\"}"; + } + } + } + + return null; + } + + /** + * 从包含 "data: " 前缀的字符串中提取 JSON 字符串 + * + * @param input 包含 "data: " 前缀的原始字符串 + * @return 提取出的 JSON 字符串,如果未找到 "data: " 则返回 null + */ + public static String extractJsonFromDataString (String input) { + if (input == null) { + return null; + } + int index = input.indexOf("data: "); + if (index != -1) { + return input.substring(index + "data: ".length()).trim(); + } + return null; } /** @@ -335,4 +320,70 @@ public class ModelService { } return null; } + + /** + * 发送 POST 请求并处理响应 + * + * @param apiUrl 目标 API 的 URL + * @param requestBody 请求体内容 + * @throws IOException 发送请求或处理响应时可能抛出的 IO 异常 + */ + private void sendPostRequest (String apiUrl, String requestBody, SseEmitter emitter) throws IOException { + // 创建 HttpClient 实例 + HttpClient httpClient = HttpClients.createDefault(); + // 创建 HttpPost 请求对象 + HttpPost httpPost = new HttpPost(apiUrl); + + // 设置请求体和请求头 + setupRequest(httpPost, requestBody); + + // 执行 POST 请求并获取响应 + HttpResponse response = httpClient.execute(httpPost); + + // 处理响应实体 + handleResponseEntity(response, emitter); + } + + /** + * 设置请求体和请求头 + * + * @param httpPost HttpPost 请求对象 + * @param requestBody 请求体内容 + * @throws IOException 创建 StringEntity 时可能抛出的 IO 异常 + */ + private void setupRequest (HttpPost httpPost, String requestBody) throws IOException { + // 创建 StringEntity 对象,用于封装请求体 + StringEntity entity = new StringEntity(requestBody); + // 设置请求体 + httpPost.setEntity(entity); + // 设置请求头,指定请求体的内容类型为 JSON + httpPost.setHeader("Content-Type", "application/json"); + } + + /** + * 处理响应实体 + * + * @param response HttpResponse 对象 + * @throws IOException 读取响应实体输入流时可能抛出的 IO 异常 + */ + private void handleResponseEntity (HttpResponse response, SseEmitter emitter) throws IOException { + // 获取响应实体 + HttpEntity responseEntity = response.getEntity(); + if (responseEntity != null) { + // 使用 try-with-resources 语句自动关闭输入流和 BufferedReader + try (InputStream inputStream = responseEntity.getContent(); + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) { + String line; + // 逐行读取响应内容并处理 + while ((line = reader.readLine()) != null) { + System.out.println("接收到的响应行数据: " + line); + String content = parseStreamLine(line); + if (content != null) { + emitter.send(SseEmitter.event().data(content)); + } + } + } + } + } + } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/ModelCompletionsReqVO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/ModelCompletionsReqVO.java index a4d212b70..f66e5ee46 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/ModelCompletionsReqVO.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/ModelCompletionsReqVO.java @@ -17,6 +17,9 @@ public class ModelCompletionsReqVO { private List messages; private Integer max_tokens = 4000; private Double temperature = 0.7; +// private Integer max_tokens; +// private Double temperature; + private Boolean stream; // private Integer max_length = 120000; @Data