修改模型调优
This commit is contained in:
parent
fded71ba43
commit
1ca41b7cf6
@ -44,7 +44,7 @@ public class BaseModelTaskService {
|
||||
Long modelId = baseModelDO.getModelId();
|
||||
|
||||
String query = "?filter={\"id\":" + modelId + "}";
|
||||
String res = trainHttpService.modelTableQuery(new HashMap<>(), "model_deploy", query);
|
||||
String res = trainHttpService.modelTableQuery(new HashMap<>(), "","model_deploy", query);
|
||||
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
mapper.registerModule(new JavaTimeModule());
|
||||
|
@ -1,15 +1,20 @@
|
||||
package cn.iocoder.yudao.module.llm.service.finetuningtask;
|
||||
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.finetuningtask.FineTuningTaskDO;
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.servername.ServerNameDO;
|
||||
import cn.iocoder.yudao.module.llm.dal.mysql.finetuningtask.FineTuningTaskMapper;
|
||||
import cn.iocoder.yudao.module.llm.dal.mysql.servername.ServerNameMapper;
|
||||
import cn.iocoder.yudao.module.llm.enums.FineTuningTaskStatusConstants;
|
||||
import cn.iocoder.yudao.module.llm.enums.FinetuningTaskStatusEnum;
|
||||
import cn.iocoder.yudao.module.llm.service.basemodel.vo.ModelListRes;
|
||||
import cn.iocoder.yudao.module.llm.service.http.TrainHttpService;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.AigcFineTuningDetailRespVO;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.AigcModelDeployVO;
|
||||
import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.esotericsoftware.minlog.Log;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.scheduling.annotation.Scheduled;
|
||||
import org.springframework.stereotype.Component;
|
||||
@ -35,6 +40,8 @@ public class FineTuningTaskSyncService {
|
||||
|
||||
@Resource
|
||||
FineTuningTaskMapper fineTuningTaskMapper;
|
||||
@Resource
|
||||
ServerNameMapper serverNameMapper;
|
||||
|
||||
@Scheduled(cron = "0 */3 * * * ?")
|
||||
public void updateFineTuningTaskStatus() {
|
||||
@ -44,12 +51,27 @@ public class FineTuningTaskSyncService {
|
||||
|
||||
if(Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.TRAINING.getStatus())
|
||||
|| Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.WAITING.getStatus())){
|
||||
AigcFineTuningDetailRespVO resp = trainHttpService.finetuningDetail(new HashMap<>(), fineTuningTaskDO.getJobId());
|
||||
ServerNameDO serverNameDO = serverNameMapper.selectById(fineTuningTaskDO.getGpuType());
|
||||
String hostUrl = serverNameDO!=null ?serverNameDO.getHost():"";
|
||||
String queryJobs = "?filter={\"job_id\":\""+fineTuningTaskDO.getJobId()+"\"}";
|
||||
String respJobs = trainHttpService.modelTableQuery(new HashMap<>(), hostUrl,"fine_tuning_train_job",queryJobs);
|
||||
AigcFineTuningDetailRespVO resp = new AigcFineTuningDetailRespVO();
|
||||
try {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
List<AigcFineTuningDetailRespVO> aigcFineTuningDetailRespVO = mapper.readValue(respJobs,new TypeReference<List<AigcFineTuningDetailRespVO>>() {});
|
||||
if (aigcFineTuningDetailRespVO != null && aigcFineTuningDetailRespVO.size() > 0){
|
||||
if(aigcFineTuningDetailRespVO.get(0).getJob_id().equals(fineTuningTaskDO.getJobId())){
|
||||
resp = aigcFineTuningDetailRespVO.get(0);
|
||||
}
|
||||
}
|
||||
}catch (Exception e){
|
||||
log.error("获取微调任务状态失败{}",e.getMessage());
|
||||
}
|
||||
if (resp == null){
|
||||
continue;
|
||||
}
|
||||
FineTuningTaskDO updateObj = new FineTuningTaskDO();
|
||||
Integer status = FineTuningTaskStatusConstants.getStatus(resp.getTrainStatus());
|
||||
Integer status = FineTuningTaskStatusConstants.getStatus(resp.getTrain_status());
|
||||
if(status != null){
|
||||
updateObj.setId(fineTuningTaskDO.getId());
|
||||
}
|
||||
@ -83,8 +105,8 @@ public class FineTuningTaskSyncService {
|
||||
}*/
|
||||
updateObj.setStatus(2);
|
||||
// 获取模型id
|
||||
String querModels = "?filter={\"model_name\":\""+resp.getFineTunedModel()+"\"}";
|
||||
String resModels = trainHttpService.modelTableQuery(new HashMap<>(), "models",querModels);
|
||||
String querModels = "?filter={\"model_name\":\""+resp.getFine_tuned_model()+"\"}";
|
||||
String resModels = trainHttpService.modelTableQuery(new HashMap<>(),hostUrl, "models",querModels);
|
||||
log.info("获取 aigc models 表数据 info {}",resModels);
|
||||
JSONArray jsonArrayModels = JSONArray.parseArray(resModels);
|
||||
|
||||
@ -106,7 +128,7 @@ public class FineTuningTaskSyncService {
|
||||
//获取检查点信息
|
||||
//todo 模型工厂的功能有问题,暂时写死
|
||||
jobModelName = "Qwen2.5-0.5B-Instruct-147";
|
||||
String checkFileList = trainHttpService.getCheckFileList(jobModelName);
|
||||
String checkFileList = trainHttpService.getCheckFileList(hostUrl,jobModelName);
|
||||
List<String> checkpoints = new ArrayList<>();
|
||||
List<String> fileUrls = new ArrayList<>();
|
||||
List<String> fileList = JSONArray.parseArray(checkFileList,String.class);
|
||||
@ -120,7 +142,7 @@ public class FineTuningTaskSyncService {
|
||||
for (String checkpoint : checkpoints) {
|
||||
String filePath = "/" + checkpoint + "/trainer_state.json";
|
||||
fileUrls.add(filePath);
|
||||
String fileUrl = trainHttpService.getCheckFile(jobModelName, filePath);
|
||||
String fileUrl = trainHttpService.getCheckFile(hostUrl,jobModelName, filePath);
|
||||
try {
|
||||
URL url = new URL(fileUrl);
|
||||
URLConnection urlConnection = url.openConnection();
|
||||
|
@ -165,8 +165,8 @@ public class TrainHttpService {
|
||||
/**
|
||||
* 根据表名称查询数据
|
||||
*/
|
||||
public String modelTableQuery(Map<String, String> headers, String tableName,String query){
|
||||
String url = String.format(llmBackendProperties.getTableDataQuery(),tableName);
|
||||
public String modelTableQuery(Map<String, String> headers, String urlHost,String tableName,String query){
|
||||
String url = String.format(urlHost+llmBackendProperties.getTableDataQuery(),tableName);
|
||||
url = url+query;
|
||||
String res = HttpUtils.get(url, headers);
|
||||
log.info(" model query info :{}", res);
|
||||
@ -337,14 +337,14 @@ public class TrainHttpService {
|
||||
return HttpUtils.get(url+finetuningLog+params, null);
|
||||
}
|
||||
|
||||
public String getCheckFileList(String name){
|
||||
public String getCheckFileList(String url,String name){
|
||||
String checkFileList = llmBackendProperties.getCheckFileList();
|
||||
return HttpUtils.get(checkFileList + name, null);
|
||||
return HttpUtils.get(url+checkFileList + name, null);
|
||||
}
|
||||
|
||||
public String getCheckFile(String name,String path){
|
||||
public String getCheckFile(String url,String name,String path){
|
||||
String checkFileList = llmBackendProperties.getCheckFileList();
|
||||
return checkFileList + name + "&&file_path=" + path;
|
||||
return url+checkFileList + name + "&&file_path=" + path;
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -15,33 +15,58 @@ import java.util.List;
|
||||
public class AigcFineTuningDetailRespVO {
|
||||
|
||||
private int id;
|
||||
private String jobId;
|
||||
private String baseModel;
|
||||
private int trainEpoch;
|
||||
private String trainStatus;
|
||||
private String trainDuration;
|
||||
private String job_id;
|
||||
private String base_model;
|
||||
private int train_epoch;
|
||||
private String train_status;
|
||||
private String train_duration;
|
||||
private int process;
|
||||
private String fineTunedModel;
|
||||
private String fine_tuned_model;
|
||||
private String remark;
|
||||
private String finishedAt;
|
||||
private ZonedDateTime createdAt;
|
||||
private String trainPublisher;
|
||||
private String trainLog;
|
||||
private String errorMessage;
|
||||
private boolean lora;
|
||||
private TrainAnalysis trainAnalysis;
|
||||
private String finished_at;
|
||||
private String created_at;
|
||||
private String train_publisher;
|
||||
private String train_log;
|
||||
private String error_message;
|
||||
private String lora;
|
||||
private TrainAnalysis train_analysis;
|
||||
private String suffix;
|
||||
private int modelMaxLength;
|
||||
private int trainBatchSize;
|
||||
private String learningRate;
|
||||
private String fileUrl;
|
||||
private String fileId;
|
||||
private String startTrainTime;
|
||||
private int procPerNode;
|
||||
private int evalBatchSize;
|
||||
private int accumulationSteps;
|
||||
private int model_max_length;
|
||||
private int train_batch_size;
|
||||
private String learning_rate;
|
||||
private String file_url;
|
||||
private String file_id;
|
||||
private String start_train_time;
|
||||
private int proc_per_node;
|
||||
private int eval_batch_size;
|
||||
private int accumulation_steps;
|
||||
private String scenario;
|
||||
private Diagnosis diagnosis;
|
||||
private String deleted_at;
|
||||
private String paas_job_name;
|
||||
private String weight_decay;
|
||||
private String data_path;
|
||||
private String progress;
|
||||
private String warmup_ratio;
|
||||
private String output_dir;
|
||||
private String progress_epochs;
|
||||
private String channel_id;
|
||||
private String logging_steps;
|
||||
private String script_file;
|
||||
private String progress_loss;
|
||||
private String template_id;
|
||||
private String master_port;
|
||||
private String progress_learning_rate;
|
||||
private String eval_steps;
|
||||
private String tenant_id;
|
||||
private String save_steps;
|
||||
private String base_model_path;
|
||||
private String train_script;
|
||||
private String updated_at;
|
||||
private String save_total_limit;
|
||||
private String validation_file;
|
||||
|
||||
|
||||
|
||||
|
||||
// Nested classes for TrainAnalysis and Diagnosis
|
||||
|
@ -4,7 +4,9 @@ import cn.hutool.json.JSONObject;
|
||||
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.dataset.dto.DataJsonTemplate;
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.modelservice.ModelServiceDO;
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.servername.ServerNameDO;
|
||||
import cn.iocoder.yudao.module.llm.dal.mysql.modelservice.ModelServiceMapper;
|
||||
import cn.iocoder.yudao.module.llm.dal.mysql.servername.ServerNameMapper;
|
||||
import cn.iocoder.yudao.module.llm.enums.ModelDeployConstantEnum;
|
||||
import cn.iocoder.yudao.module.llm.handler.AigcCustomDateTimeDeserializer;
|
||||
import cn.iocoder.yudao.module.llm.service.http.TrainHttpService;
|
||||
@ -44,6 +46,9 @@ public class ModelServiceTaskSyncService {
|
||||
@Resource
|
||||
private ModelServiceMapper modelServiceMapper;
|
||||
|
||||
@Resource
|
||||
ServerNameMapper serverNameMapper;
|
||||
|
||||
|
||||
|
||||
@Scheduled(cron = "0 */2 * * * ?")
|
||||
@ -63,7 +68,9 @@ public class ModelServiceTaskSyncService {
|
||||
// 使用 TypeReference 解析 JSON 字符串为 List<String>
|
||||
try {
|
||||
String query = "?filter={\"id\":"+jobid+"}";
|
||||
String res = trainHttpService.modelTableQuery(new HashMap<>(), "model_deploy",query);
|
||||
ServerNameDO serverNameDO = serverNameMapper.selectById(modelServiceDO.getGpuType());
|
||||
String hostUrl = serverNameDO!=null ?serverNameDO.getHost():"";
|
||||
String res = trainHttpService.modelTableQuery(new HashMap<>(), hostUrl,"model_deploy",query);
|
||||
log.info("获取 aigc model_deploy 表数据 info {}",res);
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
mapper.registerModule(new JavaTimeModule());
|
||||
|
Loading…
x
Reference in New Issue
Block a user