refactor(module-llm):
- 重构微调任务状态同步服务 - 优化了微调任务状态同步定时任务的执行逻辑 - 增加了日志记录 - 重构了检查点信息获取逻辑
This commit is contained in:
parent
962c31e540
commit
0823596f97
@ -119,5 +119,8 @@ public class LLMBackendProperties {
|
||||
|
||||
private String embedQuery;
|
||||
|
||||
/**
|
||||
* 获取调优检查点列表
|
||||
*/
|
||||
private String checkFileList;
|
||||
}
|
||||
|
@ -7,14 +7,12 @@ import cn.iocoder.yudao.module.llm.dal.mysql.finetuningtask.FineTuningTaskMapper
|
||||
import cn.iocoder.yudao.module.llm.dal.mysql.servername.ServerNameMapper;
|
||||
import cn.iocoder.yudao.module.llm.enums.FineTuningTaskStatusConstants;
|
||||
import cn.iocoder.yudao.module.llm.enums.FinetuningTaskStatusEnum;
|
||||
import cn.iocoder.yudao.module.llm.service.basemodel.vo.ModelListRes;
|
||||
import cn.iocoder.yudao.module.llm.service.http.FineTuningTaskHttpService;
|
||||
import cn.iocoder.yudao.module.llm.service.http.TrainHttpService;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.AigcFineTuningDetailRespVO;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.AigcModelDeployVO;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.esotericsoftware.minlog.Log;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@ -26,7 +24,6 @@ import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.net.MalformedURLException;
|
||||
import java.net.URL;
|
||||
import java.net.URLConnection;
|
||||
import java.time.Duration;
|
||||
@ -49,46 +46,78 @@ public class FineTuningTaskSyncService {
|
||||
@Resource
|
||||
private FineTuningTaskHttpService fineTuningTaskHttpService;
|
||||
|
||||
@Scheduled(cron = "0 */1 * * * ?")
|
||||
public void updateFineTuningTaskStatus() {
|
||||
Log.info("FineTuningTaskSync 定时任务启动");
|
||||
List<FineTuningTaskDO> fineTuningTaskDOList = fineTuningTaskMapper.selectList();
|
||||
for (FineTuningTaskDO fineTuningTaskDO : fineTuningTaskDOList) {
|
||||
private static final String FINE_TUNING_LOG = "~~~ fine - tuning ~~~ : ";
|
||||
|
||||
if(Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.TRAINING.getStatus())
|
||||
|| Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.WAITING.getStatus())){
|
||||
ServerNameDO serverNameDO = serverNameMapper.selectById(fineTuningTaskDO.getGpuType());
|
||||
if (fineTuningTaskDO.getJobId() == null){
|
||||
@Scheduled(cron = "0/20 * * * * ?")
|
||||
public void updateFineTuningTaskStatus () {
|
||||
// 方法入口:微调任务状态同步定时任务
|
||||
log.info("{} [定时任务] FineTuningTaskSync 启动 - 开始同步微调任务状态", FINE_TUNING_LOG);
|
||||
|
||||
// 1. 查询所有待处理的微调任务
|
||||
log.debug("正在查询所有微调任务...");
|
||||
List<FineTuningTaskDO> fineTuningTaskDOList = fineTuningTaskMapper.selectList();
|
||||
log.info("{} [任务查询] 共获取到{}个微调任务。", FINE_TUNING_LOG, fineTuningTaskDOList.size());
|
||||
|
||||
for (FineTuningTaskDO fineTuningTaskDO : fineTuningTaskDOList) {
|
||||
log.info("{} [任务处理] 开始处理任务DO:{}", FINE_TUNING_LOG, JSON.toJSONString(fineTuningTaskDO));
|
||||
|
||||
// 检查任务状态是否为训练中或等待中
|
||||
if (Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.TRAINING.getStatus()) || Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.WAITING.getStatus())) {
|
||||
log.info("{} [状态更新] 检测到需更新状态的任务ID:{},当前状态:{}",
|
||||
FINE_TUNING_LOG, fineTuningTaskDO.getId(), fineTuningTaskDO.getStatus());
|
||||
|
||||
// 3. 获取GPU服务器信息
|
||||
|
||||
String hostUrl = getHostUrl(fineTuningTaskDO);
|
||||
|
||||
// 4. 检查任务是否有 Job ID
|
||||
if (fineTuningTaskDO.getJobId() == null) {
|
||||
log.warn("{} 微调任务未关联 Job ID,任务ID: {}", FINE_TUNING_LOG, fineTuningTaskDO.getId());
|
||||
continue;
|
||||
}
|
||||
String hostUrl = serverNameDO!=null ?serverNameDO.getHost():"";
|
||||
String queryJobs = "?filter={\"job_id\":\""+fineTuningTaskDO.getJobId()+"\"}";
|
||||
String respJobs = fineTuningTaskHttpService.modelTableQuery(new HashMap<>(), hostUrl,"fine_tuning_train_job",queryJobs);
|
||||
|
||||
// 5. 构建API查询参数
|
||||
String queryJobs = "?filter={\"job_id\":\"" + fineTuningTaskDO.getJobId() + "\"}";
|
||||
log.info("{} [API请求] 发起任务查询 job_id={} 查询参数: {}", FINE_TUNING_LOG, fineTuningTaskDO.getJobId(), queryJobs);
|
||||
|
||||
// 6. 调用HTTP服务获取任务详情
|
||||
String respJobs = fineTuningTaskHttpService.modelTableQuery(
|
||||
new HashMap<>(), hostUrl, "fine_tuning_train_job", queryJobs);
|
||||
log.info("{} [API响应] 任务ID:{} 原始响应数据:{}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), respJobs);
|
||||
|
||||
// 7. 解析API响应数据
|
||||
AigcFineTuningDetailRespVO resp = new AigcFineTuningDetailRespVO();
|
||||
|
||||
try {
|
||||
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
List<AigcFineTuningDetailRespVO> aigcFineTuningDetailRespVO = mapper.readValue(respJobs,new TypeReference<List<AigcFineTuningDetailRespVO>>() {});
|
||||
if (aigcFineTuningDetailRespVO != null && aigcFineTuningDetailRespVO.size() > 0){
|
||||
if(aigcFineTuningDetailRespVO.get(0).getJob_id().equals(fineTuningTaskDO.getJobId())){
|
||||
resp = aigcFineTuningDetailRespVO.get(0);
|
||||
}
|
||||
List<AigcFineTuningDetailRespVO> respList = mapper.readValue(
|
||||
respJobs, new TypeReference<List<AigcFineTuningDetailRespVO>>() {
|
||||
});
|
||||
|
||||
if (respList != null && !respList.isEmpty()) {
|
||||
resp = respList.get(0);
|
||||
log.info("{} [状态更新] 任务ID:{},当前状态:{},API响应状态:{}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), fineTuningTaskDO.getStatus(), resp.getTrain_status());
|
||||
|
||||
}
|
||||
}catch (Exception e){
|
||||
log.error("获取微调任务状态失败{}",e.getMessage());
|
||||
} catch (Exception e) {
|
||||
log.error("{} 解析微调任务状态时发生异常。任务ID: {}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), e);
|
||||
}
|
||||
if (resp == null){
|
||||
|
||||
if (resp == null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
FineTuningTaskDO updateObj = new FineTuningTaskDO();
|
||||
if (ObjectUtil.isAllEmpty(resp.getTrain_status())){
|
||||
if (ObjectUtil.isAllEmpty(resp.getTrain_status())) {
|
||||
continue;
|
||||
}
|
||||
Integer status = FineTuningTaskStatusConstants.getStatus(resp.getTrain_status());
|
||||
if(status != null){
|
||||
if (status != null) {
|
||||
updateObj.setId(fineTuningTaskDO.getId());
|
||||
}
|
||||
//如果微调任务由训练中变为已完成,则取模型列表中获取该任务生成的模型id
|
||||
if((fineTuningTaskDO.getStatus() == 1 && status == 2) || (fineTuningTaskDO.getStatus() == 3 && status == 2)){
|
||||
if ((fineTuningTaskDO.getStatus() == 1 && status == 2) || (fineTuningTaskDO.getStatus() == 3 && status == 2)) {
|
||||
String jobModelName = fineTuningTaskDO.getJobModelName();
|
||||
/* String resStr = trainHttpService.modelsList("");
|
||||
Log.info("获取aicg模型列表返回数据内容{}",resStr);
|
||||
@ -117,82 +146,121 @@ public class FineTuningTaskSyncService {
|
||||
}*/
|
||||
updateObj.setStatus(2);
|
||||
// 获取模型id
|
||||
String querModels = "?filter={\"model_name\":\""+resp.getFine_tuned_model()+"\"}";
|
||||
String resModels = fineTuningTaskHttpService.modelTableQuery(new HashMap<>(),hostUrl, "models",querModels);
|
||||
log.info("获取 aigc models 表数据 info {}",resModels);
|
||||
String querModels = "?filter={\"model_name\":\"" + resp.getFine_tuned_model() + "\"}";
|
||||
String resModels = fineTuningTaskHttpService.modelTableQuery(new HashMap<>(), hostUrl, "models", querModels);
|
||||
log.info("获取 aigc models 表数据 info {}", resModels);
|
||||
JSONArray jsonArrayModels = JSONArray.parseArray(resModels);
|
||||
|
||||
// 时长获取
|
||||
Duration duration = Duration.between(fineTuningTaskDO.getCreateTime(),LocalDateTime.now());
|
||||
Duration duration = Duration.between(fineTuningTaskDO.getCreateTime(), LocalDateTime.now());
|
||||
long minutes = duration.toMinutes();
|
||||
updateObj.setTrainDuration(String.valueOf(minutes));
|
||||
|
||||
if(jsonArrayModels.size() > 0){
|
||||
if (jsonArrayModels.size() > 0) {
|
||||
JSONObject jsonObjectModels = jsonArrayModels.getJSONObject(0);
|
||||
updateObj.setJobModelListId(jsonObjectModels.getLong("id"));
|
||||
}
|
||||
|
||||
}catch (Exception e){
|
||||
log.error(" error {}",e.getMessage());
|
||||
} catch (Exception e) {
|
||||
log.error(" error {}", e.getMessage());
|
||||
}
|
||||
|
||||
try {
|
||||
//获取检查点信息
|
||||
//todo 模型工厂的功能有问题,暂时写死
|
||||
// jobModelName = "Qwen2.5-0.5B-Instruct-147";
|
||||
String checkFileList = trainHttpService.getCheckFileList(hostUrl,jobModelName);
|
||||
List<String> checkpoints = new ArrayList<>();
|
||||
List<String> fileUrls = new ArrayList<>();
|
||||
List<String> fileList = JSONArray.parseArray(checkFileList,String.class);
|
||||
Map<String,JSONObject> map = new HashMap<>();
|
||||
for (String s : fileList) {
|
||||
if(s.contains("checkpoint")){
|
||||
checkpoints.add(s);
|
||||
}
|
||||
}
|
||||
if(checkpoints.size() > 0){
|
||||
for (String checkpoint : checkpoints) {
|
||||
String filePath = "/" + checkpoint + "/trainer_state.json";
|
||||
fileUrls.add(filePath);
|
||||
String fileUrl = trainHttpService.getCheckFile(hostUrl,jobModelName, filePath);
|
||||
try {
|
||||
URL url = new URL(fileUrl);
|
||||
URLConnection urlConnection = url.openConnection();
|
||||
InputStream inputStream = urlConnection.getInputStream();
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
|
||||
String line;
|
||||
StringBuilder content = new StringBuilder();
|
||||
while ((line = reader.readLine()) != null) {
|
||||
content.append(line + "\n");
|
||||
}
|
||||
reader.close();
|
||||
map.put(checkpoint,JSONObject.parseObject(content.toString()));
|
||||
} catch (MalformedURLException e) {
|
||||
e.printStackTrace();
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
updateObj.setCheckPointFilePath(JSONObject.toJSONString(fileUrls));
|
||||
updateObj.setCheckPointData(JSONObject.toJSONString(map));
|
||||
|
||||
}catch (Exception e){
|
||||
log.error(" error {}",e.getMessage());
|
||||
}
|
||||
getCheckPoint(fineTuningTaskDO, jobModelName, hostUrl, updateObj);
|
||||
|
||||
fineTuningTaskMapper.updateById(updateObj);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
updateObj.setStatus(status);
|
||||
}
|
||||
fineTuningTaskMapper.updateById(updateObj);
|
||||
|
||||
fineTuningTaskMapper.updateById(updateObj);
|
||||
} else {
|
||||
// 处理已成功但是没有获取到检查点错误
|
||||
String jobModelName = fineTuningTaskDO.getJobModelName();
|
||||
String hostUrl = getHostUrl(fineTuningTaskDO);
|
||||
FineTuningTaskDO updateObj = new FineTuningTaskDO();
|
||||
updateObj.setId(fineTuningTaskDO.getId());
|
||||
getCheckPoint(fineTuningTaskDO, jobModelName, hostUrl, updateObj);
|
||||
fineTuningTaskMapper.updateById(updateObj);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
private String getHostUrl (FineTuningTaskDO fineTuningTaskDO) {
|
||||
ServerNameDO serverNameDO = serverNameMapper.selectById(fineTuningTaskDO.getGpuType());
|
||||
String hostUrl = (serverNameDO != null) ? serverNameDO.getHost() : "";
|
||||
log.info("{} [主机信息] 任务ID:{} GPU 类型: {} 关联主机URL:{}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), fineTuningTaskDO.getGpuType(), hostUrl);
|
||||
return hostUrl;
|
||||
}
|
||||
|
||||
private void getCheckPoint (FineTuningTaskDO fineTuningTaskDO, String jobModelName, String hostUrl, FineTuningTaskDO updateObj) {
|
||||
// 获取检查点信息
|
||||
try {
|
||||
// String jobModelName = fineTuningTaskDO.getJobModelName();
|
||||
log.info("正在获取检查点信息,模型名称: {}", jobModelName);
|
||||
String checkFileList = trainHttpService.getCheckFileList(hostUrl, jobModelName);
|
||||
List<String> checkpoints = new ArrayList<>();
|
||||
List<String> fileUrls = new ArrayList<>();
|
||||
Map<String, JSONObject> map = new HashMap<>();
|
||||
|
||||
// List<String> fileList = JSONArray.parseArray(checkFileList, String.class);
|
||||
// for (String s : fileList) {
|
||||
// if (s.contains("checkpoint")) {
|
||||
// checkpoints.add(s);
|
||||
// }
|
||||
// }
|
||||
// 判断是否是数组
|
||||
if (checkFileList.startsWith("[")) {
|
||||
List<String> fileList = JSONArray.parseArray(checkFileList, String.class);
|
||||
// 处理文件列表
|
||||
for (String s : fileList) {
|
||||
if (s.contains("checkpoint")) {
|
||||
checkpoints.add(s);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// TODO 处理错误信息
|
||||
JSONObject errorObj = JSON.parseObject(checkFileList);
|
||||
String errorMsg = errorObj.getString("error");
|
||||
// 抛出异常或记录日志
|
||||
log.error("获取检查点文件列表时出错:{}", errorMsg);
|
||||
}
|
||||
|
||||
if (!checkpoints.isEmpty()) {
|
||||
log.info("找到 {} 个检查点文件。", checkpoints.size());
|
||||
for (String checkpoint : checkpoints) {
|
||||
String filePath = "/" + checkpoint + "/trainer_state.json";
|
||||
fileUrls.add(filePath);
|
||||
String fileUrl = trainHttpService.getCheckFile(hostUrl, jobModelName, filePath);
|
||||
|
||||
try {
|
||||
URL url = new URL(fileUrl);
|
||||
URLConnection urlConnection = url.openConnection();
|
||||
InputStream inputStream = urlConnection.getInputStream();
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
|
||||
StringBuilder content = new StringBuilder();
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
content.append(line).append("\n");
|
||||
}
|
||||
reader.close();
|
||||
map.put(checkpoint, JSONObject.parseObject(content.toString()));
|
||||
log.info("检查点文件解析成功。检查点: {}", checkpoint);
|
||||
} catch (IOException e) {
|
||||
log.error("读取检查点文件时发生异常。任务ID: {}", fineTuningTaskDO.getId(), e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.warn("未找到检查点文件,任务ID: {}", fineTuningTaskDO.getId());
|
||||
}
|
||||
|
||||
updateObj.setCheckPointFilePath(JSONObject.toJSONString(fileUrls));
|
||||
updateObj.setCheckPointData(JSONObject.toJSONString(map));
|
||||
log.info("检查点信息更新完成。");
|
||||
} catch (Exception e) {
|
||||
log.error("获取检查点信息时发生异常。任务ID: {}", fineTuningTaskDO.getId(), e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user