diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/basemodel/BaseModelTaskService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/basemodel/BaseModelTaskService.java index 9b7a39862..3883862e7 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/basemodel/BaseModelTaskService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/basemodel/BaseModelTaskService.java @@ -44,7 +44,7 @@ public class BaseModelTaskService { Long modelId = baseModelDO.getModelId(); String query = "?filter={\"id\":" + modelId + "}"; - String res = trainHttpService.modelTableQuery(new HashMap<>(), "model_deploy", query); + String res = trainHttpService.modelTableQuery(new HashMap<>(), "","model_deploy", query); ObjectMapper mapper = new ObjectMapper(); mapper.registerModule(new JavaTimeModule()); diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java index 57fbac1b9..4c7d860af 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskSyncService.java @@ -1,15 +1,20 @@ package cn.iocoder.yudao.module.llm.service.finetuningtask; import cn.iocoder.yudao.module.llm.dal.dataobject.finetuningtask.FineTuningTaskDO; +import cn.iocoder.yudao.module.llm.dal.dataobject.servername.ServerNameDO; import cn.iocoder.yudao.module.llm.dal.mysql.finetuningtask.FineTuningTaskMapper; +import cn.iocoder.yudao.module.llm.dal.mysql.servername.ServerNameMapper; import cn.iocoder.yudao.module.llm.enums.FineTuningTaskStatusConstants; import cn.iocoder.yudao.module.llm.enums.FinetuningTaskStatusEnum; import cn.iocoder.yudao.module.llm.service.basemodel.vo.ModelListRes; import cn.iocoder.yudao.module.llm.service.http.TrainHttpService; import cn.iocoder.yudao.module.llm.service.http.vo.AigcFineTuningDetailRespVO; +import cn.iocoder.yudao.module.llm.service.http.vo.AigcModelDeployVO; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import com.esotericsoftware.minlog.Log; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import lombok.extern.slf4j.Slf4j; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Component; @@ -35,6 +40,8 @@ public class FineTuningTaskSyncService { @Resource FineTuningTaskMapper fineTuningTaskMapper; + @Resource + ServerNameMapper serverNameMapper; @Scheduled(cron = "0 */3 * * * ?") public void updateFineTuningTaskStatus() { @@ -44,12 +51,27 @@ public class FineTuningTaskSyncService { if(Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.TRAINING.getStatus()) || Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.WAITING.getStatus())){ - AigcFineTuningDetailRespVO resp = trainHttpService.finetuningDetail(new HashMap<>(), fineTuningTaskDO.getJobId()); + ServerNameDO serverNameDO = serverNameMapper.selectById(fineTuningTaskDO.getGpuType()); + String hostUrl = serverNameDO!=null ?serverNameDO.getHost():""; + String queryJobs = "?filter={\"job_id\":\""+fineTuningTaskDO.getJobId()+"\"}"; + String respJobs = trainHttpService.modelTableQuery(new HashMap<>(), hostUrl,"fine_tuning_train_job",queryJobs); + AigcFineTuningDetailRespVO resp = new AigcFineTuningDetailRespVO(); + try { + ObjectMapper mapper = new ObjectMapper(); + List aigcFineTuningDetailRespVO = mapper.readValue(respJobs,new TypeReference>() {}); + if (aigcFineTuningDetailRespVO != null && aigcFineTuningDetailRespVO.size() > 0){ + if(aigcFineTuningDetailRespVO.get(0).getJob_id().equals(fineTuningTaskDO.getJobId())){ + resp = aigcFineTuningDetailRespVO.get(0); + } + } + }catch (Exception e){ + log.error("获取微调任务状态失败{}",e.getMessage()); + } if (resp == null){ continue; } FineTuningTaskDO updateObj = new FineTuningTaskDO(); - Integer status = FineTuningTaskStatusConstants.getStatus(resp.getTrainStatus()); + Integer status = FineTuningTaskStatusConstants.getStatus(resp.getTrain_status()); if(status != null){ updateObj.setId(fineTuningTaskDO.getId()); } @@ -83,8 +105,8 @@ public class FineTuningTaskSyncService { }*/ updateObj.setStatus(2); // 获取模型id - String querModels = "?filter={\"model_name\":\""+resp.getFineTunedModel()+"\"}"; - String resModels = trainHttpService.modelTableQuery(new HashMap<>(), "models",querModels); + String querModels = "?filter={\"model_name\":\""+resp.getFine_tuned_model()+"\"}"; + String resModels = trainHttpService.modelTableQuery(new HashMap<>(),hostUrl, "models",querModels); log.info("获取 aigc models 表数据 info {}",resModels); JSONArray jsonArrayModels = JSONArray.parseArray(resModels); @@ -106,7 +128,7 @@ public class FineTuningTaskSyncService { //获取检查点信息 //todo 模型工厂的功能有问题,暂时写死 jobModelName = "Qwen2.5-0.5B-Instruct-147"; - String checkFileList = trainHttpService.getCheckFileList(jobModelName); + String checkFileList = trainHttpService.getCheckFileList(hostUrl,jobModelName); List checkpoints = new ArrayList<>(); List fileUrls = new ArrayList<>(); List fileList = JSONArray.parseArray(checkFileList,String.class); @@ -120,7 +142,7 @@ public class FineTuningTaskSyncService { for (String checkpoint : checkpoints) { String filePath = "/" + checkpoint + "/trainer_state.json"; fileUrls.add(filePath); - String fileUrl = trainHttpService.getCheckFile(jobModelName, filePath); + String fileUrl = trainHttpService.getCheckFile(hostUrl,jobModelName, filePath); try { URL url = new URL(fileUrl); URLConnection urlConnection = url.openConnection(); diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/TrainHttpService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/TrainHttpService.java index 1e86a4386..1d6f98066 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/TrainHttpService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/TrainHttpService.java @@ -165,8 +165,8 @@ public class TrainHttpService { /** * 根据表名称查询数据 */ - public String modelTableQuery(Map headers, String tableName,String query){ - String url = String.format(llmBackendProperties.getTableDataQuery(),tableName); + public String modelTableQuery(Map headers, String urlHost,String tableName,String query){ + String url = String.format(urlHost+llmBackendProperties.getTableDataQuery(),tableName); url = url+query; String res = HttpUtils.get(url, headers); log.info(" model query info :{}", res); @@ -337,14 +337,14 @@ public class TrainHttpService { return HttpUtils.get(url+finetuningLog+params, null); } - public String getCheckFileList(String name){ + public String getCheckFileList(String url,String name){ String checkFileList = llmBackendProperties.getCheckFileList(); - return HttpUtils.get(checkFileList + name, null); + return HttpUtils.get(url+checkFileList + name, null); } - public String getCheckFile(String name,String path){ + public String getCheckFile(String url,String name,String path){ String checkFileList = llmBackendProperties.getCheckFileList(); - return checkFileList + name + "&&file_path=" + path; + return url+checkFileList + name + "&&file_path=" + path; } } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/AigcFineTuningDetailRespVO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/AigcFineTuningDetailRespVO.java index cf38a5771..c562596cc 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/AigcFineTuningDetailRespVO.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/vo/AigcFineTuningDetailRespVO.java @@ -15,33 +15,58 @@ import java.util.List; public class AigcFineTuningDetailRespVO { private int id; - private String jobId; - private String baseModel; - private int trainEpoch; - private String trainStatus; - private String trainDuration; + private String job_id; + private String base_model; + private int train_epoch; + private String train_status; + private String train_duration; private int process; - private String fineTunedModel; + private String fine_tuned_model; private String remark; - private String finishedAt; - private ZonedDateTime createdAt; - private String trainPublisher; - private String trainLog; - private String errorMessage; - private boolean lora; - private TrainAnalysis trainAnalysis; + private String finished_at; + private String created_at; + private String train_publisher; + private String train_log; + private String error_message; + private String lora; + private TrainAnalysis train_analysis; private String suffix; - private int modelMaxLength; - private int trainBatchSize; - private String learningRate; - private String fileUrl; - private String fileId; - private String startTrainTime; - private int procPerNode; - private int evalBatchSize; - private int accumulationSteps; + private int model_max_length; + private int train_batch_size; + private String learning_rate; + private String file_url; + private String file_id; + private String start_train_time; + private int proc_per_node; + private int eval_batch_size; + private int accumulation_steps; private String scenario; private Diagnosis diagnosis; + private String deleted_at; + private String paas_job_name; + private String weight_decay; + private String data_path; + private String progress; + private String warmup_ratio; + private String output_dir; + private String progress_epochs; + private String channel_id; + private String logging_steps; + private String script_file; + private String progress_loss; + private String template_id; + private String master_port; + private String progress_learning_rate; + private String eval_steps; + private String tenant_id; + private String save_steps; + private String base_model_path; + private String train_script; + private String updated_at; + private String save_total_limit; + private String validation_file; + + // Nested classes for TrainAnalysis and Diagnosis diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/modelservice/ModelServiceTaskSyncService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/modelservice/ModelServiceTaskSyncService.java index c36db3bf2..c9d53bad6 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/modelservice/ModelServiceTaskSyncService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/modelservice/ModelServiceTaskSyncService.java @@ -4,7 +4,9 @@ import cn.hutool.json.JSONObject; import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils; import cn.iocoder.yudao.module.llm.controller.admin.dataset.dto.DataJsonTemplate; import cn.iocoder.yudao.module.llm.dal.dataobject.modelservice.ModelServiceDO; +import cn.iocoder.yudao.module.llm.dal.dataobject.servername.ServerNameDO; import cn.iocoder.yudao.module.llm.dal.mysql.modelservice.ModelServiceMapper; +import cn.iocoder.yudao.module.llm.dal.mysql.servername.ServerNameMapper; import cn.iocoder.yudao.module.llm.enums.ModelDeployConstantEnum; import cn.iocoder.yudao.module.llm.handler.AigcCustomDateTimeDeserializer; import cn.iocoder.yudao.module.llm.service.http.TrainHttpService; @@ -44,6 +46,9 @@ public class ModelServiceTaskSyncService { @Resource private ModelServiceMapper modelServiceMapper; + @Resource + ServerNameMapper serverNameMapper; + @Scheduled(cron = "0 */2 * * * ?") @@ -63,7 +68,9 @@ public class ModelServiceTaskSyncService { // 使用 TypeReference 解析 JSON 字符串为 List try { String query = "?filter={\"id\":"+jobid+"}"; - String res = trainHttpService.modelTableQuery(new HashMap<>(), "model_deploy",query); + ServerNameDO serverNameDO = serverNameMapper.selectById(modelServiceDO.getGpuType()); + String hostUrl = serverNameDO!=null ?serverNameDO.getHost():""; + String res = trainHttpService.modelTableQuery(new HashMap<>(), hostUrl,"model_deploy",query); log.info("获取 aigc model_deploy 表数据 info {}",res); ObjectMapper mapper = new ObjectMapper(); mapper.registerModule(new JavaTimeModule());