微调任务API对接
This commit is contained in:
parent
d287405afe
commit
8c015ab526
@ -0,0 +1,38 @@
|
||||
package cn.iocoder.yudao.module.llm.enums;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* 用户类型的枚举值
|
||||
*
|
||||
* @author 张陶
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum FinetuningTaskStatusEnum implements IntArrayValuable {
|
||||
|
||||
/** 等待中、等待中 */
|
||||
WAITING(4),
|
||||
/** 进行中、训练中 */
|
||||
TRAINING(1),
|
||||
/** 部署完成、训练完成 */
|
||||
FINISHED(2),
|
||||
/** 已停止、已取消 */
|
||||
CANCELLED(0);
|
||||
|
||||
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(FinetuningTaskStatusEnum::getStatus).toArray();
|
||||
|
||||
/**
|
||||
* 用户类型
|
||||
*/
|
||||
private final Integer status;
|
||||
|
||||
@Override
|
||||
public int[] array() {
|
||||
return ARRAYS;
|
||||
}
|
||||
}
|
@ -140,4 +140,5 @@ public class FineTuningTaskDO extends BaseDO {
|
||||
*/
|
||||
private Long baseModelId;
|
||||
|
||||
private String jobId;
|
||||
}
|
@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.llm.service.async;
|
||||
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.finetuningtask.FineTuningTaskDO;
|
||||
import cn.iocoder.yudao.module.llm.dal.mysql.finetuningtask.FineTuningTaskMapper;
|
||||
import cn.iocoder.yudao.module.llm.enums.FinetuningTaskStatusEnum;
|
||||
import cn.iocoder.yudao.module.llm.service.http.TrainHttpService;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.AigcFineTuningCreateReqVO;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.AigcFineTuningCreateRespVO;
|
||||
@ -26,29 +27,56 @@ public class AsyncFineTuningTaskService {
|
||||
@Async
|
||||
public void createTuning(FineTuningTaskDO fineTuningTask) {
|
||||
try {
|
||||
AigcFineTuningCreateReqVO req = new AigcFineTuningCreateReqVO();
|
||||
AigcFineTuningCreateReqVO req = getAigcFineTuningCreateReqVO(fineTuningTask);
|
||||
AigcFineTuningCreateRespVO resp = trainHttpService.finetuningCreate(new HashMap<>(), req);
|
||||
FineTuningTaskDO updateObj = new FineTuningTaskDO();
|
||||
updateObj.setId(fineTuningTask.getId());
|
||||
if (resp != null) {
|
||||
updateObj.setJobId(resp.getJobId());
|
||||
updateObj.setStatus(FinetuningTaskStatusEnum.WAITING.getStatus());
|
||||
} else {
|
||||
updateObj.setStatus(FinetuningTaskStatusEnum.CANCELLED.getStatus());
|
||||
}
|
||||
fineTuningTaskMapper.updateById(updateObj);
|
||||
} catch (Exception e){
|
||||
e.printStackTrace();
|
||||
};
|
||||
}
|
||||
|
||||
private static AigcFineTuningCreateReqVO getAigcFineTuningCreateReqVO(FineTuningTaskDO fineTuningTask) {
|
||||
AigcFineTuningCreateReqVO req = new AigcFineTuningCreateReqVO();
|
||||
req.setBaseModel(fineTuningTask.getBaseModel());
|
||||
req.setTrainEpoch(fineTuningTask.getEpoch());
|
||||
req.setSuffix(fineTuningTask.getTaskName());
|
||||
req.setRemark(fineTuningTask.getTaskIntro());
|
||||
req.setTrainBatchSize(fineTuningTask.getBatchSize());
|
||||
req.setEvalBatchSize(fineTuningTask.getBatchSize());
|
||||
req.setAccumulationSteps(fineTuningTask.getGradientAccumulation());
|
||||
req.setProcPerNode(fineTuningTask.getGpuCount());
|
||||
req.setLearningRate(fineTuningTask.getLearningRate());
|
||||
req.setModelMaxLength(fineTuningTask.getCutoffLen());
|
||||
req.setLora(fineTuningTask.getLorayRank() != null);
|
||||
return req;
|
||||
}
|
||||
|
||||
//调优任务部署
|
||||
@Async
|
||||
public void startFineTuningTask() {
|
||||
public void startFineTuningTask(FineTuningTaskDO fineTuningTask) {
|
||||
try {
|
||||
Thread.sleep(30000);
|
||||
createTuning(fineTuningTask);
|
||||
}catch(Exception e){
|
||||
e.printStackTrace();
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
//调优任务停止
|
||||
@Async
|
||||
public void stopFineTuningTask() {
|
||||
public void stopFineTuningTask(FineTuningTaskDO fineTuningTask) {
|
||||
try {
|
||||
Thread.sleep(30000);
|
||||
trainHttpService.finetuningDelete(new HashMap<>(), fineTuningTask.getJobId());
|
||||
}catch(Exception e){
|
||||
e.printStackTrace();
|
||||
|
||||
};
|
||||
}
|
||||
|
@ -78,10 +78,12 @@ public class FineTuningTaskServiceImpl implements FineTuningTaskService {
|
||||
// fineTuningTaskMapper.deleteById(id);
|
||||
}
|
||||
|
||||
private void validateFineTuningTaskExists(Long id) {
|
||||
if (fineTuningTaskMapper.selectById(id) == null) {
|
||||
private FineTuningTaskDO validateFineTuningTaskExists(Long id) {
|
||||
FineTuningTaskDO fineTuningTaskDO = fineTuningTaskMapper.selectById(id);
|
||||
if (fineTuningTaskDO == null) {
|
||||
throw exception(FINE_TUNING_TASK_NOT_EXISTS);
|
||||
}
|
||||
return fineTuningTaskDO;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -135,9 +137,10 @@ public class FineTuningTaskServiceImpl implements FineTuningTaskService {
|
||||
|
||||
@Override
|
||||
public void startFineTuningTask(Long id) {
|
||||
FineTuningTaskDO fineTuningTaskDO = validateFineTuningTaskExists(id);
|
||||
fineTuningTaskMapper.stopStartTask(id,1);
|
||||
//todo 调用模型服务,开启调优任务
|
||||
asyncFineTuningTaskService.startFineTuningTask();
|
||||
asyncFineTuningTaskService.startFineTuningTask(fineTuningTaskDO);
|
||||
|
||||
fineTuningTaskMapper.stopStartTask(id,2);
|
||||
|
||||
@ -151,9 +154,10 @@ public class FineTuningTaskServiceImpl implements FineTuningTaskService {
|
||||
|
||||
@Override
|
||||
public void stopFineTuningTask(Long id) {
|
||||
FineTuningTaskDO fineTuningTaskDO = validateFineTuningTaskExists(id);
|
||||
fineTuningTaskMapper.stopStartTask(id,1);
|
||||
//todo 调用模型服务,停止调优任务
|
||||
asyncFineTuningTaskService.stopFineTuningTask();
|
||||
asyncFineTuningTaskService.stopFineTuningTask(fineTuningTaskDO);
|
||||
|
||||
fineTuningTaskMapper.stopStartTask(id,0);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user