From 24ddfe0264b19015908ffbff530a5e22b52ec7e2 Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Fri, 14 Mar 2025 18:06:19 +0800 Subject: [PATCH] =?UTF-8?q?feat(llm):=20=E6=B7=BB=E5=8A=A0=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=90=8D=E7=A7=B0=E4=B8=8E=E5=BE=AE=E8=B0=83=E7=8A=B6?= =?UTF-8?q?=E6=80=81=E6=9B=B4=E6=96=B0=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 replaceActiveGroups 方法,用于处理模型名称中的 active 分组 - 在定时任务中调用该方法,更新基础模型的微调状态 - 优化了代码结构,提高了可读性和可维护性 --- .../basemodel/BaseModelTaskService.java | 184 ++++++++++++------ 1 file changed, 129 insertions(+), 55 deletions(-) diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/basemodel/BaseModelTaskService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/basemodel/BaseModelTaskService.java index e3e495663..0af1f2f5a 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/basemodel/BaseModelTaskService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/basemodel/BaseModelTaskService.java @@ -15,9 +15,7 @@ import cn.iocoder.yudao.module.llm.service.basemodel.vo.ModelListRes; import cn.iocoder.yudao.module.llm.service.basemodel.vo.PedestalModelVO; import cn.iocoder.yudao.module.llm.service.http.FineTuningTaskHttpService; import cn.iocoder.yudao.module.llm.service.http.TrainHttpService; -import cn.iocoder.yudao.module.llm.service.http.vo.AigcModelDeploySaveReq; import cn.iocoder.yudao.module.llm.service.http.vo.AigcModelDeployVO; -import cn.iocoder.yudao.module.llm.service.http.vo.ModelDeployRespVO; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; @@ -28,6 +26,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.module.SimpleModule; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Component; @@ -36,6 +35,8 @@ import java.time.LocalDateTime; import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Collectors; @Component @@ -57,10 +58,12 @@ public class BaseModelTaskService { private FineTuningTaskHttpService fineTuningTaskHttpService; @Resource private LLMBackendProperties llmBackendProperties; + @Value("${spring.profiles.active}") + private String active; // 减少维护 先注释掉 @Scheduled(cron = "0 0/1 * * * ?") - public void synchronous() throws JsonProcessingException { + public void synchronous () throws JsonProcessingException { try { log.info("开始同步基础模型信息..."); @@ -106,33 +109,33 @@ public class BaseModelTaskService { String status = latestRecord.getStatus(); log.info("最新模型部署记录状态: {}", status); -// // 如果模型状态为 "stop",则重新部署 -// if ("stop".equals(status)) { -// log.info("模型状态为 'stop',正在重新部署模型..."); -// -// // 构建模型部署请求 -// AigcModelDeploySaveReq aigcModelDeploySaveReq = new AigcModelDeploySaveReq( -// baseModelDO.getAigcModelName(), "gpu"); -// log.info("模型部署请求参数: {}", JSON.toJSONString(aigcModelDeploySaveReq)); -// -// // 发起模型部署请求 -// ModelDeployRespVO modelDeployRespVO = trainHttpService.modelDeploy( -// new HashMap<>(), serverName.getHost(), aigcModelDeploySaveReq); -// log.info("模型部署请求完成。响应内容: {}", JSON.toJSONString(modelDeployRespVO)); -// -// // 更新基础模型信息 -// if (!"error".equals(modelDeployRespVO.getMessage())) { -// log.info("模型部署成功。正在更新基础模型信息..."); -// BaseModelSaveReqVO baseModelSaveReqVO = new BaseModelSaveReqVO(); -// baseModelSaveReqVO.setId(baseModelDO.getId()); -// baseModelSaveReqVO.setModelId(modelDeployRespVO.getId()); -// baseModelSaveReqVO.setChatUrl(modelDeployRespVO.getPort() + DEFAULT_MODEL_URL_SUFFIX); -// baseModelService.updateBaseModel(baseModelSaveReqVO); -// log.info("基础模型信息更新完成。模型ID: {}", baseModelDO.getId()); -// } else { -// log.error("模型部署失败。模型ID: {}", baseModelDO.getId()); -// } -// } + // // 如果模型状态为 "stop",则重新部署 + // if ("stop".equals(status)) { + // log.info("模型状态为 'stop',正在重新部署模型..."); + // + // // 构建模型部署请求 + // AigcModelDeploySaveReq aigcModelDeploySaveReq = new AigcModelDeploySaveReq( + // baseModelDO.getAigcModelName(), "gpu"); + // log.info("模型部署请求参数: {}", JSON.toJSONString(aigcModelDeploySaveReq)); + // + // // 发起模型部署请求 + // ModelDeployRespVO modelDeployRespVO = trainHttpService.modelDeploy( + // new HashMap<>(), serverName.getHost(), aigcModelDeploySaveReq); + // log.info("模型部署请求完成。响应内容: {}", JSON.toJSONString(modelDeployRespVO)); + // + // // 更新基础模型信息 + // if (!"error".equals(modelDeployRespVO.getMessage())) { + // log.info("模型部署成功。正在更新基础模型信息..."); + // BaseModelSaveReqVO baseModelSaveReqVO = new BaseModelSaveReqVO(); + // baseModelSaveReqVO.setId(baseModelDO.getId()); + // baseModelSaveReqVO.setModelId(modelDeployRespVO.getId()); + // baseModelSaveReqVO.setChatUrl(modelDeployRespVO.getPort() + DEFAULT_MODEL_URL_SUFFIX); + // baseModelService.updateBaseModel(baseModelSaveReqVO); + // log.info("基础模型信息更新完成。模型ID: {}", baseModelDO.getId()); + // } else { + // log.error("模型部署失败。模型ID: {}", baseModelDO.getId()); + // } + // } } else { log.warn("未找到模型部署记录。模型ID: {}", modelId); } @@ -147,7 +150,7 @@ public class BaseModelTaskService { // @Scheduled(cron ="0 0/1 * * * ?") - public void updateBaseModel() { + public void updateBaseModel () { Log.info("定时任务启动"); String resStr = trainHttpService.modelsList(""); Log.info("获取aicg模型列表返回数据内容{}", resStr); @@ -200,7 +203,7 @@ public class BaseModelTaskService { } @Scheduled(cron = "0 0/5 * * * ?") - public void updateTheBaseModelState() { + public void updateTheBaseModelState () { try { // 获取所有基础模型列表 List baseModelList = baseModelService.getAllModels(); @@ -225,7 +228,7 @@ public class BaseModelTaskService { List differentModels = baseModelList.stream() .filter(baseModel -> !remoteModelNames.contains(baseModel.getModelName())) .collect(Collectors.toList()); - if (differentModels.size()>0) { + if (differentModels.size() > 0) { baseModelService.deletebyIds(differentModels); } @@ -233,10 +236,13 @@ public class BaseModelTaskService { List uniqueRemoteModelNames = remoteModelNames.stream() .filter(remoteModelName -> !differentModelNames.contains(remoteModelName)) .collect(Collectors.toList()); - if (uniqueRemoteModelNames.size()>0) { + if (uniqueRemoteModelNames.size() > 0) { for (String remoteModelName : uniqueRemoteModelNames) { BaseModelDO baseModelDO = new BaseModelDO(); baseModelDO.setModelName(remoteModelName); + // 模型类型 + // 微调状态 + baseModelDO.setIsFinetuned(replaceActiveGroups(remoteModelName, active)); baseModelMapper.insert(baseModelDO); } } @@ -246,8 +252,8 @@ public class BaseModelTaskService { } -// @Scheduled(cron = "0 0/1 * * * ?") - public void refreshTheModelService() { + // @Scheduled(cron = "0 0/1 * * * ?") + public void refreshTheModelService () { try { // 获取所有基础模型列表 List modelServiceDOS = modelServiceMapper.selectList(); @@ -269,20 +275,20 @@ public class BaseModelTaskService { List differentModels = modelServiceDOS.stream() .filter(baseModel -> !remoteModelNames.contains(baseModel.getBaseModelName()) && (baseModel.getStatus() == 4 || baseModel.getStatus() == 2)) .collect(Collectors.toList()); - for (ModelServiceDO baseModel : differentModels){ + for (ModelServiceDO baseModel : differentModels) { baseModel.setStatus(3); } - if (differentModels.size()>0) { + if (differentModels.size() > 0) { modelServiceMapper.updateById(differentModels); } for (String name : remoteModelNames) { -// JSONObject jsonObject = (JSONObject) object; -// String string = JSON.toJSONString(jsonObject); -// PedestalModelVO pedestalModelVo = JSON.parseObject(string, PedestalModelVO.class); + // JSONObject jsonObject = (JSONObject) object; + // String string = JSON.toJSONString(jsonObject); + // PedestalModelVO pedestalModelVo = JSON.parseObject(string, PedestalModelVO.class); List collect = modelListRes.stream() .filter(pedestalModelVO -> pedestalModelVO.getDeploymentName() - .equals(name) && "running" - .equals(pedestalModelVO.getStatus())) + .equals(name) && "running" + .equals(pedestalModelVO.getStatus())) .collect(Collectors.toList()); if (collect.size() > 0) { PedestalModelVO pedestalModelVo = collect.get(0); @@ -303,23 +309,23 @@ public class BaseModelTaskService { String string1 = pedestalModelVo.getHost() + "/v1/chat/completions"; localModel.setStatus(4); localModel.setModelUrl(string1); -// localModel.setApiUrl(string1); + // localModel.setApiUrl(string1); localModel.setJobId((long) pedestalModelVo.getId()); modelServiceMapper.updateById(localModel); log.info("模型 {} 状态为 running,无需更新", pedestalModelVo.getDeploymentName()); } } -// else { -// //新增基座模型 -// if ("running".equals(pedestalModelVo.getStatus())){ -// BaseModelSaveReqVO baseModelSaveReqVO = new BaseModelSaveReqVO(); -// baseModelSaveReqVO.setModelName(pedestalModelVo.getDeploymentName()); -// baseModelSaveReqVO.setIsActive(1); -// baseModelSaveReqVO.setAigcModelName(pedestalModelVo.getDeploymentName()); -// baseModelSaveReqVO.setChatUrl(pedestalModelVo.getHost() + "/v1/chat/completions"); -// baseModelService.createBaseModel(baseModelSaveReqVO); -// } -// } + // else { + // //新增基座模型 + // if ("running".equals(pedestalModelVo.getStatus())){ + // BaseModelSaveReqVO baseModelSaveReqVO = new BaseModelSaveReqVO(); + // baseModelSaveReqVO.setModelName(pedestalModelVo.getDeploymentName()); + // baseModelSaveReqVO.setIsActive(1); + // baseModelSaveReqVO.setAigcModelName(pedestalModelVo.getDeploymentName()); + // baseModelSaveReqVO.setChatUrl(pedestalModelVo.getHost() + "/v1/chat/completions"); + // baseModelService.createBaseModel(baseModelSaveReqVO); + // } + // } } else { List localModels = modelServiceDOS.stream() .filter(baseModel -> name.equals(baseModel.getBaseModelName())) @@ -338,4 +344,72 @@ public class BaseModelTaskService { } } + + private boolean replaceActiveGroups (String remoteModelName, String active) { + // 定义正则表达式,匹配 -active-数字-数字 这样的模式 + String regex = "(-" + active + "-\\d+-\\d+)"; + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(remoteModelName); + + // 用于存储替换后的字符串 + StringBuilder result = new StringBuilder(remoteModelName); + int groupCount = 0; + + // 查找所有匹配的分组并替换 + while (matcher.find()) { + groupCount++; + // 替换匹配的分组为 ${groupCount} + result.replace(matcher.start(), matcher.end(), "${" + groupCount + "}"); + // 重置 Matcher 的位置,因为字符串已被修改 + matcher.reset(result.toString()); + } + + // 如果找到匹配的分组,输出结果并返回 true + if (groupCount > 0) { + log.info("Modified string: {}", result); + return true; + } else { + log.info("No matching groups found."); + return false; + } + } + + + // public static void main (String[] args) { + // String s="Qwen2.5-0.5B-Instruct-dev-64-1"; + // + // String active="dev"; + // // 找到第一个 active 后的两个下划线,如-64-1,然后将 -dev-64-1 分为一组, + // // 然后将原始字符串变为 "Qwen2.5-0.5B-Instruct-${1}",有几个分组就命名几个 如Qwen2.5-0.5B-Instruct-dev-64-1-dev-64-1 变为Qwen2.5-0.5B-Instruct-${1}-${2} + // } + + public static void main (String[] args) { + String s = "Qwen2.5-0.5B-Instruct"; + String active = "dev"; + + // 定义正则表达式,匹配 -dev-64-1 这样的模式 + String regex = "(-" + active + "-\\d+-\\d+)"; + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(s); + + // 用于存储替换后的字符串 + StringBuilder result = new StringBuilder(s); + int groupCount = 0; + + // 查找所有匹配的分组并替换 + while (matcher.find()) { + groupCount++; + // 替换匹配的分组为 ${groupCount} + result.replace(matcher.start(), matcher.end(), "${" + groupCount + "}"); + // 重置 Matcher 的位置,因为字符串已被修改 + matcher.reset(result.toString()); + } + + // 输出结果 + if (groupCount > 0) { + System.out.println("Modified string: " + result); + } else { + System.out.println("No matching groups found."); + } + } }