This commit is contained in:
leon 2025-02-25 14:55:41 +08:00
commit 30be835063
3 changed files with 84 additions and 31 deletions

View File

@ -11,6 +11,7 @@ import cn.iocoder.yudao.module.llm.service.http.TrainHttpService;
import cn.iocoder.yudao.module.llm.service.http.vo.AigcModelDeploySaveReq;
import cn.iocoder.yudao.module.llm.service.http.vo.AigcModelDeployVO;
import cn.iocoder.yudao.module.llm.service.http.vo.ModelDeployRespVO;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.esotericsoftware.minlog.Log;
@ -19,6 +20,7 @@ import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
@ -29,6 +31,7 @@ import java.util.List;
import java.util.stream.Collectors;
@Component
@Slf4j
public class BaseModelTaskService {
@Resource
@ -44,44 +47,90 @@ public class BaseModelTaskService {
// 减少维护 先注释掉
@Scheduled(cron ="0 0/1 * * * ?")
public void synchronous() throws JsonProcessingException {
List<BaseModelDO> baseModelList = baseModelService.getBaseModelList();
for (BaseModelDO baseModelDO : baseModelList) {
Long modelId = baseModelDO.getModelId();
Long gpuId = baseModelDO.getGpuId();
ServerNameDO serverNameDO1 = serverNameMapper.selectById(gpuId);
String query = "?filter={\"id\":" + modelId + "}";
String res = trainHttpService.modelTableQuery(new HashMap<>(), serverNameDO1.getHost(),"model_deploy", query);
try {
log.info("开始同步基础模型信息...");
ObjectMapper mapper = new ObjectMapper();
mapper.registerModule(new JavaTimeModule());
SimpleModule module = new SimpleModule();
module.addDeserializer(LocalDateTime.class, new AigcCustomDateTimeDeserializer());
mapper.registerModule(module);
// 获取gpuhost主机
ServerNameDO serverNameDO = serverNameMapper.selectById(baseModelDO.getGpuId());
List<AigcModelDeployVO> aigcModelDeployVOS = mapper.readValue(res,new TypeReference<List<AigcModelDeployVO>>() {});
if (!aigcModelDeployVOS.isEmpty()) {
AigcModelDeployVO latestRecord = aigcModelDeployVOS.get(0);
String status = latestRecord.getStatus();
if(status.equals("stop")){
// 获取所有基础模型列表
log.debug("正在查询所有基础模型列表...");
List<BaseModelDO> baseModelList = baseModelService.getBaseModelList();
log.info("成功查询到 {} 个基础模型。", baseModelList.size());
AigcModelDeploySaveReq aigcModelDeploySaveReq = new AigcModelDeploySaveReq(baseModelDO.getAigcModelName(),
"gpu");
ModelDeployRespVO modelDeployRespVO = trainHttpService.modelDeploy(new HashMap<>(),serverNameDO.getHost(), aigcModelDeploySaveReq);
if (!modelDeployRespVO.getMessage().equals("error")) {
BaseModelSaveReqVO baseModelSaveReqVO = new BaseModelSaveReqVO();
baseModelSaveReqVO.setId(baseModelDO.getId());
baseModelSaveReqVO.setModelId(modelDeployRespVO.getId());
baseModelSaveReqVO.setChatUrl(modelDeployRespVO.getPort() + DEFAULT_MODEL_URL_SUFFIX);
baseModelService.updateBaseModel(new BaseModelSaveReqVO());
// 遍历每个基础模型
for (BaseModelDO baseModelDO : baseModelList) {
Long modelId = baseModelDO.getModelId();
Long gpuId = baseModelDO.getGpuId();
log.debug("正在处理基础模型模型ID: {}, GPU ID: {}", modelId, gpuId);
// 查询 GPU 服务器信息
log.debug("正在查询 GPU 服务器信息GPU ID: {}", gpuId);
ServerNameDO serverName = serverNameMapper.selectById(gpuId);
if (serverName == null) {
log.error("未找到 GPU 服务器信息GPU ID: {}", gpuId);
continue;
}
log.debug("GPU 服务器信息查询成功。主机地址: {}", serverName.getHost());
// 构建查询参数并查询模型部署信息
String query = "?filter={\"id\":" + modelId + "}";
log.debug("正在查询模型部署信息,查询参数: {}", query);
String res = trainHttpService.modelTableQuery(new HashMap<>(), serverName.getHost(), "model_deploy", query);
log.debug("模型部署信息查询成功。响应内容: {}", res);
// 解析响应内容
log.debug("正在解析模型部署信息...");
ObjectMapper mapper = new ObjectMapper();
mapper.registerModule(new JavaTimeModule());
SimpleModule module = new SimpleModule();
module.addDeserializer(LocalDateTime.class, new AigcCustomDateTimeDeserializer());
mapper.registerModule(module);
List<AigcModelDeployVO> aigcModelDeploys = mapper.readValue(res, new TypeReference<List<AigcModelDeployVO>>() {});
log.debug("模型部署信息解析完成。记录数量: {}", aigcModelDeploys.size());
if (!aigcModelDeploys.isEmpty()) {
AigcModelDeployVO latestRecord = aigcModelDeploys.get(0);
String status = latestRecord.getStatus();
log.debug("最新模型部署记录状态: {}", status);
// 如果模型状态为 "stop"则重新部署
if ("stop".equals(status)) {
log.info("模型状态为 'stop',正在重新部署模型...");
// 构建模型部署请求
AigcModelDeploySaveReq aigcModelDeploySaveReq = new AigcModelDeploySaveReq(
baseModelDO.getAigcModelName(), "gpu");
log.debug("模型部署请求参数: {}", JSON.toJSONString(aigcModelDeploySaveReq));
// 发起模型部署请求
ModelDeployRespVO modelDeployRespVO = trainHttpService.modelDeploy(
new HashMap<>(), serverName.getHost(), aigcModelDeploySaveReq);
log.debug("模型部署请求完成。响应内容: {}", JSON.toJSONString(modelDeployRespVO));
// 更新基础模型信息
if (!"error".equals(modelDeployRespVO.getMessage())) {
log.info("模型部署成功。正在更新基础模型信息...");
BaseModelSaveReqVO baseModelSaveReqVO = new BaseModelSaveReqVO();
baseModelSaveReqVO.setId(baseModelDO.getId());
baseModelSaveReqVO.setModelId(modelDeployRespVO.getId());
baseModelSaveReqVO.setChatUrl(modelDeployRespVO.getPort() + DEFAULT_MODEL_URL_SUFFIX);
baseModelService.updateBaseModel(baseModelSaveReqVO);
log.info("基础模型信息更新完成。模型ID: {}", baseModelDO.getId());
} else {
log.error("模型部署失败。模型ID: {}", baseModelDO.getId());
}
}
} else {
log.warn("未找到模型部署记录。模型ID: {}", modelId);
}
}
log.info("基础模型信息同步完成。");
} catch (Exception e) {
log.error("同步基础模型信息时发生异常。", e);
throw e;
}
}
// @Scheduled(cron ="0 0/1 * * * ?")
public void updateBaseModel() {
Log.info("定时任务启动");

View File

@ -263,6 +263,10 @@ public class FineTuningTaskServiceImpl implements FineTuningTaskService {
.map(FineTuningTaskDO::getGpuType)
.filter(Objects::nonNull)
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(gpuTypeIds)){
return respVOS;
}
List<ServerNameDO> serverNameDOS = serverNameMapper.selectList(new LambdaQueryWrapper<ServerNameDO>()
.in(ServerNameDO::getId, gpuTypeIds));
Map<Long, ServerNameDO> longServerNameDOMap = cn.iocoder.yudao.framework.common.util.collection.CollectionUtils

View File

@ -157,8 +157,8 @@ public class TrainHttpService {
try {
// 记录请求信息
log.info("开始创建微调任务请求URL: {}", url + llmBackendProperties.getFinetuningCreate());
log.debug("请求头: {}", headers);
log.debug("请求体: {}", JSON.toJSONString(req));
log.info("请求头: {}", headers);
log.info("请求体: {}", JSON.toJSONString(req));
// 发起 HTTP 请求
log.debug("正在发起 HTTP POST 请求...");