修改模型调优

This commit is contained in:
limin 2025-02-18 14:52:25 +08:00
parent fded71ba43
commit 1ca41b7cf6
5 changed files with 90 additions and 36 deletions

View File

@ -44,7 +44,7 @@ public class BaseModelTaskService {
Long modelId = baseModelDO.getModelId();
String query = "?filter={\"id\":" + modelId + "}";
String res = trainHttpService.modelTableQuery(new HashMap<>(), "model_deploy", query);
String res = trainHttpService.modelTableQuery(new HashMap<>(), "","model_deploy", query);
ObjectMapper mapper = new ObjectMapper();
mapper.registerModule(new JavaTimeModule());

View File

@ -1,15 +1,20 @@
package cn.iocoder.yudao.module.llm.service.finetuningtask;
import cn.iocoder.yudao.module.llm.dal.dataobject.finetuningtask.FineTuningTaskDO;
import cn.iocoder.yudao.module.llm.dal.dataobject.servername.ServerNameDO;
import cn.iocoder.yudao.module.llm.dal.mysql.finetuningtask.FineTuningTaskMapper;
import cn.iocoder.yudao.module.llm.dal.mysql.servername.ServerNameMapper;
import cn.iocoder.yudao.module.llm.enums.FineTuningTaskStatusConstants;
import cn.iocoder.yudao.module.llm.enums.FinetuningTaskStatusEnum;
import cn.iocoder.yudao.module.llm.service.basemodel.vo.ModelListRes;
import cn.iocoder.yudao.module.llm.service.http.TrainHttpService;
import cn.iocoder.yudao.module.llm.service.http.vo.AigcFineTuningDetailRespVO;
import cn.iocoder.yudao.module.llm.service.http.vo.AigcModelDeployVO;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.esotericsoftware.minlog.Log;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
@ -35,6 +40,8 @@ public class FineTuningTaskSyncService {
@Resource
FineTuningTaskMapper fineTuningTaskMapper;
@Resource
ServerNameMapper serverNameMapper;
@Scheduled(cron = "0 */3 * * * ?")
public void updateFineTuningTaskStatus() {
@ -44,12 +51,27 @@ public class FineTuningTaskSyncService {
if(Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.TRAINING.getStatus())
|| Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.WAITING.getStatus())){
AigcFineTuningDetailRespVO resp = trainHttpService.finetuningDetail(new HashMap<>(), fineTuningTaskDO.getJobId());
ServerNameDO serverNameDO = serverNameMapper.selectById(fineTuningTaskDO.getGpuType());
String hostUrl = serverNameDO!=null ?serverNameDO.getHost():"";
String queryJobs = "?filter={\"job_id\":\""+fineTuningTaskDO.getJobId()+"\"}";
String respJobs = trainHttpService.modelTableQuery(new HashMap<>(), hostUrl,"fine_tuning_train_job",queryJobs);
AigcFineTuningDetailRespVO resp = new AigcFineTuningDetailRespVO();
try {
ObjectMapper mapper = new ObjectMapper();
List<AigcFineTuningDetailRespVO> aigcFineTuningDetailRespVO = mapper.readValue(respJobs,new TypeReference<List<AigcFineTuningDetailRespVO>>() {});
if (aigcFineTuningDetailRespVO != null && aigcFineTuningDetailRespVO.size() > 0){
if(aigcFineTuningDetailRespVO.get(0).getJob_id().equals(fineTuningTaskDO.getJobId())){
resp = aigcFineTuningDetailRespVO.get(0);
}
}
}catch (Exception e){
log.error("获取微调任务状态失败{}",e.getMessage());
}
if (resp == null){
continue;
}
FineTuningTaskDO updateObj = new FineTuningTaskDO();
Integer status = FineTuningTaskStatusConstants.getStatus(resp.getTrainStatus());
Integer status = FineTuningTaskStatusConstants.getStatus(resp.getTrain_status());
if(status != null){
updateObj.setId(fineTuningTaskDO.getId());
}
@ -83,8 +105,8 @@ public class FineTuningTaskSyncService {
}*/
updateObj.setStatus(2);
// 获取模型id
String querModels = "?filter={\"model_name\":\""+resp.getFineTunedModel()+"\"}";
String resModels = trainHttpService.modelTableQuery(new HashMap<>(), "models",querModels);
String querModels = "?filter={\"model_name\":\""+resp.getFine_tuned_model()+"\"}";
String resModels = trainHttpService.modelTableQuery(new HashMap<>(),hostUrl, "models",querModels);
log.info("获取 aigc models 表数据 info {}",resModels);
JSONArray jsonArrayModels = JSONArray.parseArray(resModels);
@ -106,7 +128,7 @@ public class FineTuningTaskSyncService {
//获取检查点信息
//todo 模型工厂的功能有问题暂时写死
jobModelName = "Qwen2.5-0.5B-Instruct-147";
String checkFileList = trainHttpService.getCheckFileList(jobModelName);
String checkFileList = trainHttpService.getCheckFileList(hostUrl,jobModelName);
List<String> checkpoints = new ArrayList<>();
List<String> fileUrls = new ArrayList<>();
List<String> fileList = JSONArray.parseArray(checkFileList,String.class);
@ -120,7 +142,7 @@ public class FineTuningTaskSyncService {
for (String checkpoint : checkpoints) {
String filePath = "/" + checkpoint + "/trainer_state.json";
fileUrls.add(filePath);
String fileUrl = trainHttpService.getCheckFile(jobModelName, filePath);
String fileUrl = trainHttpService.getCheckFile(hostUrl,jobModelName, filePath);
try {
URL url = new URL(fileUrl);
URLConnection urlConnection = url.openConnection();

View File

@ -165,8 +165,8 @@ public class TrainHttpService {
/**
* 根据表名称查询数据
*/
public String modelTableQuery(Map<String, String> headers, String tableName,String query){
String url = String.format(llmBackendProperties.getTableDataQuery(),tableName);
public String modelTableQuery(Map<String, String> headers, String urlHost,String tableName,String query){
String url = String.format(urlHost+llmBackendProperties.getTableDataQuery(),tableName);
url = url+query;
String res = HttpUtils.get(url, headers);
log.info(" model query info :{}", res);
@ -337,14 +337,14 @@ public class TrainHttpService {
return HttpUtils.get(url+finetuningLog+params, null);
}
public String getCheckFileList(String name){
public String getCheckFileList(String url,String name){
String checkFileList = llmBackendProperties.getCheckFileList();
return HttpUtils.get(checkFileList + name, null);
return HttpUtils.get(url+checkFileList + name, null);
}
public String getCheckFile(String name,String path){
public String getCheckFile(String url,String name,String path){
String checkFileList = llmBackendProperties.getCheckFileList();
return checkFileList + name + "&&file_path=" + path;
return url+checkFileList + name + "&&file_path=" + path;
}
}

View File

@ -15,33 +15,58 @@ import java.util.List;
public class AigcFineTuningDetailRespVO {
private int id;
private String jobId;
private String baseModel;
private int trainEpoch;
private String trainStatus;
private String trainDuration;
private String job_id;
private String base_model;
private int train_epoch;
private String train_status;
private String train_duration;
private int process;
private String fineTunedModel;
private String fine_tuned_model;
private String remark;
private String finishedAt;
private ZonedDateTime createdAt;
private String trainPublisher;
private String trainLog;
private String errorMessage;
private boolean lora;
private TrainAnalysis trainAnalysis;
private String finished_at;
private String created_at;
private String train_publisher;
private String train_log;
private String error_message;
private String lora;
private TrainAnalysis train_analysis;
private String suffix;
private int modelMaxLength;
private int trainBatchSize;
private String learningRate;
private String fileUrl;
private String fileId;
private String startTrainTime;
private int procPerNode;
private int evalBatchSize;
private int accumulationSteps;
private int model_max_length;
private int train_batch_size;
private String learning_rate;
private String file_url;
private String file_id;
private String start_train_time;
private int proc_per_node;
private int eval_batch_size;
private int accumulation_steps;
private String scenario;
private Diagnosis diagnosis;
private String deleted_at;
private String paas_job_name;
private String weight_decay;
private String data_path;
private String progress;
private String warmup_ratio;
private String output_dir;
private String progress_epochs;
private String channel_id;
private String logging_steps;
private String script_file;
private String progress_loss;
private String template_id;
private String master_port;
private String progress_learning_rate;
private String eval_steps;
private String tenant_id;
private String save_steps;
private String base_model_path;
private String train_script;
private String updated_at;
private String save_total_limit;
private String validation_file;
// Nested classes for TrainAnalysis and Diagnosis

View File

@ -4,7 +4,9 @@ import cn.hutool.json.JSONObject;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.module.llm.controller.admin.dataset.dto.DataJsonTemplate;
import cn.iocoder.yudao.module.llm.dal.dataobject.modelservice.ModelServiceDO;
import cn.iocoder.yudao.module.llm.dal.dataobject.servername.ServerNameDO;
import cn.iocoder.yudao.module.llm.dal.mysql.modelservice.ModelServiceMapper;
import cn.iocoder.yudao.module.llm.dal.mysql.servername.ServerNameMapper;
import cn.iocoder.yudao.module.llm.enums.ModelDeployConstantEnum;
import cn.iocoder.yudao.module.llm.handler.AigcCustomDateTimeDeserializer;
import cn.iocoder.yudao.module.llm.service.http.TrainHttpService;
@ -44,6 +46,9 @@ public class ModelServiceTaskSyncService {
@Resource
private ModelServiceMapper modelServiceMapper;
@Resource
ServerNameMapper serverNameMapper;
@Scheduled(cron = "0 */2 * * * ?")
@ -63,7 +68,9 @@ public class ModelServiceTaskSyncService {
// 使用 TypeReference 解析 JSON 字符串为 List<String>
try {
String query = "?filter={\"id\":"+jobid+"}";
String res = trainHttpService.modelTableQuery(new HashMap<>(), "model_deploy",query);
ServerNameDO serverNameDO = serverNameMapper.selectById(modelServiceDO.getGpuType());
String hostUrl = serverNameDO!=null ?serverNameDO.getHost():"";
String res = trainHttpService.modelTableQuery(new HashMap<>(), hostUrl,"model_deploy",query);
log.info("获取 aigc model_deploy 表数据 info {}",res);
ObjectMapper mapper = new ObjectMapper();
mapper.registerModule(new JavaTimeModule());