refactor(module-llm):

- 重构微调任务状态同步服务
- 优化了微调任务状态同步定时任务的执行逻辑
- 增加了日志记录
- 重构了检查点信息获取逻辑
This commit is contained in:
Liuyang 2025-03-01 13:17:39 +08:00
parent 962c31e540
commit 0823596f97
2 changed files with 153 additions and 82 deletions

View File

@ -119,5 +119,8 @@ public class LLMBackendProperties {
private String embedQuery;
/**
* 获取调优检查点列表
*/
private String checkFileList;
}

View File

@ -7,14 +7,12 @@ import cn.iocoder.yudao.module.llm.dal.mysql.finetuningtask.FineTuningTaskMapper
import cn.iocoder.yudao.module.llm.dal.mysql.servername.ServerNameMapper;
import cn.iocoder.yudao.module.llm.enums.FineTuningTaskStatusConstants;
import cn.iocoder.yudao.module.llm.enums.FinetuningTaskStatusEnum;
import cn.iocoder.yudao.module.llm.service.basemodel.vo.ModelListRes;
import cn.iocoder.yudao.module.llm.service.http.FineTuningTaskHttpService;
import cn.iocoder.yudao.module.llm.service.http.TrainHttpService;
import cn.iocoder.yudao.module.llm.service.http.vo.AigcFineTuningDetailRespVO;
import cn.iocoder.yudao.module.llm.service.http.vo.AigcModelDeployVO;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.esotericsoftware.minlog.Log;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
@ -26,7 +24,6 @@ import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLConnection;
import java.time.Duration;
@ -49,46 +46,78 @@ public class FineTuningTaskSyncService {
@Resource
private FineTuningTaskHttpService fineTuningTaskHttpService;
@Scheduled(cron = "0 */1 * * * ?")
public void updateFineTuningTaskStatus() {
Log.info("FineTuningTaskSync 定时任务启动");
List<FineTuningTaskDO> fineTuningTaskDOList = fineTuningTaskMapper.selectList();
for (FineTuningTaskDO fineTuningTaskDO : fineTuningTaskDOList) {
private static final String FINE_TUNING_LOG = "~~~ fine - tuning ~~~ : ";
if(Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.TRAINING.getStatus())
|| Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.WAITING.getStatus())){
ServerNameDO serverNameDO = serverNameMapper.selectById(fineTuningTaskDO.getGpuType());
if (fineTuningTaskDO.getJobId() == null){
@Scheduled(cron = "0/20 * * * * ?")
public void updateFineTuningTaskStatus () {
// 方法入口微调任务状态同步定时任务
log.info("{} [定时任务] FineTuningTaskSync 启动 - 开始同步微调任务状态", FINE_TUNING_LOG);
// 1. 查询所有待处理的微调任务
log.debug("正在查询所有微调任务...");
List<FineTuningTaskDO> fineTuningTaskDOList = fineTuningTaskMapper.selectList();
log.info("{} [任务查询] 共获取到{}个微调任务。", FINE_TUNING_LOG, fineTuningTaskDOList.size());
for (FineTuningTaskDO fineTuningTaskDO : fineTuningTaskDOList) {
log.info("{} [任务处理] 开始处理任务DO:{}", FINE_TUNING_LOG, JSON.toJSONString(fineTuningTaskDO));
// 检查任务状态是否为训练中或等待中
if (Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.TRAINING.getStatus()) || Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.WAITING.getStatus())) {
log.info("{} [状态更新] 检测到需更新状态的任务ID:{},当前状态:{}",
FINE_TUNING_LOG, fineTuningTaskDO.getId(), fineTuningTaskDO.getStatus());
// 3. 获取GPU服务器信息
String hostUrl = getHostUrl(fineTuningTaskDO);
// 4. 检查任务是否有 Job ID
if (fineTuningTaskDO.getJobId() == null) {
log.warn("{} 微调任务未关联 Job ID任务ID: {}", FINE_TUNING_LOG, fineTuningTaskDO.getId());
continue;
}
String hostUrl = serverNameDO!=null ?serverNameDO.getHost():"";
String queryJobs = "?filter={\"job_id\":\""+fineTuningTaskDO.getJobId()+"\"}";
String respJobs = fineTuningTaskHttpService.modelTableQuery(new HashMap<>(), hostUrl,"fine_tuning_train_job",queryJobs);
// 5. 构建API查询参数
String queryJobs = "?filter={\"job_id\":\"" + fineTuningTaskDO.getJobId() + "\"}";
log.info("{} [API请求] 发起任务查询 job_id={} 查询参数: {}", FINE_TUNING_LOG, fineTuningTaskDO.getJobId(), queryJobs);
// 6. 调用HTTP服务获取任务详情
String respJobs = fineTuningTaskHttpService.modelTableQuery(
new HashMap<>(), hostUrl, "fine_tuning_train_job", queryJobs);
log.info("{} [API响应] 任务ID:{} 原始响应数据:{}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), respJobs);
// 7. 解析API响应数据
AigcFineTuningDetailRespVO resp = new AigcFineTuningDetailRespVO();
try {
ObjectMapper mapper = new ObjectMapper();
List<AigcFineTuningDetailRespVO> aigcFineTuningDetailRespVO = mapper.readValue(respJobs,new TypeReference<List<AigcFineTuningDetailRespVO>>() {});
if (aigcFineTuningDetailRespVO != null && aigcFineTuningDetailRespVO.size() > 0){
if(aigcFineTuningDetailRespVO.get(0).getJob_id().equals(fineTuningTaskDO.getJobId())){
resp = aigcFineTuningDetailRespVO.get(0);
}
List<AigcFineTuningDetailRespVO> respList = mapper.readValue(
respJobs, new TypeReference<List<AigcFineTuningDetailRespVO>>() {
});
if (respList != null && !respList.isEmpty()) {
resp = respList.get(0);
log.info("{} [状态更新] 任务ID:{},当前状态:{}API响应状态:{}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), fineTuningTaskDO.getStatus(), resp.getTrain_status());
}
}catch (Exception e){
log.error("获取微调任务状态失败{}",e.getMessage());
} catch (Exception e) {
log.error("{} 解析微调任务状态时发生异常。任务ID: {}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), e);
}
if (resp == null){
if (resp == null) {
continue;
}
FineTuningTaskDO updateObj = new FineTuningTaskDO();
if (ObjectUtil.isAllEmpty(resp.getTrain_status())){
if (ObjectUtil.isAllEmpty(resp.getTrain_status())) {
continue;
}
Integer status = FineTuningTaskStatusConstants.getStatus(resp.getTrain_status());
if(status != null){
if (status != null) {
updateObj.setId(fineTuningTaskDO.getId());
}
//如果微调任务由训练中变为已完成则取模型列表中获取该任务生成的模型id
if((fineTuningTaskDO.getStatus() == 1 && status == 2) || (fineTuningTaskDO.getStatus() == 3 && status == 2)){
if ((fineTuningTaskDO.getStatus() == 1 && status == 2) || (fineTuningTaskDO.getStatus() == 3 && status == 2)) {
String jobModelName = fineTuningTaskDO.getJobModelName();
/* String resStr = trainHttpService.modelsList("");
Log.info("获取aicg模型列表返回数据内容{}",resStr);
@ -117,82 +146,121 @@ public class FineTuningTaskSyncService {
}*/
updateObj.setStatus(2);
// 获取模型id
String querModels = "?filter={\"model_name\":\""+resp.getFine_tuned_model()+"\"}";
String resModels = fineTuningTaskHttpService.modelTableQuery(new HashMap<>(),hostUrl, "models",querModels);
log.info("获取 aigc models 表数据 info {}",resModels);
String querModels = "?filter={\"model_name\":\"" + resp.getFine_tuned_model() + "\"}";
String resModels = fineTuningTaskHttpService.modelTableQuery(new HashMap<>(), hostUrl, "models", querModels);
log.info("获取 aigc models 表数据 info {}", resModels);
JSONArray jsonArrayModels = JSONArray.parseArray(resModels);
// 时长获取
Duration duration = Duration.between(fineTuningTaskDO.getCreateTime(),LocalDateTime.now());
Duration duration = Duration.between(fineTuningTaskDO.getCreateTime(), LocalDateTime.now());
long minutes = duration.toMinutes();
updateObj.setTrainDuration(String.valueOf(minutes));
if(jsonArrayModels.size() > 0){
if (jsonArrayModels.size() > 0) {
JSONObject jsonObjectModels = jsonArrayModels.getJSONObject(0);
updateObj.setJobModelListId(jsonObjectModels.getLong("id"));
}
}catch (Exception e){
log.error(" error {}",e.getMessage());
} catch (Exception e) {
log.error(" error {}", e.getMessage());
}
try {
//获取检查点信息
//todo 模型工厂的功能有问题暂时写死
// jobModelName = "Qwen2.5-0.5B-Instruct-147";
String checkFileList = trainHttpService.getCheckFileList(hostUrl,jobModelName);
List<String> checkpoints = new ArrayList<>();
List<String> fileUrls = new ArrayList<>();
List<String> fileList = JSONArray.parseArray(checkFileList,String.class);
Map<String,JSONObject> map = new HashMap<>();
for (String s : fileList) {
if(s.contains("checkpoint")){
checkpoints.add(s);
}
}
if(checkpoints.size() > 0){
for (String checkpoint : checkpoints) {
String filePath = "/" + checkpoint + "/trainer_state.json";
fileUrls.add(filePath);
String fileUrl = trainHttpService.getCheckFile(hostUrl,jobModelName, filePath);
try {
URL url = new URL(fileUrl);
URLConnection urlConnection = url.openConnection();
InputStream inputStream = urlConnection.getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
String line;
StringBuilder content = new StringBuilder();
while ((line = reader.readLine()) != null) {
content.append(line + "\n");
}
reader.close();
map.put(checkpoint,JSONObject.parseObject(content.toString()));
} catch (MalformedURLException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
}
updateObj.setCheckPointFilePath(JSONObject.toJSONString(fileUrls));
updateObj.setCheckPointData(JSONObject.toJSONString(map));
}catch (Exception e){
log.error(" error {}",e.getMessage());
}
getCheckPoint(fineTuningTaskDO, jobModelName, hostUrl, updateObj);
fineTuningTaskMapper.updateById(updateObj);
}
else {
} else {
updateObj.setStatus(status);
}
fineTuningTaskMapper.updateById(updateObj);
fineTuningTaskMapper.updateById(updateObj);
} else {
// 处理已成功但是没有获取到检查点错误
String jobModelName = fineTuningTaskDO.getJobModelName();
String hostUrl = getHostUrl(fineTuningTaskDO);
FineTuningTaskDO updateObj = new FineTuningTaskDO();
updateObj.setId(fineTuningTaskDO.getId());
getCheckPoint(fineTuningTaskDO, jobModelName, hostUrl, updateObj);
fineTuningTaskMapper.updateById(updateObj);
}
}
}
private String getHostUrl (FineTuningTaskDO fineTuningTaskDO) {
ServerNameDO serverNameDO = serverNameMapper.selectById(fineTuningTaskDO.getGpuType());
String hostUrl = (serverNameDO != null) ? serverNameDO.getHost() : "";
log.info("{} [主机信息] 任务ID:{} GPU 类型: {} 关联主机URL:{}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), fineTuningTaskDO.getGpuType(), hostUrl);
return hostUrl;
}
private void getCheckPoint (FineTuningTaskDO fineTuningTaskDO, String jobModelName, String hostUrl, FineTuningTaskDO updateObj) {
// 获取检查点信息
try {
// String jobModelName = fineTuningTaskDO.getJobModelName();
log.info("正在获取检查点信息,模型名称: {}", jobModelName);
String checkFileList = trainHttpService.getCheckFileList(hostUrl, jobModelName);
List<String> checkpoints = new ArrayList<>();
List<String> fileUrls = new ArrayList<>();
Map<String, JSONObject> map = new HashMap<>();
// List<String> fileList = JSONArray.parseArray(checkFileList, String.class);
// for (String s : fileList) {
// if (s.contains("checkpoint")) {
// checkpoints.add(s);
// }
// }
// 判断是否是数组
if (checkFileList.startsWith("[")) {
List<String> fileList = JSONArray.parseArray(checkFileList, String.class);
// 处理文件列表
for (String s : fileList) {
if (s.contains("checkpoint")) {
checkpoints.add(s);
}
}
} else {
// TODO 处理错误信息
JSONObject errorObj = JSON.parseObject(checkFileList);
String errorMsg = errorObj.getString("error");
// 抛出异常或记录日志
log.error("获取检查点文件列表时出错:{}", errorMsg);
}
if (!checkpoints.isEmpty()) {
log.info("找到 {} 个检查点文件。", checkpoints.size());
for (String checkpoint : checkpoints) {
String filePath = "/" + checkpoint + "/trainer_state.json";
fileUrls.add(filePath);
String fileUrl = trainHttpService.getCheckFile(hostUrl, jobModelName, filePath);
try {
URL url = new URL(fileUrl);
URLConnection urlConnection = url.openConnection();
InputStream inputStream = urlConnection.getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
StringBuilder content = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
content.append(line).append("\n");
}
reader.close();
map.put(checkpoint, JSONObject.parseObject(content.toString()));
log.info("检查点文件解析成功。检查点: {}", checkpoint);
} catch (IOException e) {
log.error("读取检查点文件时发生异常。任务ID: {}", fineTuningTaskDO.getId(), e);
}
}
} else {
log.warn("未找到检查点文件任务ID: {}", fineTuningTaskDO.getId());
}
updateObj.setCheckPointFilePath(JSONObject.toJSONString(fileUrls));
updateObj.setCheckPointData(JSONObject.toJSONString(map));
log.info("检查点信息更新完成。");
} catch (Exception e) {
log.error("获取检查点信息时发生异常。任务ID: {}", fineTuningTaskDO.getId(), e);
}
}
}