From 0823596f97ab37014a0305f2f70a878911cdbbae Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Sat, 1 Mar 2025 13:17:39 +0800 Subject: [PATCH 01/12] =?UTF-8?q?refactor(module-llm):=20-=20=E9=87=8D?= =?UTF-8?q?=E6=9E=84=E5=BE=AE=E8=B0=83=E4=BB=BB=E5=8A=A1=E7=8A=B6=E6=80=81?= =?UTF-8?q?=E5=90=8C=E6=AD=A5=E6=9C=8D=E5=8A=A1=20-=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E4=BA=86=E5=BE=AE=E8=B0=83=E4=BB=BB=E5=8A=A1=E7=8A=B6=E6=80=81?= =?UTF-8?q?=E5=90=8C=E6=AD=A5=E5=AE=9A=E6=97=B6=E4=BB=BB=E5=8A=A1=E7=9A=84?= =?UTF-8?q?=E6=89=A7=E8=A1=8C=E9=80=BB=E8=BE=91=20-=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E4=BA=86=E6=97=A5=E5=BF=97=E8=AE=B0=E5=BD=95=20-=20=E9=87=8D?= =?UTF-8?q?=E6=9E=84=E4=BA=86=E6=A3=80=E6=9F=A5=E7=82=B9=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../backend/config/LLMBackendProperties.java | 3 + .../FineTuningTaskSyncService.java | 232 +++++++++++------- 2 files changed, 153 insertions(+), 82 deletions(-) diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/framework/backend/config/LLMBackendProperties.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/framework/backend/config/LLMBackendProperties.java index 9745977ca..245f845ea 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/framework/backend/config/LLMBackendProperties.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/framework/backend/config/LLMBackendProperties.java @@ -119,5 +119,8 @@ public class LLMBackendProperties { private String embedQuery; + /** + * 获取调优检查点列表 + */ private String checkFileList; } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java index 3a2ef34bd..16719484b 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java @@ -7,14 +7,12 @@ import cn.iocoder.yudao.module.llm.dal.mysql.finetuningtask.FineTuningTaskMapper import cn.iocoder.yudao.module.llm.dal.mysql.servername.ServerNameMapper; import cn.iocoder.yudao.module.llm.enums.FineTuningTaskStatusConstants; import cn.iocoder.yudao.module.llm.enums.FinetuningTaskStatusEnum; -import cn.iocoder.yudao.module.llm.service.basemodel.vo.ModelListRes; import cn.iocoder.yudao.module.llm.service.http.FineTuningTaskHttpService; import cn.iocoder.yudao.module.llm.service.http.TrainHttpService; import cn.iocoder.yudao.module.llm.service.http.vo.AigcFineTuningDetailRespVO; -import cn.iocoder.yudao.module.llm.service.http.vo.AigcModelDeployVO; +import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; -import com.esotericsoftware.minlog.Log; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import lombok.extern.slf4j.Slf4j; @@ -26,7 +24,6 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; -import java.net.MalformedURLException; import java.net.URL; import java.net.URLConnection; import java.time.Duration; @@ -49,46 +46,78 @@ public class FineTuningTaskSyncService { @Resource private FineTuningTaskHttpService fineTuningTaskHttpService; - @Scheduled(cron = "0 */1 * * * ?") - public void updateFineTuningTaskStatus() { - Log.info("FineTuningTaskSync 定时任务启动"); - List fineTuningTaskDOList = fineTuningTaskMapper.selectList(); - for (FineTuningTaskDO fineTuningTaskDO : fineTuningTaskDOList) { + private static final String FINE_TUNING_LOG = "~~~ fine - tuning ~~~ : "; - if(Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.TRAINING.getStatus()) - || Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.WAITING.getStatus())){ - ServerNameDO serverNameDO = serverNameMapper.selectById(fineTuningTaskDO.getGpuType()); - if (fineTuningTaskDO.getJobId() == null){ + @Scheduled(cron = "0/20 * * * * ?") + public void updateFineTuningTaskStatus () { + // 方法入口:微调任务状态同步定时任务 + log.info("{} [定时任务] FineTuningTaskSync 启动 - 开始同步微调任务状态", FINE_TUNING_LOG); + + // 1. 查询所有待处理的微调任务 + log.debug("正在查询所有微调任务..."); + List fineTuningTaskDOList = fineTuningTaskMapper.selectList(); + log.info("{} [任务查询] 共获取到{}个微调任务。", FINE_TUNING_LOG, fineTuningTaskDOList.size()); + + for (FineTuningTaskDO fineTuningTaskDO : fineTuningTaskDOList) { + log.info("{} [任务处理] 开始处理任务DO:{}", FINE_TUNING_LOG, JSON.toJSONString(fineTuningTaskDO)); + + // 检查任务状态是否为训练中或等待中 + if (Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.TRAINING.getStatus()) || Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.WAITING.getStatus())) { + log.info("{} [状态更新] 检测到需更新状态的任务ID:{},当前状态:{}", + FINE_TUNING_LOG, fineTuningTaskDO.getId(), fineTuningTaskDO.getStatus()); + + // 3. 获取GPU服务器信息 + + String hostUrl = getHostUrl(fineTuningTaskDO); + + // 4. 检查任务是否有 Job ID + if (fineTuningTaskDO.getJobId() == null) { + log.warn("{} 微调任务未关联 Job ID,任务ID: {}", FINE_TUNING_LOG, fineTuningTaskDO.getId()); continue; } - String hostUrl = serverNameDO!=null ?serverNameDO.getHost():""; - String queryJobs = "?filter={\"job_id\":\""+fineTuningTaskDO.getJobId()+"\"}"; - String respJobs = fineTuningTaskHttpService.modelTableQuery(new HashMap<>(), hostUrl,"fine_tuning_train_job",queryJobs); + + // 5. 构建API查询参数 + String queryJobs = "?filter={\"job_id\":\"" + fineTuningTaskDO.getJobId() + "\"}"; + log.info("{} [API请求] 发起任务查询 job_id={} 查询参数: {}", FINE_TUNING_LOG, fineTuningTaskDO.getJobId(), queryJobs); + + // 6. 调用HTTP服务获取任务详情 + String respJobs = fineTuningTaskHttpService.modelTableQuery( + new HashMap<>(), hostUrl, "fine_tuning_train_job", queryJobs); + log.info("{} [API响应] 任务ID:{} 原始响应数据:{}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), respJobs); + + // 7. 解析API响应数据 AigcFineTuningDetailRespVO resp = new AigcFineTuningDetailRespVO(); + try { + ObjectMapper mapper = new ObjectMapper(); - List aigcFineTuningDetailRespVO = mapper.readValue(respJobs,new TypeReference>() {}); - if (aigcFineTuningDetailRespVO != null && aigcFineTuningDetailRespVO.size() > 0){ - if(aigcFineTuningDetailRespVO.get(0).getJob_id().equals(fineTuningTaskDO.getJobId())){ - resp = aigcFineTuningDetailRespVO.get(0); - } + List respList = mapper.readValue( + respJobs, new TypeReference>() { + }); + + if (respList != null && !respList.isEmpty()) { + resp = respList.get(0); + log.info("{} [状态更新] 任务ID:{},当前状态:{},API响应状态:{}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), fineTuningTaskDO.getStatus(), resp.getTrain_status()); + } - }catch (Exception e){ - log.error("获取微调任务状态失败{}",e.getMessage()); + } catch (Exception e) { + log.error("{} 解析微调任务状态时发生异常。任务ID: {}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), e); } - if (resp == null){ + + if (resp == null) { continue; } + FineTuningTaskDO updateObj = new FineTuningTaskDO(); - if (ObjectUtil.isAllEmpty(resp.getTrain_status())){ + if (ObjectUtil.isAllEmpty(resp.getTrain_status())) { continue; } Integer status = FineTuningTaskStatusConstants.getStatus(resp.getTrain_status()); - if(status != null){ + if (status != null) { updateObj.setId(fineTuningTaskDO.getId()); } //如果微调任务由训练中变为已完成,则取模型列表中获取该任务生成的模型id - if((fineTuningTaskDO.getStatus() == 1 && status == 2) || (fineTuningTaskDO.getStatus() == 3 && status == 2)){ + if ((fineTuningTaskDO.getStatus() == 1 && status == 2) || (fineTuningTaskDO.getStatus() == 3 && status == 2)) { String jobModelName = fineTuningTaskDO.getJobModelName(); /* String resStr = trainHttpService.modelsList(""); Log.info("获取aicg模型列表返回数据内容{}",resStr); @@ -117,82 +146,121 @@ public class FineTuningTaskSyncService { }*/ updateObj.setStatus(2); // 获取模型id - String querModels = "?filter={\"model_name\":\""+resp.getFine_tuned_model()+"\"}"; - String resModels = fineTuningTaskHttpService.modelTableQuery(new HashMap<>(),hostUrl, "models",querModels); - log.info("获取 aigc models 表数据 info {}",resModels); + String querModels = "?filter={\"model_name\":\"" + resp.getFine_tuned_model() + "\"}"; + String resModels = fineTuningTaskHttpService.modelTableQuery(new HashMap<>(), hostUrl, "models", querModels); + log.info("获取 aigc models 表数据 info {}", resModels); JSONArray jsonArrayModels = JSONArray.parseArray(resModels); // 时长获取 - Duration duration = Duration.between(fineTuningTaskDO.getCreateTime(),LocalDateTime.now()); + Duration duration = Duration.between(fineTuningTaskDO.getCreateTime(), LocalDateTime.now()); long minutes = duration.toMinutes(); updateObj.setTrainDuration(String.valueOf(minutes)); - if(jsonArrayModels.size() > 0){ + if (jsonArrayModels.size() > 0) { JSONObject jsonObjectModels = jsonArrayModels.getJSONObject(0); updateObj.setJobModelListId(jsonObjectModels.getLong("id")); } - }catch (Exception e){ - log.error(" error {}",e.getMessage()); + } catch (Exception e) { + log.error(" error {}", e.getMessage()); } - try { - //获取检查点信息 - //todo 模型工厂的功能有问题,暂时写死 -// jobModelName = "Qwen2.5-0.5B-Instruct-147"; - String checkFileList = trainHttpService.getCheckFileList(hostUrl,jobModelName); - List checkpoints = new ArrayList<>(); - List fileUrls = new ArrayList<>(); - List fileList = JSONArray.parseArray(checkFileList,String.class); - Map map = new HashMap<>(); - for (String s : fileList) { - if(s.contains("checkpoint")){ - checkpoints.add(s); - } - } - if(checkpoints.size() > 0){ - for (String checkpoint : checkpoints) { - String filePath = "/" + checkpoint + "/trainer_state.json"; - fileUrls.add(filePath); - String fileUrl = trainHttpService.getCheckFile(hostUrl,jobModelName, filePath); - try { - URL url = new URL(fileUrl); - URLConnection urlConnection = url.openConnection(); - InputStream inputStream = urlConnection.getInputStream(); - BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); - String line; - StringBuilder content = new StringBuilder(); - while ((line = reader.readLine()) != null) { - content.append(line + "\n"); - } - reader.close(); - map.put(checkpoint,JSONObject.parseObject(content.toString())); - } catch (MalformedURLException e) { - e.printStackTrace(); - } catch (IOException e) { - e.printStackTrace(); - } - } - } - - updateObj.setCheckPointFilePath(JSONObject.toJSONString(fileUrls)); - updateObj.setCheckPointData(JSONObject.toJSONString(map)); - - }catch (Exception e){ - log.error(" error {}",e.getMessage()); - } + getCheckPoint(fineTuningTaskDO, jobModelName, hostUrl, updateObj); fineTuningTaskMapper.updateById(updateObj); - } - else { + } else { updateObj.setStatus(status); } - fineTuningTaskMapper.updateById(updateObj); + fineTuningTaskMapper.updateById(updateObj); + } else { + // 处理已成功但是没有获取到检查点错误 + String jobModelName = fineTuningTaskDO.getJobModelName(); + String hostUrl = getHostUrl(fineTuningTaskDO); + FineTuningTaskDO updateObj = new FineTuningTaskDO(); + updateObj.setId(fineTuningTaskDO.getId()); + getCheckPoint(fineTuningTaskDO, jobModelName, hostUrl, updateObj); + fineTuningTaskMapper.updateById(updateObj); } } } + private String getHostUrl (FineTuningTaskDO fineTuningTaskDO) { + ServerNameDO serverNameDO = serverNameMapper.selectById(fineTuningTaskDO.getGpuType()); + String hostUrl = (serverNameDO != null) ? serverNameDO.getHost() : ""; + log.info("{} [主机信息] 任务ID:{} GPU 类型: {} 关联主机URL:{}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), fineTuningTaskDO.getGpuType(), hostUrl); + return hostUrl; + } + + private void getCheckPoint (FineTuningTaskDO fineTuningTaskDO, String jobModelName, String hostUrl, FineTuningTaskDO updateObj) { + // 获取检查点信息 + try { + // String jobModelName = fineTuningTaskDO.getJobModelName(); + log.info("正在获取检查点信息,模型名称: {}", jobModelName); + String checkFileList = trainHttpService.getCheckFileList(hostUrl, jobModelName); + List checkpoints = new ArrayList<>(); + List fileUrls = new ArrayList<>(); + Map map = new HashMap<>(); + + // List fileList = JSONArray.parseArray(checkFileList, String.class); + // for (String s : fileList) { + // if (s.contains("checkpoint")) { + // checkpoints.add(s); + // } + // } + // 判断是否是数组 + if (checkFileList.startsWith("[")) { + List fileList = JSONArray.parseArray(checkFileList, String.class); + // 处理文件列表 + for (String s : fileList) { + if (s.contains("checkpoint")) { + checkpoints.add(s); + } + } + } else { + // TODO 处理错误信息 + JSONObject errorObj = JSON.parseObject(checkFileList); + String errorMsg = errorObj.getString("error"); + // 抛出异常或记录日志 + log.error("获取检查点文件列表时出错:{}", errorMsg); + } + + if (!checkpoints.isEmpty()) { + log.info("找到 {} 个检查点文件。", checkpoints.size()); + for (String checkpoint : checkpoints) { + String filePath = "/" + checkpoint + "/trainer_state.json"; + fileUrls.add(filePath); + String fileUrl = trainHttpService.getCheckFile(hostUrl, jobModelName, filePath); + + try { + URL url = new URL(fileUrl); + URLConnection urlConnection = url.openConnection(); + InputStream inputStream = urlConnection.getInputStream(); + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); + StringBuilder content = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + content.append(line).append("\n"); + } + reader.close(); + map.put(checkpoint, JSONObject.parseObject(content.toString())); + log.info("检查点文件解析成功。检查点: {}", checkpoint); + } catch (IOException e) { + log.error("读取检查点文件时发生异常。任务ID: {}", fineTuningTaskDO.getId(), e); + } + } + } else { + log.warn("未找到检查点文件,任务ID: {}", fineTuningTaskDO.getId()); + } + + updateObj.setCheckPointFilePath(JSONObject.toJSONString(fileUrls)); + updateObj.setCheckPointData(JSONObject.toJSONString(map)); + log.info("检查点信息更新完成。"); + } catch (Exception e) { + log.error("获取检查点信息时发生异常。任务ID: {}", fineTuningTaskDO.getId(), e); + } + } + } From 80590896cd7264b4370a62da142445d4dd8b8adb Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Sat, 1 Mar 2025 13:23:35 +0800 Subject: [PATCH 02/12] =?UTF-8?q?refactor(llm):=20=E8=B0=83=E6=95=B4?= =?UTF-8?q?=E5=BE=AE=E8=B0=83=E4=BB=BB=E5=8A=A1=E7=8A=B6=E6=80=81=E5=90=8C?= =?UTF-8?q?=E6=AD=A5=E5=AE=9A=E6=97=B6=E4=BB=BB=E5=8A=A1=E7=9A=84=E6=89=A7?= =?UTF-8?q?=E8=A1=8C=E9=A2=91=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将定时任务的执行频率从每 20 秒一次修改为每分钟一次 --- .../llm/service/finetuningtask/FineTuningTaskSyncService.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java index 16719484b..b1a73527d 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java @@ -48,7 +48,7 @@ public class FineTuningTaskSyncService { private static final String FINE_TUNING_LOG = "~~~ fine - tuning ~~~ : "; - @Scheduled(cron = "0/20 * * * * ?") + @Scheduled(cron = "0 */1 * * * ?") public void updateFineTuningTaskStatus () { // 方法入口:微调任务状态同步定时任务 log.info("{} [定时任务] FineTuningTaskSync 启动 - 开始同步微调任务状态", FINE_TUNING_LOG); From 438559fbc559403f84d53a405ba11fc3d70e46b5 Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Sat, 1 Mar 2025 13:44:38 +0800 Subject: [PATCH 03/12] =?UTF-8?q?fix(llm):=20=E4=BF=AE=E5=A4=8D=20md=20?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E8=BD=AC=E6=8D=A2=E5=90=8E=E7=9A=84=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 md 文件转换后的文件格式从 docx 修改为 txt -确保文件名后缀正确替换,避免产生错误的文件类型 --- .../iocoder/yudao/module/llm/service/http/RagHttpService.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java index 09ee4d04b..bc5a28512 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java @@ -428,7 +428,7 @@ public class RagHttpService { if ("md".equals(fileSuffix)) { log.info("正在处理 md 文件"); try { - tempFilePath= converterMdToTxt(tempFilePath.toString(), tempFilePath.toString().replace(".md", ".docx")); + tempFilePath= converterMdToTxt(tempFilePath.toString(), tempFilePath.toString().replace(".md", ".txt")); } catch (Exception e) { throw new RuntimeException(e); } From 866f838245a06fea8d9beec6ced091a0fa1e9b93 Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Sat, 1 Mar 2025 13:52:27 +0800 Subject: [PATCH 04/12] =?UTF-8?q?refactor(module-llm):=20=E6=B3=A8?= =?UTF-8?q?=E9=87=8A=E6=8E=89=20MD=20=E6=96=87=E4=BB=B6=E5=A4=84=E7=90=86?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 注释掉了处理 MD 文件的代码块 - 保留了创建 OkHttpClient 实例的逻辑 --- .../module/llm/service/http/RagHttpService.java | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java index bc5a28512..e005c7fc1 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java @@ -425,14 +425,14 @@ public class RagHttpService { } } - if ("md".equals(fileSuffix)) { - log.info("正在处理 md 文件"); - try { - tempFilePath= converterMdToTxt(tempFilePath.toString(), tempFilePath.toString().replace(".md", ".txt")); - } catch (Exception e) { - throw new RuntimeException(e); - } - } +// if ("md".equals(fileSuffix)) { +// log.info("正在处理 md 文件"); +// try { +// tempFilePath= converterMdToTxt(tempFilePath.toString(), tempFilePath.toString().replace(".md", ".txt")); +// } catch (Exception e) { +// throw new RuntimeException(e); +// } +// } // 创建 OkHttpClient 实例 log.info("创建 OkHttpClient 实例,设置超时时间为 3 分钟"); From 0edfd15f1097c7fd40bd026bb26344fcf3659673 Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Sat, 1 Mar 2025 17:27:26 +0800 Subject: [PATCH 05/12] =?UTF-8?q?refactor(module-llm):=20=E6=B3=A8?= =?UTF-8?q?=E9=87=8A=E6=8E=89=20doc=20=E6=96=87=E4=BB=B6=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 注释掉了处理 doc 文件的代码块 - 保留了处理 md 文件的注释代码块 --- .../llm/service/http/RagHttpService.java | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java index e005c7fc1..394be88a5 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java @@ -415,15 +415,15 @@ public class RagHttpService { Path tempFilePath = downloadFileToTemp(fileUrl, fileName); log.info("文件已下载到临时目录: {}", tempFilePath); - String fileSuffix = getFileSuffix(fileName); - if ("doc".equals(fileSuffix)) { - log.info("正在处理 doc 文件"); - try { - tempFilePath= converterDocToDocx(tempFilePath.toString(), tempFilePath.toString().replace(".doc", ".docx")); - } catch (Exception e) { - throw new RuntimeException(e); - } - } +// String fileSuffix = getFileSuffix(fileName); +// if ("doc".equals(fileSuffix)) { +// log.info("正在处理 doc 文件"); +// try { +// tempFilePath= converterDocToDocx(tempFilePath.toString(), tempFilePath.toString().replace(".doc", ".docx")); +// } catch (Exception e) { +// throw new RuntimeException(e); +// } +// } // if ("md".equals(fileSuffix)) { // log.info("正在处理 md 文件"); From b6e65d777a72fb695d66fcca4091fe3705bb40d7 Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Sat, 1 Mar 2025 17:27:42 +0800 Subject: [PATCH 06/12] =?UTF-8?q?refactor(module-llm):=20=E6=B3=A8?= =?UTF-8?q?=E9=87=8A=E6=8E=89=20doc=20=E6=96=87=E4=BB=B6=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 注释掉了处理 doc 文件的代码块 - 保留了处理 md 文件的注释代码块 --- .../llm/controller/admin/conversation/vo/ChatReqVO.java | 4 ++++ .../admin/datarefluxdata/vo/DataRefluxDataSaveReqVO.java | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ChatReqVO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ChatReqVO.java index 14ce1d336..b567ea182 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ChatReqVO.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ChatReqVO.java @@ -27,4 +27,8 @@ public class ChatReqVO { private String systemPrompt; @Schema(description = "知识库id") private Long knowledge; + @Schema(description = "单次回复限制max_tokens") + private String maxTokens; + @Schema(description = "随机性temperature") + private Long temperature; } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/datarefluxdata/vo/DataRefluxDataSaveReqVO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/datarefluxdata/vo/DataRefluxDataSaveReqVO.java index 981f2d932..bd1df6e13 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/datarefluxdata/vo/DataRefluxDataSaveReqVO.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/datarefluxdata/vo/DataRefluxDataSaveReqVO.java @@ -32,4 +32,8 @@ public class DataRefluxDataSaveReqVO { @Schema(description = "Response") private String response; + @Schema(description = "单次回复限制max_tokens") + private String maxTokens; + @Schema(description = "随机性temperature") + private Long temperature; } From 6e19d81a725a4c72fe5504d841d22b803a3a7429 Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Sat, 1 Mar 2025 18:18:34 +0800 Subject: [PATCH 07/12] =?UTF-8?q?feat(module-llm):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E5=AF=B9=E8=AF=9D=E6=B5=81=E5=BC=8F=E5=A4=84=E7=90=86=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 /stream-chat 接口,使用 SSE 进行流式响应- 实现 chatStream 方法,处理对话流式请求 - 添加 modelCompletionsStream 方法,支持模型补全流式处理 - 更新 ConversationService 接口,增加流式处理相关方法 - 在 pom.xml 中添加 spring-webflux 依赖 --- yudao-module-llm/yudao-module-llm-biz/pom.xml | 4 + .../conversation/ConversationController.java | 33 ++++ .../conversation/ConversationService.java | 11 +- .../conversation/ConversationServiceImpl.java | 182 +++++++++++++++++- .../module/llm/service/http/ModelService.java | 81 ++++++++ 5 files changed, 300 insertions(+), 11 deletions(-) diff --git a/yudao-module-llm/yudao-module-llm-biz/pom.xml b/yudao-module-llm/yudao-module-llm-biz/pom.xml index f8a8a4c5d..e01a56f9b 100644 --- a/yudao-module-llm/yudao-module-llm-biz/pom.xml +++ b/yudao-module-llm/yudao-module-llm-biz/pom.xml @@ -125,6 +125,10 @@ poi-scratchpad 5.2.3 + + org.springframework + spring-webflux + diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java index 6096a5850..f019cead0 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java @@ -2,8 +2,12 @@ package cn.iocoder.yudao.module.llm.controller.admin.conversation; import cn.hutool.json.JSONArray; import cn.hutool.json.JSONObject; +import cn.iocoder.yudao.framework.common.exception.ServiceException; +import cn.iocoder.yudao.module.llm.service.http.vo.ModelCompletionsReqVO; import cn.iocoder.yudao.module.llm.service.http.vo.TextToImageReqVo; import com.alibaba.fastjson.JSON; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.MediaType; import org.springframework.web.bind.annotation.*; import javax.annotation.Resource; import org.springframework.validation.annotation.Validated; @@ -15,8 +19,10 @@ import io.swagger.v3.oas.annotations.Operation; import javax.validation.constraints.*; import javax.validation.*; import javax.servlet.http.*; +import java.time.LocalDateTime; import java.util.*; import java.io.IOException; +import java.util.concurrent.CompletableFuture; import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageResult; @@ -32,11 +38,13 @@ import static cn.iocoder.yudao.framework.apilog.core.enums.OperateTypeEnum.*; import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*; import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO; import cn.iocoder.yudao.module.llm.service.conversation.ConversationService; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @Tag(name = "管理后台 - 大模型对话记录") @RestController @RequestMapping("/llm/conversation") @Validated +@Slf4j public class ConversationController { @Resource @@ -101,6 +109,31 @@ public class ConversationController { public CommonResult chat(@Valid @RequestBody ChatReqVO chatReqVO) { return success(conversationService.chat(chatReqVO)); } + + /** + * 对话推理接口,使用 SSE 进行流式响应 + * @param chatReqVO 对话请求对象 + * @return SseEmitter 对象,用于流式发送响应 + */ + @PostMapping("/stream-chat") + public SseEmitter streamChat(@Valid @RequestBody ChatReqVO chatReqVO) { + log.info("收到对话推理请求,请求参数: {}", chatReqVO); + SseEmitter emitter = new SseEmitter(); + try { + conversationService.chatStream(chatReqVO, emitter); + } catch (Exception e) { + log.error("处理对话推理请求时发生异常", e); + try { + emitter.completeWithError(e); + } catch (Exception ex) { + log.error("无法完成 SseEmitter 错误处理", ex); + } + } + log.info("返回 SseEmitter 对象,准备进行流式响应"); + return emitter; + } + + @PostMapping("/text-to-image") @Operation(summary = "文字转图片接口") public CommonResult textToImage(@Valid @RequestBody TextToImageReqVo req) { diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java index 7c582bb70..9e1388ee0 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java @@ -9,7 +9,9 @@ import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*; import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageParam; +import cn.iocoder.yudao.module.llm.service.http.vo.ModelCompletionsReqVO; import cn.iocoder.yudao.module.llm.service.http.vo.TextToImageReqVo; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; /** * 大模型对话记录 Service 接口 @@ -64,4 +66,11 @@ public interface ConversationService { * @return */ JSONArray textToImage(TextToImageReqVo req); -} \ No newline at end of file + + /** + * 聊天流 + * @param chatReqVO chatReqVO + * @param emitter emitter + */ + void chatStream (@Valid ChatReqVO chatReqVO, SseEmitter emitter); +} diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java index 4bbffedce..9ad223380 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java @@ -1,5 +1,7 @@ package cn.iocoder.yudao.module.llm.service.conversation; +import cn.hutool.core.bean.BeanUtil; +import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONArray; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.http.HttpUtils; @@ -32,11 +34,16 @@ import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.toolkit.CollectionUtils; +import lombok.extern.slf4j.Slf4j; import org.springframework.data.redis.core.StringRedisTemplate; +import org.springframework.http.MediaType; import org.springframework.stereotype.Service; import org.springframework.validation.annotation.Validated; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; +import java.io.IOException; import java.util.*; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; @@ -49,6 +56,7 @@ import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*; */ @Service @Validated +@Slf4j public class ConversationServiceImpl implements ConversationService { @Resource @@ -130,19 +138,11 @@ public class ConversationServiceImpl implements ConversationService { ApplicationSaveReqVO applicationSaveReqVO = BeanUtils.toBean(application, ApplicationSaveReqVO.class); applicationService.updateApplication(applicationSaveReqVO); } - /*Long promptId = application.getPromptId(); - PromptTemplatesRespVO promptTemplates = promptTemplatesService.getPromptTemplates(promptId); - if(promptTemplates != null){ - chatReqVO.setSystemPrompt(promptTemplates.getTemplateText()); - }*/ + chatReqVO.setSystemPrompt(application.getPrompt()); } } - /* if (Objects.equals(1, chatReqVO.getModelType())){ - return publicModelChat(chatReqVO); - }else { - return privateModelChat(chatReqVO); - }*/ + return publicModelChat(chatReqVO); } @@ -273,9 +273,171 @@ public class ConversationServiceImpl implements ConversationService { dataRefluxDataSaveReqVO.setPrompt(chatReqVO.getPrompt()); dataRefluxDataSaveReqVO.setResponse(modelCompletionsRespVO.getAnswer()); dataRefluxDataSaveReqVO.setSystem(modelCompletionsRespVO.getSystem()); + dataRefluxDataSaveReqVO.setMaxTokens(chatReqVO.getMaxTokens()); + dataRefluxDataSaveReqVO.setTemperature(chatReqVO.getTemperature()); dataRefluxDataService.saveDataRefluxData(dataRefluxDataSaveReqVO); return chatRespVO; } + + /** + * 处理对话请求,以流式方式返回结果 + * @param chatReqVO 对话请求对象 + * @param emitter SseEmitter 对象,用于流式发送响应 + */ + public void chatStream(ChatReqVO chatReqVO, SseEmitter emitter) { + log.info("开始处理对话请求,请求参数: {}", chatReqVO); + // 检查系统提示信息,如果为空则尝试从应用中获取 + if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().equals("")) { + if (chatReqVO.getApplicationId() != null) { + log.info("系统提示信息为空,尝试从应用中获取,应用 ID: {}", chatReqVO.getApplicationId()); + ApplicationRespVO application = applicationService.getApplication(chatReqVO.getApplicationId()); + List messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1); + if (CollectionUtils.isEmpty(messageHistoryList)) { + log.info("聊天历史记录为空,更新应用聊天计数"); + application.setChatCount(application.getChatCount() + 1); + ApplicationSaveReqVO applicationSaveReqVO = BeanUtil.toBean(application, ApplicationSaveReqVO.class); + applicationService.updateApplication(applicationSaveReqVO); + } + chatReqVO.setSystemPrompt(application.getPrompt()); + log.info("已更新系统提示信息为: {}", chatReqVO.getSystemPrompt()); + } + } + // 调用公共模型聊天流式处理方法 + publicModelChatStream(chatReqVO, emitter); + } + + /** + * 公共模型聊天流式处理方法 + * @param chatReqVO 对话请求对象 + * @param emitter SseEmitter 对象,用于流式发送响应 + */ + public void publicModelChatStream(ChatReqVO chatReqVO, SseEmitter emitter) { + log.info("开始公共模型聊天流式处理,请求参数: {}", chatReqVO); + // 检查 UUID 是否为空,若为空则生成一个 + if (StrUtil.isBlank(chatReqVO.getUuid())) { + log.info("UUID 为空,生成新的 UUID"); + chatReqVO.setUuid(UUID.randomUUID().toString()); + } + String model = null; + String selfModelUrl = ""; + // 根据模型类型获取模型信息 + if (Objects.equals(1, chatReqVO.getModelType())) { + log.info("使用预制模型,模型 ID: {}", chatReqVO.getModelId()); + BaseModelDO baseModelDO = baseModelService.getBaseModel(chatReqVO.getModelId()); + if (baseModelDO == null) { + log.error("预制模型不存在,模型 ID: {}", chatReqVO.getModelId()); + try { + emitter.completeWithError(new RuntimeException("BASE_MODEL_NOT_EXISTS")); + } catch (Exception e) { + log.error("无法完成 SseEmitter 错误处理", e); + } + return; + } + selfModelUrl = baseModelDO.getChatUrl(); + model = baseModelDO.getAigcModelName(); + log.info("获取到预制模型信息,模型名称: {}, 聊天 URL: {}", model, selfModelUrl); + } else if (Objects.equals(0, chatReqVO.getModelType())) { + log.info("使用自定义模型,模型 ID: {}", chatReqVO.getModelId()); + ModelServiceDO modelServiceDO = modelServiceMapper.selectById(chatReqVO.getModelId()); + if (modelServiceDO == null) { + log.error("自定义模型服务不存在,模型 ID: {}", chatReqVO.getModelId()); + try { + emitter.completeWithError(new RuntimeException("MODEL_SERVICE_NOT_EXISTS")); + } catch (Exception e) { + log.error("无法完成 SseEmitter 错误处理", e); + } + return; + } + model = modelServiceDO.getBaseModelName(); + selfModelUrl = modelServiceDO.getModelUrl(); + log.info("获取到自定义模型信息,模型名称: {}, 模型 URL: {}", model, selfModelUrl); + } else { + log.error("无效的模型类型,模型类型: {}", chatReqVO.getModelType()); + try { + emitter.completeWithError(new RuntimeException("BASE_MODEL_NOT_EXISTS")); + } catch (Exception e) { + log.error("无法完成 SseEmitter 错误处理", e); + } + return; + } + + List messages = new ArrayList<>(); + + // 如果知识库 ID 不为空,先调用知识库获取相关信息 + StringBuilder knowledgeBase = new StringBuilder(); + if (chatReqVO.getKnowledge() != null && chatReqVO.getKnowledge() != 0) { + log.info("知识库 ID 不为空,开始查询知识库,知识库 ID: {}", chatReqVO.getKnowledge()); + LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(); + queryWrapper.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, chatReqVO.getKnowledge()); + List fileList = knowledgeDocumentsMapper.selectList(queryWrapper); + for (KnowledgeDocumentsDO knowledgeDocumentsDO : fileList) { + Long id = knowledgeDocumentsDO.getId(); + KnowledgeRagEmbedQueryVO knowledgeRagEmbedQueryVO = new KnowledgeRagEmbedQueryVO(); + knowledgeRagEmbedQueryVO.setFile_id(id.toString()); + knowledgeRagEmbedQueryVO.setQuery(chatReqVO.getPrompt()); + String result = HttpUtils.post(llmBackendProperties.getEmbedQuery(), null, JSON.toJSONString(knowledgeRagEmbedQueryVO)); + com.alibaba.fastjson.JSONArray jsonArray = JSON.parseArray(result); + if (jsonArray != null && !jsonArray.isEmpty()) { + JSONArray jsonArray1 = (JSONArray) jsonArray.get(0); + JSONObject jsonObject = (JSONObject) jsonArray1.get(0); + knowledgeBase.append(jsonObject.get("page_content")); + } + } + log.info("知识库查询完成,获取到的信息: {}", knowledgeBase.toString()); + } + String mess = chatReqVO.getSystemPrompt() + "" + knowledgeBase.toString() + ""; + // 查询历史记录消息,并将查询出来的知识信息放入到 role = system 的消息中 + List messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1); + if (messageHistoryList != null && !messageHistoryList.isEmpty()) { + log.info("存在聊天历史记录,处理历史记录消息"); + for (String messageHistory : messageHistoryList) { + ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class); + if (modelCompletionsMessage.getRole().equals("system")) { + modelCompletionsMessage.setContent(mess); + stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage)); + } + messages.add(modelCompletionsMessage); + } + } else { + log.info("不存在聊天历史记录,创建新的系统消息"); + ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage(); + systemMessage.setRole("system"); + systemMessage.setContent(mess); + stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage)); + messages.add(systemMessage); + } + + // 创建用户消息 + ModelCompletionsReqVO.ModelCompletionsMessage message = new ModelCompletionsReqVO.ModelCompletionsMessage(); + message.setRole("user"); + message.setContent(chatReqVO.getPrompt()); + messages.add(message); + + // 构建模型补全请求对象 + ModelCompletionsReqVO modelCompletionsReqVO = new ModelCompletionsReqVO(); + modelCompletionsReqVO.setMessages(messages); + modelCompletionsReqVO.setModel(model); + log.info("构建模型补全请求对象,请求参数: {}", modelCompletionsReqVO); + + // 调用模型服务进行流式处理 + modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter); + + // 将用户消息存入缓存 + stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(message)); + + // 保存数据回流信息 + DataRefluxDataSaveReqVO dataRefluxDataSaveReqVO = new DataRefluxDataSaveReqVO(); + dataRefluxDataSaveReqVO.setModelServiceId(chatReqVO.getModelId()); + dataRefluxDataSaveReqVO.setModelType(chatReqVO.getModelType()); + dataRefluxDataSaveReqVO.setPrompt(chatReqVO.getPrompt()); + dataRefluxDataSaveReqVO.setSystem("助手"); + dataRefluxDataSaveReqVO.setMaxTokens(chatReqVO.getMaxTokens()); + dataRefluxDataSaveReqVO.setTemperature(chatReqVO.getTemperature()); + dataRefluxDataService.saveDataRefluxData(dataRefluxDataSaveReqVO); + log.info("数据回流信息保存完成"); + } + + /** * 私有模型聊天 * @param chatReqVO diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java index 017daec5a..1eff7fe69 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java @@ -1,5 +1,6 @@ package cn.iocoder.yudao.module.llm.service.http; +import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.common.util.http.HttpUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.llm.dal.dataobject.servername.ServerNameDO; @@ -12,11 +13,14 @@ import com.alibaba.fastjson.JSONObject; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.stereotype.Service; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; +import java.net.HttpURLConnection; +import java.net.URL; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -129,6 +133,83 @@ public class ModelService { throw new RuntimeException("模型补全请求处理失败", e); } } + /** + * 模型补全流式处理方法 + * @param url 模型服务的 URL + * @param req 模型补全请求对象 + * @param emitter SseEmitter 对象,用于流式发送响应 + */ + public void modelCompletionsStream(String url, ModelCompletionsReqVO req, SseEmitter emitter) { + log.info("开始处理模型补全请求,请求参数: {}", req); + // 检查模型是否为空,若为空则设置默认模型 + if (StrUtil.isBlank(req.getModel())) { + log.info("模型 ID 为空,设置为默认模型: {}", DEFAULT_MODEL_ID); + req.setModel(DEFAULT_MODEL_ID); + } + // 记录请求信息 + log.info("请求参数: {}", JSON.toJSONString(req)); + try { + // 发起 HTTP POST 请求 + URL targetUrl; + if (StrUtil.isBlank(url)) { + log.info("URL 为空,使用默认 URL: {}", llmBackendProperties.getModelCompletions()); + targetUrl = new URL(llmBackendProperties.getModelCompletions()); + } else { + log.info("使用指定 URL: {}", url); + targetUrl = new URL(url); + } + HttpURLConnection connection = (HttpURLConnection) targetUrl.openConnection(); + connection.setRequestMethod("POST"); + connection.setRequestProperty("Content-Type", "application/json"); + connection.setDoOutput(true); + connection.getOutputStream().write(JSON.toJSONString(req).getBytes()); + + BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream())); + String line; + StringBuilder responseBuilder = new StringBuilder(); + while ((line = reader.readLine()) != null) { + // 解析流式响应内容 + if (StrUtil.isNotBlank(line)) { + log.info("收到流式响应内容: {}", line); + ChatCompletion chatCompletion = JSON.parseObject(line, ChatCompletion.class); + if (StrUtil.isBlank(chatCompletion.getDetail())) { + String respContent = chatCompletion.getChoices().get(0).getMessage().getContent(); + String patternString = "(.*?)"; + Pattern pattern = Pattern.compile(patternString, Pattern.DOTALL); + Matcher matcher = pattern.matcher(respContent); + String answerContent = matcher.replaceAll(""); + // 流式发送数据 + try { + emitter.send(SseEmitter.event().data(answerContent)); + log.info("已发送流式响应数据: {}", answerContent); + } catch (Exception e) { + log.error("发送流式响应数据时发生异常", e); + try { + emitter.completeWithError(e); + } catch (Exception ex) { + log.error("无法完成 SseEmitter 错误处理", ex); + } + return; + } + } + } + } + // 完成流式传输 + try { + emitter.complete(); + log.info("流式传输完成"); + } catch (Exception e) { + log.error("完成流式传输时发生异常", e); + } + } catch (Exception e) { + log.error("处理模型补全请求时发生异常", e); + try { + emitter.completeWithError(e); + } catch (Exception ex) { + log.error("无法完成 SseEmitter 错误处理", ex); + } + } + } /** * 通过 GPU 类型获取对应的主机地址 From 3fecb7e3789cc55a232397494622604681517aa9 Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Sat, 1 Mar 2025 20:58:40 +0800 Subject: [PATCH 08/12] =?UTF-8?q?refactor(module-llm):=E8=B0=83=E6=95=B4?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 恢复了对 .doc 文件的处理逻辑,将其转换为 .docx - 注释掉了对 .md 文件的处理逻辑- 优化了代码格式和缩进 --- .../llm/service/http/RagHttpService.java | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java index 394be88a5..aa0998ef9 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java @@ -415,24 +415,24 @@ public class RagHttpService { Path tempFilePath = downloadFileToTemp(fileUrl, fileName); log.info("文件已下载到临时目录: {}", tempFilePath); -// String fileSuffix = getFileSuffix(fileName); -// if ("doc".equals(fileSuffix)) { -// log.info("正在处理 doc 文件"); -// try { -// tempFilePath= converterDocToDocx(tempFilePath.toString(), tempFilePath.toString().replace(".doc", ".docx")); -// } catch (Exception e) { -// throw new RuntimeException(e); -// } -// } + String fileSuffix = getFileSuffix(fileName); + if ("doc".equals(fileSuffix)) { + log.info("正在处理 doc 文件"); + try { + tempFilePath = converterDocToDocx(tempFilePath.toString(), tempFilePath.toString().replace(".doc", ".docx")); + } catch (Exception e) { + throw new RuntimeException(e); + } + } -// if ("md".equals(fileSuffix)) { -// log.info("正在处理 md 文件"); -// try { -// tempFilePath= converterMdToTxt(tempFilePath.toString(), tempFilePath.toString().replace(".md", ".txt")); -// } catch (Exception e) { -// throw new RuntimeException(e); -// } -// } + // if ("md".equals(fileSuffix)) { + // log.info("正在处理 md 文件"); + // try { + // tempFilePath= converterMdToTxt(tempFilePath.toString(), tempFilePath.toString().replace(".md", ".txt")); + // } catch (Exception e) { + // throw new RuntimeException(e); + // } + // } // 创建 OkHttpClient 实例 log.info("创建 OkHttpClient 实例,设置超时时间为 3 分钟"); @@ -626,7 +626,7 @@ public class RagHttpService { return path; } - public static Path converterDocToDocx(String inputPath, String outputPath) throws Exception { + public static Path converterDocToDocx (String inputPath, String outputPath) throws Exception { // 读取DOC文档 try (HWPFDocument doc = new HWPFDocument(Files.newInputStream(Paths.get(inputPath)))) { XWPFDocument docx = new XWPFDocument(); @@ -650,6 +650,7 @@ public class RagHttpService { return Paths.get(outputPath); } } + /** * 处理响应结果 */ From a4e7cd67b7814d8076fb016ecbfce972a92f3a4b Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Sun, 2 Mar 2025 01:21:01 +0800 Subject: [PATCH 09/12] =?UTF-8?q?refactor(module-llm):=20=E6=B3=A8?= =?UTF-8?q?=E9=87=8A=E6=8E=89=20doc=20=E6=96=87=E4=BB=B6=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 注释掉了处理 doc 文件转换为 docx 文件的代码块 - 保留了其他文件类型处理的注释代码 --- .../llm/service/http/RagHttpService.java | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java index aa0998ef9..ac683ba3e 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java @@ -415,15 +415,15 @@ public class RagHttpService { Path tempFilePath = downloadFileToTemp(fileUrl, fileName); log.info("文件已下载到临时目录: {}", tempFilePath); - String fileSuffix = getFileSuffix(fileName); - if ("doc".equals(fileSuffix)) { - log.info("正在处理 doc 文件"); - try { - tempFilePath = converterDocToDocx(tempFilePath.toString(), tempFilePath.toString().replace(".doc", ".docx")); - } catch (Exception e) { - throw new RuntimeException(e); - } - } +// String fileSuffix = getFileSuffix(fileName); +// if ("doc".equals(fileSuffix)) { +// log.info("正在处理 doc 文件"); +// try { +// tempFilePath = converterDocToDocx(tempFilePath.toString(), tempFilePath.toString().replace(".doc", ".docx")); +// } catch (Exception e) { +// throw new RuntimeException(e); +// } +// } // if ("md".equals(fileSuffix)) { // log.info("正在处理 md 文件"); From 6ccf593f0aabe9c1edc3c4fd517dc6c2a17b30dd Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Sun, 2 Mar 2025 10:31:11 +0800 Subject: [PATCH 10/12] =?UTF-8?q?refactor(llm):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=A1=A5=E5=85=A8=E6=B5=81=E5=BC=8F=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 优化了 ChatReqVO 和 DataRefluxDataSaveReqVO 中的数据类型 - 重构了 ConversationController 和 ConversationService 中的 chatStream 方法 - 重新实现了 ModelService 中的 modelCompletionsStream 方法,采用更高效的处理方式 - 新增了辅助方法 parseStreamLine、extractJsonFromDataString、setupRequest 和 handleResponseEntity 以提高代码可读性和可维护性 --- .../conversation/ConversationController.java | 4 +- .../admin/conversation/vo/ChatReqVO.java | 4 +- .../vo/DataRefluxDataSaveReqVO.java | 4 +- .../conversation/ConversationService.java | 3 +- .../conversation/ConversationServiceImpl.java | 9 +- .../module/llm/service/http/ModelService.java | 207 +++++++++++------- .../http/vo/ModelCompletionsReqVO.java | 3 + 7 files changed, 145 insertions(+), 89 deletions(-) diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java index f019cead0..52c6033cc 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/ConversationController.java @@ -116,11 +116,11 @@ public class ConversationController { * @return SseEmitter 对象,用于流式发送响应 */ @PostMapping("/stream-chat") - public SseEmitter streamChat(@Valid @RequestBody ChatReqVO chatReqVO) { + public SseEmitter streamChat(@Valid @RequestBody ChatReqVO chatReqVO,HttpServletResponse response) { log.info("收到对话推理请求,请求参数: {}", chatReqVO); SseEmitter emitter = new SseEmitter(); try { - conversationService.chatStream(chatReqVO, emitter); + conversationService.chatStream(chatReqVO, emitter,response); } catch (Exception e) { log.error("处理对话推理请求时发生异常", e); try { diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ChatReqVO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ChatReqVO.java index b567ea182..c7f211bf5 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ChatReqVO.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/conversation/vo/ChatReqVO.java @@ -28,7 +28,7 @@ public class ChatReqVO { @Schema(description = "知识库id") private Long knowledge; @Schema(description = "单次回复限制max_tokens") - private String maxTokens; + private Integer maxTokens; @Schema(description = "随机性temperature") - private Long temperature; + private Double temperature; } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/datarefluxdata/vo/DataRefluxDataSaveReqVO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/datarefluxdata/vo/DataRefluxDataSaveReqVO.java index bd1df6e13..2ec9c537f 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/datarefluxdata/vo/DataRefluxDataSaveReqVO.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/controller/admin/datarefluxdata/vo/DataRefluxDataSaveReqVO.java @@ -33,7 +33,7 @@ public class DataRefluxDataSaveReqVO { private String response; @Schema(description = "单次回复限制max_tokens") - private String maxTokens; + private Integer maxTokens; @Schema(description = "随机性temperature") - private Long temperature; + private Double temperature; } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java index 9e1388ee0..9a1397859 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationService.java @@ -1,6 +1,7 @@ package cn.iocoder.yudao.module.llm.service.conversation; import java.util.*; +import javax.servlet.http.HttpServletResponse; import javax.validation.*; import cn.hutool.json.JSONArray; @@ -72,5 +73,5 @@ public interface ConversationService { * @param chatReqVO chatReqVO * @param emitter emitter */ - void chatStream (@Valid ChatReqVO chatReqVO, SseEmitter emitter); + void chatStream (@Valid ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response); } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java index 9ad223380..5145c1185 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/conversation/ConversationServiceImpl.java @@ -43,6 +43,7 @@ import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; +import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.util.*; @@ -284,7 +285,7 @@ public class ConversationServiceImpl implements ConversationService { * @param chatReqVO 对话请求对象 * @param emitter SseEmitter 对象,用于流式发送响应 */ - public void chatStream(ChatReqVO chatReqVO, SseEmitter emitter) { + public void chatStream(ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response) { log.info("开始处理对话请求,请求参数: {}", chatReqVO); // 检查系统提示信息,如果为空则尝试从应用中获取 if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().equals("")) { @@ -303,7 +304,7 @@ public class ConversationServiceImpl implements ConversationService { } } // 调用公共模型聊天流式处理方法 - publicModelChatStream(chatReqVO, emitter); + publicModelChatStream(chatReqVO, emitter, response); } /** @@ -311,7 +312,7 @@ public class ConversationServiceImpl implements ConversationService { * @param chatReqVO 对话请求对象 * @param emitter SseEmitter 对象,用于流式发送响应 */ - public void publicModelChatStream(ChatReqVO chatReqVO, SseEmitter emitter) { + public void publicModelChatStream(ChatReqVO chatReqVO, SseEmitter emitter,HttpServletResponse response) { log.info("开始公共模型聊天流式处理,请求参数: {}", chatReqVO); // 检查 UUID 是否为空,若为空则生成一个 if (StrUtil.isBlank(chatReqVO.getUuid())) { @@ -420,7 +421,7 @@ public class ConversationServiceImpl implements ConversationService { log.info("构建模型补全请求对象,请求参数: {}", modelCompletionsReqVO); // 调用模型服务进行流式处理 - modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter); + modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter, response); // 将用户消息存入缓存 stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(message)); diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java index 1eff7fe69..11cc5cc27 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java @@ -1,6 +1,5 @@ package cn.iocoder.yudao.module.llm.service.http; -import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.common.util.http.HttpUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.llm.dal.dataobject.servername.ServerNameDO; @@ -12,15 +11,21 @@ import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; +import org.apache.http.HttpEntity; +import org.apache.http.HttpResponse; +import org.apache.http.client.HttpClient; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.HttpClients; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; -import java.io.BufferedReader; -import java.io.InputStream; -import java.io.InputStreamReader; +import javax.servlet.http.HttpServletResponse; +import java.io.*; import java.net.HttpURLConnection; import java.net.URL; +import java.nio.charset.StandardCharsets; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -133,93 +138,73 @@ public class ModelService { throw new RuntimeException("模型补全请求处理失败", e); } } + /** * 模型补全流式处理方法 + * * @param url 模型服务的 URL * @param req 模型补全请求对象 - * @param emitter SseEmitter 对象,用于流式发送响应 */ - public void modelCompletionsStream(String url, ModelCompletionsReqVO req, SseEmitter emitter) { - log.info("开始处理模型补全请求,请求参数: {}", req); - // 检查模型是否为空,若为空则设置默认模型 - if (StrUtil.isBlank(req.getModel())) { - log.info("模型 ID 为空,设置为默认模型: {}", DEFAULT_MODEL_ID); - req.setModel(DEFAULT_MODEL_ID); - } - // 记录请求信息 - log.info("请求参数: {}", JSON.toJSONString(req)); + public void modelCompletionsStream (String url, ModelCompletionsReqVO req, SseEmitter emitter, HttpServletResponse response) { + req.setStream(true); + log.info("开始处理模型补全请求,参数: {}", req); try { - // 发起 HTTP POST 请求 - URL targetUrl; - if (StrUtil.isBlank(url)) { - log.info("URL 为空,使用默认 URL: {}", llmBackendProperties.getModelCompletions()); - targetUrl = new URL(llmBackendProperties.getModelCompletions()); - } else { - log.info("使用指定 URL: {}", url); - targetUrl = new URL(url); - } - HttpURLConnection connection = (HttpURLConnection) targetUrl.openConnection(); - connection.setRequestMethod("POST"); - connection.setRequestProperty("Content-Type", "application/json"); - connection.setDoOutput(true); - connection.getOutputStream().write(JSON.toJSONString(req).getBytes()); - - BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream())); - String line; - StringBuilder responseBuilder = new StringBuilder(); - while ((line = reader.readLine()) != null) { - // 解析流式响应内容 - if (StrUtil.isNotBlank(line)) { - log.info("收到流式响应内容: {}", line); - ChatCompletion chatCompletion = JSON.parseObject(line, ChatCompletion.class); - if (StrUtil.isBlank(chatCompletion.getDetail())) { - String respContent = chatCompletion.getChoices().get(0).getMessage().getContent(); - String patternString = "(.*?)"; - Pattern pattern = Pattern.compile(patternString, Pattern.DOTALL); - Matcher matcher = pattern.matcher(respContent); - String answerContent = matcher.replaceAll(""); - // 流式发送数据 - try { - emitter.send(SseEmitter.event().data(answerContent)); - log.info("已发送流式响应数据: {}", answerContent); - } catch (Exception e) { - log.error("发送流式响应数据时发生异常", e); - try { - emitter.completeWithError(e); - } catch (Exception ex) { - log.error("无法完成 SseEmitter 错误处理", ex); - } - return; - } - } - } - } - // 完成流式传输 - try { - emitter.complete(); - log.info("流式传输完成"); - } catch (Exception e) { - log.error("完成流式传输时发生异常", e); - } + log.info("开始处理模型补全请求,参数: {}", JSON.toJSONString(req)); + sendPostRequest(url, "{\"max_tokens\":4000,\"messages\":[{\"content\":\"\",\"role\":\"system\"},{\"content\":\"介绍一下天津\",\"role\":\"user\"}],\"model\":\"Qwen2.5-0.5B-Instruct\",\"temperature\":0.7,\"stream\":true}", emitter); } catch (Exception e) { - log.error("处理模型补全请求时发生异常", e); - try { - emitter.completeWithError(e); - } catch (Exception ex) { - log.error("无法完成 SseEmitter 错误处理", ex); - } + emitter.completeWithError(e); } } /** - * 通过 GPU 类型获取对应的主机地址 - *

ModelService GpuType 是 ServerName 表ID + * 解析流式返回的单行数据,提取有效内容并清理特定标记 * - * @param gpuType gpuType - * @return host + * @param line 流式响应中的单行JSON数据 + * @return 处理后的文本内容(若无有效内容返回null) */ - public String getHostByType (Long gpuType) { - return serverNameService.getServerName(gpuType).getHost(); + private String parseStreamLine (String line) { + if (StringUtils.isNotBlank(line)) { + if (line.startsWith("data: ")) { + String dataString = extractJsonFromDataString(line); + if (!dataString.contains("[DONE]")) { + JSONObject jsonObject = JSON.parseObject(dataString); + // 获取 choices 数组 + JSONArray choicesArray = jsonObject.getJSONArray("choices"); + if (choicesArray != null && !choicesArray.isEmpty()) { + // 获取第一个 choice 对象 + JSONObject firstChoice = choicesArray.getJSONObject(0); + // 获取 delta 对象 + JSONObject delta = firstChoice.getJSONObject("delta"); + if (delta != null) { + // 获取 content 的值 + String content = delta.getString("content"); + return "{\"content\":\"" + content + "\",\"finish_reason\":\"" + false + "\"}"; + } + } + } else { + return "{\"content\":\"" + "\",\"finish_reason\":\"" + true + "\"}"; + } + } + } + + return null; + } + + /** + * 从包含 "data: " 前缀的字符串中提取 JSON 字符串 + * + * @param input 包含 "data: " 前缀的原始字符串 + * @return 提取出的 JSON 字符串,如果未找到 "data: " 则返回 null + */ + public static String extractJsonFromDataString (String input) { + if (input == null) { + return null; + } + int index = input.indexOf("data: "); + if (index != -1) { + return input.substring(index + "data: ".length()).trim(); + } + return null; } /** @@ -335,4 +320,70 @@ public class ModelService { } return null; } + + /** + * 发送 POST 请求并处理响应 + * + * @param apiUrl 目标 API 的 URL + * @param requestBody 请求体内容 + * @throws IOException 发送请求或处理响应时可能抛出的 IO 异常 + */ + private void sendPostRequest (String apiUrl, String requestBody, SseEmitter emitter) throws IOException { + // 创建 HttpClient 实例 + HttpClient httpClient = HttpClients.createDefault(); + // 创建 HttpPost 请求对象 + HttpPost httpPost = new HttpPost(apiUrl); + + // 设置请求体和请求头 + setupRequest(httpPost, requestBody); + + // 执行 POST 请求并获取响应 + HttpResponse response = httpClient.execute(httpPost); + + // 处理响应实体 + handleResponseEntity(response, emitter); + } + + /** + * 设置请求体和请求头 + * + * @param httpPost HttpPost 请求对象 + * @param requestBody 请求体内容 + * @throws IOException 创建 StringEntity 时可能抛出的 IO 异常 + */ + private void setupRequest (HttpPost httpPost, String requestBody) throws IOException { + // 创建 StringEntity 对象,用于封装请求体 + StringEntity entity = new StringEntity(requestBody); + // 设置请求体 + httpPost.setEntity(entity); + // 设置请求头,指定请求体的内容类型为 JSON + httpPost.setHeader("Content-Type", "application/json"); + } + + /** + * 处理响应实体 + * + * @param response HttpResponse 对象 + * @throws IOException 读取响应实体输入流时可能抛出的 IO 异常 + */ + private void handleResponseEntity (HttpResponse response, SseEmitter emitter) throws IOException { + // 获取响应实体 + HttpEntity responseEntity = response.getEntity(); + if (responseEntity != null) { + // 使用 try-with-resources 语句自动关闭输入流和 BufferedReader + try (InputStream inputStream = responseEntity.getContent(); + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) { + String line; + // 逐行读取响应内容并处理 + while ((line = reader.readLine()) != null) { + System.out.println("接收到的响应行数据: " + line); + String content = parseStreamLine(line); + if (content != null) { + emitter.send(SseEmitter.event().data(content)); + } + } + } + } + } + } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/ModelCompletionsReqVO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/ModelCompletionsReqVO.java index a4d212b70..f66e5ee46 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/ModelCompletionsReqVO.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/ModelCompletionsReqVO.java @@ -17,6 +17,9 @@ public class ModelCompletionsReqVO { private List messages; private Integer max_tokens = 4000; private Double temperature = 0.7; +// private Integer max_tokens; +// private Double temperature; + private Boolean stream; // private Integer max_length = 120000; @Data From 071516df0c6e251d91a38cee90b57f4cc5040f0c Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Sun, 2 Mar 2025 10:31:18 +0800 Subject: [PATCH 11/12] =?UTF-8?q?refactor(llm):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=A1=A5=E5=85=A8=E6=B5=81=E5=BC=8F=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 优化了 ChatReqVO 和 DataRefluxDataSaveReqVO 中的数据类型 - 重构了 ConversationController 和 ConversationService 中的 chatStream 方法 - 重新实现了 ModelService 中的 modelCompletionsStream 方法,采用更高效的处理方式 - 新增了辅助方法 parseStreamLine、extractJsonFromDataString、setupRequest 和 handleResponseEntity 以提高代码可读性和可维护性 --- yudao-module-llm/yudao-module-llm-biz/pom.xml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/yudao-module-llm/yudao-module-llm-biz/pom.xml b/yudao-module-llm/yudao-module-llm-biz/pom.xml index e01a56f9b..7b395d1f1 100644 --- a/yudao-module-llm/yudao-module-llm-biz/pom.xml +++ b/yudao-module-llm/yudao-module-llm-biz/pom.xml @@ -40,7 +40,10 @@ cn.iocoder.boot yudao-spring-boot-starter-excel - + + org.springframework.boot + spring-boot-starter-websocket + cn.iocoder.boot From 6370cb223ec59028477438b311708c353d3ada92 Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Sun, 2 Mar 2025 10:35:07 +0800 Subject: [PATCH 12/12] =?UTF-8?q?refactor(yudao-module-llm):=20=E9=87=8D?= =?UTF-8?q?=E6=9E=84=E6=A8=A1=E5=9E=8B=E8=A1=A5=E5=85=A8=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E7=9A=84=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 sendPostRequest 方法从私有改为受保护的 - 优化了请求体的处理,使用 JSON.toJSONString 方法序列化请求对象 - 重构了 SseEmitter 的使用方式,提高了代码的可读性和可维护性 - 删除了冗余的私有方法,简化了代码结构 --- .../module/llm/service/http/ModelService.java | 135 +++++++++--------- 1 file changed, 66 insertions(+), 69 deletions(-) diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java index 11cc5cc27..f7d790d36 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/ModelService.java @@ -23,9 +23,6 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; import javax.servlet.http.HttpServletResponse; import java.io.*; -import java.net.HttpURLConnection; -import java.net.URL; -import java.nio.charset.StandardCharsets; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -150,12 +147,77 @@ public class ModelService { log.info("开始处理模型补全请求,参数: {}", req); try { log.info("开始处理模型补全请求,参数: {}", JSON.toJSONString(req)); - sendPostRequest(url, "{\"max_tokens\":4000,\"messages\":[{\"content\":\"\",\"role\":\"system\"},{\"content\":\"介绍一下天津\",\"role\":\"user\"}],\"model\":\"Qwen2.5-0.5B-Instruct\",\"temperature\":0.7,\"stream\":true}", emitter); + sendPostRequest(url, JSON.toJSONString(req), emitter); } catch (Exception e) { emitter.completeWithError(e); } } + /** + * 发送 POST 请求并处理响应 + * + * @param apiUrl 目标 API 的 URL + * @param requestBody 请求体内容 + * @throws IOException 发送请求或处理响应时可能抛出的 IO 异常 + */ + private void sendPostRequest (String apiUrl, String requestBody, SseEmitter emitter) throws IOException { + // 创建 HttpClient 实例 + HttpClient httpClient = HttpClients.createDefault(); + // 创建 HttpPost 请求对象 + HttpPost httpPost = new HttpPost(apiUrl); + + // 设置请求体和请求头 + setupRequest(httpPost, requestBody); + + // 执行 POST 请求并获取响应 + HttpResponse response = httpClient.execute(httpPost); + + // 处理响应实体 + handleResponseEntity(response, emitter); + } + + /** + * 设置请求体和请求头 + * + * @param httpPost HttpPost 请求对象 + * @param requestBody 请求体内容 + * @throws IOException 创建 StringEntity 时可能抛出的 IO 异常 + */ + private void setupRequest (HttpPost httpPost, String requestBody) throws IOException { + // 创建 StringEntity 对象,用于封装请求体 + StringEntity entity = new StringEntity(requestBody); + // 设置请求体 + httpPost.setEntity(entity); + // 设置请求头,指定请求体的内容类型为 JSON + httpPost.setHeader("Content-Type", "application/json"); + } + + /** + * 处理响应实体 + * + * @param response HttpResponse 对象 + * @throws IOException 读取响应实体输入流时可能抛出的 IO 异常 + */ + private void handleResponseEntity (HttpResponse response, SseEmitter emitter) throws IOException { + // 获取响应实体 + HttpEntity responseEntity = response.getEntity(); + if (responseEntity != null) { + // 使用 try-with-resources 语句自动关闭输入流和 BufferedReader + try (InputStream inputStream = responseEntity.getContent(); + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) { + String line; + // 逐行读取响应内容并处理 + while ((line = reader.readLine()) != null) { + System.out.println("接收到的响应行数据: " + line); + String content = parseStreamLine(line); + if (content != null) { + emitter.send(SseEmitter.event().data(content)); + } + } + } + } + } + /** * 解析流式返回的单行数据,提取有效内容并清理特定标记 * @@ -321,69 +383,4 @@ public class ModelService { return null; } - /** - * 发送 POST 请求并处理响应 - * - * @param apiUrl 目标 API 的 URL - * @param requestBody 请求体内容 - * @throws IOException 发送请求或处理响应时可能抛出的 IO 异常 - */ - private void sendPostRequest (String apiUrl, String requestBody, SseEmitter emitter) throws IOException { - // 创建 HttpClient 实例 - HttpClient httpClient = HttpClients.createDefault(); - // 创建 HttpPost 请求对象 - HttpPost httpPost = new HttpPost(apiUrl); - - // 设置请求体和请求头 - setupRequest(httpPost, requestBody); - - // 执行 POST 请求并获取响应 - HttpResponse response = httpClient.execute(httpPost); - - // 处理响应实体 - handleResponseEntity(response, emitter); - } - - /** - * 设置请求体和请求头 - * - * @param httpPost HttpPost 请求对象 - * @param requestBody 请求体内容 - * @throws IOException 创建 StringEntity 时可能抛出的 IO 异常 - */ - private void setupRequest (HttpPost httpPost, String requestBody) throws IOException { - // 创建 StringEntity 对象,用于封装请求体 - StringEntity entity = new StringEntity(requestBody); - // 设置请求体 - httpPost.setEntity(entity); - // 设置请求头,指定请求体的内容类型为 JSON - httpPost.setHeader("Content-Type", "application/json"); - } - - /** - * 处理响应实体 - * - * @param response HttpResponse 对象 - * @throws IOException 读取响应实体输入流时可能抛出的 IO 异常 - */ - private void handleResponseEntity (HttpResponse response, SseEmitter emitter) throws IOException { - // 获取响应实体 - HttpEntity responseEntity = response.getEntity(); - if (responseEntity != null) { - // 使用 try-with-resources 语句自动关闭输入流和 BufferedReader - try (InputStream inputStream = responseEntity.getContent(); - BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) { - String line; - // 逐行读取响应内容并处理 - while ((line = reader.readLine()) != null) { - System.out.println("接收到的响应行数据: " + line); - String content = parseStreamLine(line); - if (content != null) { - emitter.send(SseEmitter.event().data(content)); - } - } - } - } - } - }