diff --git a/yudao-module-llm/yudao-module-llm-api/src/main/java/cn/iocoder/yudao/module/llm/enums/FinetuningTaskStatusEnum.java b/yudao-module-llm/yudao-module-llm-api/src/main/java/cn/iocoder/yudao/module/llm/enums/FinetuningTaskStatusEnum.java new file mode 100644 index 000000000..ec42360e2 --- /dev/null +++ b/yudao-module-llm/yudao-module-llm-api/src/main/java/cn/iocoder/yudao/module/llm/enums/FinetuningTaskStatusEnum.java @@ -0,0 +1,38 @@ +package cn.iocoder.yudao.module.llm.enums; + +import cn.iocoder.yudao.framework.common.core.IntArrayValuable; +import lombok.AllArgsConstructor; +import lombok.Getter; + +import java.util.Arrays; + +/** + * 用户类型的枚举值 + * + * @author 张陶 + */ +@Getter +@AllArgsConstructor +public enum FinetuningTaskStatusEnum implements IntArrayValuable { + + /** 等待中、等待中 */ + WAITING(4), + /** 进行中、训练中 */ + TRAINING(1), + /** 部署完成、训练完成 */ + FINISHED(2), + /** 已停止、已取消 */ + CANCELLED(0); + + public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(FinetuningTaskStatusEnum::getStatus).toArray(); + + /** + * 用户类型 + */ + private final Integer status; + + @Override + public int[] array() { + return ARRAYS; + } +} diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/dal/dataobject/finetuningtask/FineTuningTaskDO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/dal/dataobject/finetuningtask/FineTuningTaskDO.java index a81012d64..9e3684299 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/dal/dataobject/finetuningtask/FineTuningTaskDO.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/dal/dataobject/finetuningtask/FineTuningTaskDO.java @@ -140,4 +140,5 @@ public class FineTuningTaskDO extends BaseDO { */ private Long baseModelId; + private String jobId; } \ No newline at end of file diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/async/AsyncFineTuningTaskService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/async/AsyncFineTuningTaskService.java index aa4041ffb..c83d11100 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/async/AsyncFineTuningTaskService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/async/AsyncFineTuningTaskService.java @@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.llm.service.async; import cn.iocoder.yudao.module.llm.dal.dataobject.finetuningtask.FineTuningTaskDO; import cn.iocoder.yudao.module.llm.dal.mysql.finetuningtask.FineTuningTaskMapper; +import cn.iocoder.yudao.module.llm.enums.FinetuningTaskStatusEnum; import cn.iocoder.yudao.module.llm.service.http.TrainHttpService; import cn.iocoder.yudao.module.llm.service.http.vo.AigcFineTuningCreateReqVO; import cn.iocoder.yudao.module.llm.service.http.vo.AigcFineTuningCreateRespVO; @@ -26,29 +27,56 @@ public class AsyncFineTuningTaskService { @Async public void createTuning(FineTuningTaskDO fineTuningTask) { try { - AigcFineTuningCreateReqVO req = new AigcFineTuningCreateReqVO(); + AigcFineTuningCreateReqVO req = getAigcFineTuningCreateReqVO(fineTuningTask); AigcFineTuningCreateRespVO resp = trainHttpService.finetuningCreate(new HashMap<>(), req); + FineTuningTaskDO updateObj = new FineTuningTaskDO(); + updateObj.setId(fineTuningTask.getId()); + if (resp != null) { + updateObj.setJobId(resp.getJobId()); + updateObj.setStatus(FinetuningTaskStatusEnum.WAITING.getStatus()); + } else { + updateObj.setStatus(FinetuningTaskStatusEnum.CANCELLED.getStatus()); + } + fineTuningTaskMapper.updateById(updateObj); } catch (Exception e){ e.printStackTrace(); }; } + private static AigcFineTuningCreateReqVO getAigcFineTuningCreateReqVO(FineTuningTaskDO fineTuningTask) { + AigcFineTuningCreateReqVO req = new AigcFineTuningCreateReqVO(); + req.setBaseModel(fineTuningTask.getBaseModel()); + req.setTrainEpoch(fineTuningTask.getEpoch()); + req.setSuffix(fineTuningTask.getTaskName()); + req.setRemark(fineTuningTask.getTaskIntro()); + req.setTrainBatchSize(fineTuningTask.getBatchSize()); + req.setEvalBatchSize(fineTuningTask.getBatchSize()); + req.setAccumulationSteps(fineTuningTask.getGradientAccumulation()); + req.setProcPerNode(fineTuningTask.getGpuCount()); + req.setLearningRate(fineTuningTask.getLearningRate()); + req.setModelMaxLength(fineTuningTask.getCutoffLen()); + req.setLora(fineTuningTask.getLorayRank() != null); + return req; + } + //调优任务部署 @Async - public void startFineTuningTask() { + public void startFineTuningTask(FineTuningTaskDO fineTuningTask) { try { - Thread.sleep(30000); + createTuning(fineTuningTask); }catch(Exception e){ + e.printStackTrace(); }; } //调优任务停止 @Async - public void stopFineTuningTask() { + public void stopFineTuningTask(FineTuningTaskDO fineTuningTask) { try { - Thread.sleep(30000); + trainHttpService.finetuningDelete(new HashMap<>(), fineTuningTask.getJobId()); }catch(Exception e){ + e.printStackTrace(); }; } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskServiceImpl.java index a8f7db577..3ee19b09e 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskServiceImpl.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/finetuningtask/FineTuningTaskServiceImpl.java @@ -78,10 +78,12 @@ public class FineTuningTaskServiceImpl implements FineTuningTaskService { // fineTuningTaskMapper.deleteById(id); } - private void validateFineTuningTaskExists(Long id) { - if (fineTuningTaskMapper.selectById(id) == null) { + private FineTuningTaskDO validateFineTuningTaskExists(Long id) { + FineTuningTaskDO fineTuningTaskDO = fineTuningTaskMapper.selectById(id); + if (fineTuningTaskDO == null) { throw exception(FINE_TUNING_TASK_NOT_EXISTS); } + return fineTuningTaskDO; } @Override @@ -135,9 +137,10 @@ public class FineTuningTaskServiceImpl implements FineTuningTaskService { @Override public void startFineTuningTask(Long id) { + FineTuningTaskDO fineTuningTaskDO = validateFineTuningTaskExists(id); fineTuningTaskMapper.stopStartTask(id,1); //todo 调用模型服务,开启调优任务 - asyncFineTuningTaskService.startFineTuningTask(); + asyncFineTuningTaskService.startFineTuningTask(fineTuningTaskDO); fineTuningTaskMapper.stopStartTask(id,2); @@ -151,9 +154,10 @@ public class FineTuningTaskServiceImpl implements FineTuningTaskService { @Override public void stopFineTuningTask(Long id) { + FineTuningTaskDO fineTuningTaskDO = validateFineTuningTaskExists(id); fineTuningTaskMapper.stopStartTask(id,1); //todo 调用模型服务,停止调优任务 - asyncFineTuningTaskService.stopFineTuningTask(); + asyncFineTuningTaskService.stopFineTuningTask(fineTuningTaskDO); fineTuningTaskMapper.stopStartTask(id,0);