微调任务API对接

This commit is contained in:
zhangtao 2025-01-06 16:56:10 +08:00
parent d287405afe
commit 8c015ab526
4 changed files with 80 additions and 9 deletions

View File

@ -0,0 +1,38 @@
package cn.iocoder.yudao.module.llm.enums;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
/**
* 用户类型的枚举值
*
* @author 张陶
*/
@Getter
@AllArgsConstructor
public enum FinetuningTaskStatusEnum implements IntArrayValuable {
/** 等待中、等待中 */
WAITING(4),
/** 进行中、训练中 */
TRAINING(1),
/** 部署完成、训练完成 */
FINISHED(2),
/** 已停止、已取消 */
CANCELLED(0);
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(FinetuningTaskStatusEnum::getStatus).toArray();
/**
* 用户类型
*/
private final Integer status;
@Override
public int[] array() {
return ARRAYS;
}
}

View File

@ -140,4 +140,5 @@ public class FineTuningTaskDO extends BaseDO {
*/
private Long baseModelId;
private String jobId;
}

View File

@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.llm.service.async;
import cn.iocoder.yudao.module.llm.dal.dataobject.finetuningtask.FineTuningTaskDO;
import cn.iocoder.yudao.module.llm.dal.mysql.finetuningtask.FineTuningTaskMapper;
import cn.iocoder.yudao.module.llm.enums.FinetuningTaskStatusEnum;
import cn.iocoder.yudao.module.llm.service.http.TrainHttpService;
import cn.iocoder.yudao.module.llm.service.http.vo.AigcFineTuningCreateReqVO;
import cn.iocoder.yudao.module.llm.service.http.vo.AigcFineTuningCreateRespVO;
@ -26,29 +27,56 @@ public class AsyncFineTuningTaskService {
@Async
public void createTuning(FineTuningTaskDO fineTuningTask) {
try {
AigcFineTuningCreateReqVO req = new AigcFineTuningCreateReqVO();
AigcFineTuningCreateReqVO req = getAigcFineTuningCreateReqVO(fineTuningTask);
AigcFineTuningCreateRespVO resp = trainHttpService.finetuningCreate(new HashMap<>(), req);
FineTuningTaskDO updateObj = new FineTuningTaskDO();
updateObj.setId(fineTuningTask.getId());
if (resp != null) {
updateObj.setJobId(resp.getJobId());
updateObj.setStatus(FinetuningTaskStatusEnum.WAITING.getStatus());
} else {
updateObj.setStatus(FinetuningTaskStatusEnum.CANCELLED.getStatus());
}
fineTuningTaskMapper.updateById(updateObj);
} catch (Exception e){
e.printStackTrace();
};
}
private static AigcFineTuningCreateReqVO getAigcFineTuningCreateReqVO(FineTuningTaskDO fineTuningTask) {
AigcFineTuningCreateReqVO req = new AigcFineTuningCreateReqVO();
req.setBaseModel(fineTuningTask.getBaseModel());
req.setTrainEpoch(fineTuningTask.getEpoch());
req.setSuffix(fineTuningTask.getTaskName());
req.setRemark(fineTuningTask.getTaskIntro());
req.setTrainBatchSize(fineTuningTask.getBatchSize());
req.setEvalBatchSize(fineTuningTask.getBatchSize());
req.setAccumulationSteps(fineTuningTask.getGradientAccumulation());
req.setProcPerNode(fineTuningTask.getGpuCount());
req.setLearningRate(fineTuningTask.getLearningRate());
req.setModelMaxLength(fineTuningTask.getCutoffLen());
req.setLora(fineTuningTask.getLorayRank() != null);
return req;
}
//调优任务部署
@Async
public void startFineTuningTask() {
public void startFineTuningTask(FineTuningTaskDO fineTuningTask) {
try {
Thread.sleep(30000);
createTuning(fineTuningTask);
}catch(Exception e){
e.printStackTrace();
};
}
//调优任务停止
@Async
public void stopFineTuningTask() {
public void stopFineTuningTask(FineTuningTaskDO fineTuningTask) {
try {
Thread.sleep(30000);
trainHttpService.finetuningDelete(new HashMap<>(), fineTuningTask.getJobId());
}catch(Exception e){
e.printStackTrace();
};
}

View File

@ -78,10 +78,12 @@ public class FineTuningTaskServiceImpl implements FineTuningTaskService {
// fineTuningTaskMapper.deleteById(id);
}
private void validateFineTuningTaskExists(Long id) {
if (fineTuningTaskMapper.selectById(id) == null) {
private FineTuningTaskDO validateFineTuningTaskExists(Long id) {
FineTuningTaskDO fineTuningTaskDO = fineTuningTaskMapper.selectById(id);
if (fineTuningTaskDO == null) {
throw exception(FINE_TUNING_TASK_NOT_EXISTS);
}
return fineTuningTaskDO;
}
@Override
@ -135,9 +137,10 @@ public class FineTuningTaskServiceImpl implements FineTuningTaskService {
@Override
public void startFineTuningTask(Long id) {
FineTuningTaskDO fineTuningTaskDO = validateFineTuningTaskExists(id);
fineTuningTaskMapper.stopStartTask(id,1);
//todo 调用模型服务开启调优任务
asyncFineTuningTaskService.startFineTuningTask();
asyncFineTuningTaskService.startFineTuningTask(fineTuningTaskDO);
fineTuningTaskMapper.stopStartTask(id,2);
@ -151,9 +154,10 @@ public class FineTuningTaskServiceImpl implements FineTuningTaskService {
@Override
public void stopFineTuningTask(Long id) {
FineTuningTaskDO fineTuningTaskDO = validateFineTuningTaskExists(id);
fineTuningTaskMapper.stopStartTask(id,1);
//todo 调用模型服务停止调优任务
asyncFineTuningTaskService.stopFineTuningTask();
asyncFineTuningTaskService.stopFineTuningTask(fineTuningTaskDO);
fineTuningTaskMapper.stopStartTask(id,0);