From 0823596f97ab37014a0305f2f70a878911cdbbae Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Sat, 1 Mar 2025 13:17:39 +0800 Subject: [PATCH] =?UTF-8?q?refactor(module-llm):=20-=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E5=BE=AE=E8=B0=83=E4=BB=BB=E5=8A=A1=E7=8A=B6=E6=80=81=E5=90=8C?= =?UTF-8?q?=E6=AD=A5=E6=9C=8D=E5=8A=A1=20-=20=E4=BC=98=E5=8C=96=E4=BA=86?= =?UTF-8?q?=E5=BE=AE=E8=B0=83=E4=BB=BB=E5=8A=A1=E7=8A=B6=E6=80=81=E5=90=8C?= =?UTF-8?q?=E6=AD=A5=E5=AE=9A=E6=97=B6=E4=BB=BB=E5=8A=A1=E7=9A=84=E6=89=A7?= =?UTF-8?q?=E8=A1=8C=E9=80=BB=E8=BE=91=20-=20=E5=A2=9E=E5=8A=A0=E4=BA=86?= =?UTF-8?q?=E6=97=A5=E5=BF=97=E8=AE=B0=E5=BD=95=20-=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E4=BA=86=E6=A3=80=E6=9F=A5=E7=82=B9=E4=BF=A1=E6=81=AF=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../backend/config/LLMBackendProperties.java | 3 + .../FineTuningTaskSyncService.java | 232 +++++++++++------- 2 files changed, 153 insertions(+), 82 deletions(-) diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/framework/backend/config/LLMBackendProperties.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/framework/backend/config/LLMBackendProperties.java index 9745977ca..245f845ea 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/framework/backend/config/LLMBackendProperties.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/framework/backend/config/LLMBackendProperties.java @@ -119,5 +119,8 @@ public class LLMBackendProperties { private String embedQuery; + /** + * 获取调优检查点列表 + */ private String checkFileList; } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java index 3a2ef34bd..16719484b 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java @@ -7,14 +7,12 @@ import cn.iocoder.yudao.module.llm.dal.mysql.finetuningtask.FineTuningTaskMapper import cn.iocoder.yudao.module.llm.dal.mysql.servername.ServerNameMapper; import cn.iocoder.yudao.module.llm.enums.FineTuningTaskStatusConstants; import cn.iocoder.yudao.module.llm.enums.FinetuningTaskStatusEnum; -import cn.iocoder.yudao.module.llm.service.basemodel.vo.ModelListRes; import cn.iocoder.yudao.module.llm.service.http.FineTuningTaskHttpService; import cn.iocoder.yudao.module.llm.service.http.TrainHttpService; import cn.iocoder.yudao.module.llm.service.http.vo.AigcFineTuningDetailRespVO; -import cn.iocoder.yudao.module.llm.service.http.vo.AigcModelDeployVO; +import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; -import com.esotericsoftware.minlog.Log; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import lombok.extern.slf4j.Slf4j; @@ -26,7 +24,6 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; -import java.net.MalformedURLException; import java.net.URL; import java.net.URLConnection; import java.time.Duration; @@ -49,46 +46,78 @@ public class FineTuningTaskSyncService { @Resource private FineTuningTaskHttpService fineTuningTaskHttpService; - @Scheduled(cron = "0 */1 * * * ?") - public void updateFineTuningTaskStatus() { - Log.info("FineTuningTaskSync 定时任务启动"); - List fineTuningTaskDOList = fineTuningTaskMapper.selectList(); - for (FineTuningTaskDO fineTuningTaskDO : fineTuningTaskDOList) { + private static final String FINE_TUNING_LOG = "~~~ fine - tuning ~~~ : "; - if(Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.TRAINING.getStatus()) - || Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.WAITING.getStatus())){ - ServerNameDO serverNameDO = serverNameMapper.selectById(fineTuningTaskDO.getGpuType()); - if (fineTuningTaskDO.getJobId() == null){ + @Scheduled(cron = "0/20 * * * * ?") + public void updateFineTuningTaskStatus () { + // 方法入口:微调任务状态同步定时任务 + log.info("{} [定时任务] FineTuningTaskSync 启动 - 开始同步微调任务状态", FINE_TUNING_LOG); + + // 1. 查询所有待处理的微调任务 + log.debug("正在查询所有微调任务..."); + List fineTuningTaskDOList = fineTuningTaskMapper.selectList(); + log.info("{} [任务查询] 共获取到{}个微调任务。", FINE_TUNING_LOG, fineTuningTaskDOList.size()); + + for (FineTuningTaskDO fineTuningTaskDO : fineTuningTaskDOList) { + log.info("{} [任务处理] 开始处理任务DO:{}", FINE_TUNING_LOG, JSON.toJSONString(fineTuningTaskDO)); + + // 检查任务状态是否为训练中或等待中 + if (Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.TRAINING.getStatus()) || Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.WAITING.getStatus())) { + log.info("{} [状态更新] 检测到需更新状态的任务ID:{},当前状态:{}", + FINE_TUNING_LOG, fineTuningTaskDO.getId(), fineTuningTaskDO.getStatus()); + + // 3. 获取GPU服务器信息 + + String hostUrl = getHostUrl(fineTuningTaskDO); + + // 4. 检查任务是否有 Job ID + if (fineTuningTaskDO.getJobId() == null) { + log.warn("{} 微调任务未关联 Job ID,任务ID: {}", FINE_TUNING_LOG, fineTuningTaskDO.getId()); continue; } - String hostUrl = serverNameDO!=null ?serverNameDO.getHost():""; - String queryJobs = "?filter={\"job_id\":\""+fineTuningTaskDO.getJobId()+"\"}"; - String respJobs = fineTuningTaskHttpService.modelTableQuery(new HashMap<>(), hostUrl,"fine_tuning_train_job",queryJobs); + + // 5. 构建API查询参数 + String queryJobs = "?filter={\"job_id\":\"" + fineTuningTaskDO.getJobId() + "\"}"; + log.info("{} [API请求] 发起任务查询 job_id={} 查询参数: {}", FINE_TUNING_LOG, fineTuningTaskDO.getJobId(), queryJobs); + + // 6. 调用HTTP服务获取任务详情 + String respJobs = fineTuningTaskHttpService.modelTableQuery( + new HashMap<>(), hostUrl, "fine_tuning_train_job", queryJobs); + log.info("{} [API响应] 任务ID:{} 原始响应数据:{}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), respJobs); + + // 7. 解析API响应数据 AigcFineTuningDetailRespVO resp = new AigcFineTuningDetailRespVO(); + try { + ObjectMapper mapper = new ObjectMapper(); - List aigcFineTuningDetailRespVO = mapper.readValue(respJobs,new TypeReference>() {}); - if (aigcFineTuningDetailRespVO != null && aigcFineTuningDetailRespVO.size() > 0){ - if(aigcFineTuningDetailRespVO.get(0).getJob_id().equals(fineTuningTaskDO.getJobId())){ - resp = aigcFineTuningDetailRespVO.get(0); - } + List respList = mapper.readValue( + respJobs, new TypeReference>() { + }); + + if (respList != null && !respList.isEmpty()) { + resp = respList.get(0); + log.info("{} [状态更新] 任务ID:{},当前状态:{},API响应状态:{}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), fineTuningTaskDO.getStatus(), resp.getTrain_status()); + } - }catch (Exception e){ - log.error("获取微调任务状态失败{}",e.getMessage()); + } catch (Exception e) { + log.error("{} 解析微调任务状态时发生异常。任务ID: {}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), e); } - if (resp == null){ + + if (resp == null) { continue; } + FineTuningTaskDO updateObj = new FineTuningTaskDO(); - if (ObjectUtil.isAllEmpty(resp.getTrain_status())){ + if (ObjectUtil.isAllEmpty(resp.getTrain_status())) { continue; } Integer status = FineTuningTaskStatusConstants.getStatus(resp.getTrain_status()); - if(status != null){ + if (status != null) { updateObj.setId(fineTuningTaskDO.getId()); } //如果微调任务由训练中变为已完成,则取模型列表中获取该任务生成的模型id - if((fineTuningTaskDO.getStatus() == 1 && status == 2) || (fineTuningTaskDO.getStatus() == 3 && status == 2)){ + if ((fineTuningTaskDO.getStatus() == 1 && status == 2) || (fineTuningTaskDO.getStatus() == 3 && status == 2)) { String jobModelName = fineTuningTaskDO.getJobModelName(); /* String resStr = trainHttpService.modelsList(""); Log.info("获取aicg模型列表返回数据内容{}",resStr); @@ -117,82 +146,121 @@ public class FineTuningTaskSyncService { }*/ updateObj.setStatus(2); // 获取模型id - String querModels = "?filter={\"model_name\":\""+resp.getFine_tuned_model()+"\"}"; - String resModels = fineTuningTaskHttpService.modelTableQuery(new HashMap<>(),hostUrl, "models",querModels); - log.info("获取 aigc models 表数据 info {}",resModels); + String querModels = "?filter={\"model_name\":\"" + resp.getFine_tuned_model() + "\"}"; + String resModels = fineTuningTaskHttpService.modelTableQuery(new HashMap<>(), hostUrl, "models", querModels); + log.info("获取 aigc models 表数据 info {}", resModels); JSONArray jsonArrayModels = JSONArray.parseArray(resModels); // 时长获取 - Duration duration = Duration.between(fineTuningTaskDO.getCreateTime(),LocalDateTime.now()); + Duration duration = Duration.between(fineTuningTaskDO.getCreateTime(), LocalDateTime.now()); long minutes = duration.toMinutes(); updateObj.setTrainDuration(String.valueOf(minutes)); - if(jsonArrayModels.size() > 0){ + if (jsonArrayModels.size() > 0) { JSONObject jsonObjectModels = jsonArrayModels.getJSONObject(0); updateObj.setJobModelListId(jsonObjectModels.getLong("id")); } - }catch (Exception e){ - log.error(" error {}",e.getMessage()); + } catch (Exception e) { + log.error(" error {}", e.getMessage()); } - try { - //获取检查点信息 - //todo 模型工厂的功能有问题,暂时写死 -// jobModelName = "Qwen2.5-0.5B-Instruct-147"; - String checkFileList = trainHttpService.getCheckFileList(hostUrl,jobModelName); - List checkpoints = new ArrayList<>(); - List fileUrls = new ArrayList<>(); - List fileList = JSONArray.parseArray(checkFileList,String.class); - Map map = new HashMap<>(); - for (String s : fileList) { - if(s.contains("checkpoint")){ - checkpoints.add(s); - } - } - if(checkpoints.size() > 0){ - for (String checkpoint : checkpoints) { - String filePath = "/" + checkpoint + "/trainer_state.json"; - fileUrls.add(filePath); - String fileUrl = trainHttpService.getCheckFile(hostUrl,jobModelName, filePath); - try { - URL url = new URL(fileUrl); - URLConnection urlConnection = url.openConnection(); - InputStream inputStream = urlConnection.getInputStream(); - BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); - String line; - StringBuilder content = new StringBuilder(); - while ((line = reader.readLine()) != null) { - content.append(line + "\n"); - } - reader.close(); - map.put(checkpoint,JSONObject.parseObject(content.toString())); - } catch (MalformedURLException e) { - e.printStackTrace(); - } catch (IOException e) { - e.printStackTrace(); - } - } - } - - updateObj.setCheckPointFilePath(JSONObject.toJSONString(fileUrls)); - updateObj.setCheckPointData(JSONObject.toJSONString(map)); - - }catch (Exception e){ - log.error(" error {}",e.getMessage()); - } + getCheckPoint(fineTuningTaskDO, jobModelName, hostUrl, updateObj); fineTuningTaskMapper.updateById(updateObj); - } - else { + } else { updateObj.setStatus(status); } - fineTuningTaskMapper.updateById(updateObj); + fineTuningTaskMapper.updateById(updateObj); + } else { + // 处理已成功但是没有获取到检查点错误 + String jobModelName = fineTuningTaskDO.getJobModelName(); + String hostUrl = getHostUrl(fineTuningTaskDO); + FineTuningTaskDO updateObj = new FineTuningTaskDO(); + updateObj.setId(fineTuningTaskDO.getId()); + getCheckPoint(fineTuningTaskDO, jobModelName, hostUrl, updateObj); + fineTuningTaskMapper.updateById(updateObj); } } } + private String getHostUrl (FineTuningTaskDO fineTuningTaskDO) { + ServerNameDO serverNameDO = serverNameMapper.selectById(fineTuningTaskDO.getGpuType()); + String hostUrl = (serverNameDO != null) ? serverNameDO.getHost() : ""; + log.info("{} [主机信息] 任务ID:{} GPU 类型: {} 关联主机URL:{}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), fineTuningTaskDO.getGpuType(), hostUrl); + return hostUrl; + } + + private void getCheckPoint (FineTuningTaskDO fineTuningTaskDO, String jobModelName, String hostUrl, FineTuningTaskDO updateObj) { + // 获取检查点信息 + try { + // String jobModelName = fineTuningTaskDO.getJobModelName(); + log.info("正在获取检查点信息,模型名称: {}", jobModelName); + String checkFileList = trainHttpService.getCheckFileList(hostUrl, jobModelName); + List checkpoints = new ArrayList<>(); + List fileUrls = new ArrayList<>(); + Map map = new HashMap<>(); + + // List fileList = JSONArray.parseArray(checkFileList, String.class); + // for (String s : fileList) { + // if (s.contains("checkpoint")) { + // checkpoints.add(s); + // } + // } + // 判断是否是数组 + if (checkFileList.startsWith("[")) { + List fileList = JSONArray.parseArray(checkFileList, String.class); + // 处理文件列表 + for (String s : fileList) { + if (s.contains("checkpoint")) { + checkpoints.add(s); + } + } + } else { + // TODO 处理错误信息 + JSONObject errorObj = JSON.parseObject(checkFileList); + String errorMsg = errorObj.getString("error"); + // 抛出异常或记录日志 + log.error("获取检查点文件列表时出错:{}", errorMsg); + } + + if (!checkpoints.isEmpty()) { + log.info("找到 {} 个检查点文件。", checkpoints.size()); + for (String checkpoint : checkpoints) { + String filePath = "/" + checkpoint + "/trainer_state.json"; + fileUrls.add(filePath); + String fileUrl = trainHttpService.getCheckFile(hostUrl, jobModelName, filePath); + + try { + URL url = new URL(fileUrl); + URLConnection urlConnection = url.openConnection(); + InputStream inputStream = urlConnection.getInputStream(); + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); + StringBuilder content = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + content.append(line).append("\n"); + } + reader.close(); + map.put(checkpoint, JSONObject.parseObject(content.toString())); + log.info("检查点文件解析成功。检查点: {}", checkpoint); + } catch (IOException e) { + log.error("读取检查点文件时发生异常。任务ID: {}", fineTuningTaskDO.getId(), e); + } + } + } else { + log.warn("未找到检查点文件,任务ID: {}", fineTuningTaskDO.getId()); + } + + updateObj.setCheckPointFilePath(JSONObject.toJSONString(fileUrls)); + updateObj.setCheckPointData(JSONObject.toJSONString(map)); + log.info("检查点信息更新完成。"); + } catch (Exception e) { + log.error("获取检查点信息时发生异常。任务ID: {}", fineTuningTaskDO.getId(), e); + } + } + }