feat(llm): 添加模型名称与微调状态更新功能

- 新增 replaceActiveGroups 方法,用于处理模型名称中的 active 分组
- 在定时任务中调用该方法,更新基础模型的微调状态
- 优化了代码结构,提高了可读性和可维护性
This commit is contained in:
Liuyang 2025-03-14 18:06:19 +08:00
parent 60ff7ace82
commit 24ddfe0264

View File

@ -15,9 +15,7 @@ import cn.iocoder.yudao.module.llm.service.basemodel.vo.ModelListRes;
import cn.iocoder.yudao.module.llm.service.basemodel.vo.PedestalModelVO;
import cn.iocoder.yudao.module.llm.service.http.FineTuningTaskHttpService;
import cn.iocoder.yudao.module.llm.service.http.TrainHttpService;
import cn.iocoder.yudao.module.llm.service.http.vo.AigcModelDeploySaveReq;
import cn.iocoder.yudao.module.llm.service.http.vo.AigcModelDeployVO;
import cn.iocoder.yudao.module.llm.service.http.vo.ModelDeployRespVO;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
@ -28,6 +26,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
@ -36,6 +35,8 @@ import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
@Component
@ -57,10 +58,12 @@ public class BaseModelTaskService {
private FineTuningTaskHttpService fineTuningTaskHttpService;
@Resource
private LLMBackendProperties llmBackendProperties;
@Value("${spring.profiles.active}")
private String active;
// 减少维护 先注释掉
@Scheduled(cron = "0 0/1 * * * ?")
public void synchronous() throws JsonProcessingException {
public void synchronous () throws JsonProcessingException {
try {
log.info("开始同步基础模型信息...");
@ -106,33 +109,33 @@ public class BaseModelTaskService {
String status = latestRecord.getStatus();
log.info("最新模型部署记录状态: {}", status);
// // 如果模型状态为 "stop"则重新部署
// if ("stop".equals(status)) {
// log.info("模型状态为 'stop',正在重新部署模型...");
//
// // 构建模型部署请求
// AigcModelDeploySaveReq aigcModelDeploySaveReq = new AigcModelDeploySaveReq(
// baseModelDO.getAigcModelName(), "gpu");
// log.info("模型部署请求参数: {}", JSON.toJSONString(aigcModelDeploySaveReq));
//
// // 发起模型部署请求
// ModelDeployRespVO modelDeployRespVO = trainHttpService.modelDeploy(
// new HashMap<>(), serverName.getHost(), aigcModelDeploySaveReq);
// log.info("模型部署请求完成。响应内容: {}", JSON.toJSONString(modelDeployRespVO));
//
// // 更新基础模型信息
// if (!"error".equals(modelDeployRespVO.getMessage())) {
// log.info("模型部署成功。正在更新基础模型信息...");
// BaseModelSaveReqVO baseModelSaveReqVO = new BaseModelSaveReqVO();
// baseModelSaveReqVO.setId(baseModelDO.getId());
// baseModelSaveReqVO.setModelId(modelDeployRespVO.getId());
// baseModelSaveReqVO.setChatUrl(modelDeployRespVO.getPort() + DEFAULT_MODEL_URL_SUFFIX);
// baseModelService.updateBaseModel(baseModelSaveReqVO);
// log.info("基础模型信息更新完成。模型ID: {}", baseModelDO.getId());
// } else {
// log.error("模型部署失败。模型ID: {}", baseModelDO.getId());
// }
// }
// // 如果模型状态为 "stop"则重新部署
// if ("stop".equals(status)) {
// log.info("模型状态为 'stop',正在重新部署模型...");
//
// // 构建模型部署请求
// AigcModelDeploySaveReq aigcModelDeploySaveReq = new AigcModelDeploySaveReq(
// baseModelDO.getAigcModelName(), "gpu");
// log.info("模型部署请求参数: {}", JSON.toJSONString(aigcModelDeploySaveReq));
//
// // 发起模型部署请求
// ModelDeployRespVO modelDeployRespVO = trainHttpService.modelDeploy(
// new HashMap<>(), serverName.getHost(), aigcModelDeploySaveReq);
// log.info("模型部署请求完成。响应内容: {}", JSON.toJSONString(modelDeployRespVO));
//
// // 更新基础模型信息
// if (!"error".equals(modelDeployRespVO.getMessage())) {
// log.info("模型部署成功。正在更新基础模型信息...");
// BaseModelSaveReqVO baseModelSaveReqVO = new BaseModelSaveReqVO();
// baseModelSaveReqVO.setId(baseModelDO.getId());
// baseModelSaveReqVO.setModelId(modelDeployRespVO.getId());
// baseModelSaveReqVO.setChatUrl(modelDeployRespVO.getPort() + DEFAULT_MODEL_URL_SUFFIX);
// baseModelService.updateBaseModel(baseModelSaveReqVO);
// log.info("基础模型信息更新完成。模型ID: {}", baseModelDO.getId());
// } else {
// log.error("模型部署失败。模型ID: {}", baseModelDO.getId());
// }
// }
} else {
log.warn("未找到模型部署记录。模型ID: {}", modelId);
}
@ -147,7 +150,7 @@ public class BaseModelTaskService {
// @Scheduled(cron ="0 0/1 * * * ?")
public void updateBaseModel() {
public void updateBaseModel () {
Log.info("定时任务启动");
String resStr = trainHttpService.modelsList("");
Log.info("获取aicg模型列表返回数据内容{}", resStr);
@ -200,7 +203,7 @@ public class BaseModelTaskService {
}
@Scheduled(cron = "0 0/5 * * * ?")
public void updateTheBaseModelState() {
public void updateTheBaseModelState () {
try {
// 获取所有基础模型列表
List<BaseModelDO> baseModelList = baseModelService.getAllModels();
@ -225,7 +228,7 @@ public class BaseModelTaskService {
List<BaseModelDO> differentModels = baseModelList.stream()
.filter(baseModel -> !remoteModelNames.contains(baseModel.getModelName()))
.collect(Collectors.toList());
if (differentModels.size()>0) {
if (differentModels.size() > 0) {
baseModelService.deletebyIds(differentModels);
}
@ -233,10 +236,13 @@ public class BaseModelTaskService {
List<String> uniqueRemoteModelNames = remoteModelNames.stream()
.filter(remoteModelName -> !differentModelNames.contains(remoteModelName))
.collect(Collectors.toList());
if (uniqueRemoteModelNames.size()>0) {
if (uniqueRemoteModelNames.size() > 0) {
for (String remoteModelName : uniqueRemoteModelNames) {
BaseModelDO baseModelDO = new BaseModelDO();
baseModelDO.setModelName(remoteModelName);
// 模型类型
// 微调状态
baseModelDO.setIsFinetuned(replaceActiveGroups(remoteModelName, active));
baseModelMapper.insert(baseModelDO);
}
}
@ -246,8 +252,8 @@ public class BaseModelTaskService {
}
// @Scheduled(cron = "0 0/1 * * * ?")
public void refreshTheModelService() {
// @Scheduled(cron = "0 0/1 * * * ?")
public void refreshTheModelService () {
try {
// 获取所有基础模型列表
List<ModelServiceDO> modelServiceDOS = modelServiceMapper.selectList();
@ -269,20 +275,20 @@ public class BaseModelTaskService {
List<ModelServiceDO> differentModels = modelServiceDOS.stream()
.filter(baseModel -> !remoteModelNames.contains(baseModel.getBaseModelName()) && (baseModel.getStatus() == 4 || baseModel.getStatus() == 2))
.collect(Collectors.toList());
for (ModelServiceDO baseModel : differentModels){
for (ModelServiceDO baseModel : differentModels) {
baseModel.setStatus(3);
}
if (differentModels.size()>0) {
if (differentModels.size() > 0) {
modelServiceMapper.updateById(differentModels);
}
for (String name : remoteModelNames) {
// JSONObject jsonObject = (JSONObject) object;
// String string = JSON.toJSONString(jsonObject);
// PedestalModelVO pedestalModelVo = JSON.parseObject(string, PedestalModelVO.class);
// JSONObject jsonObject = (JSONObject) object;
// String string = JSON.toJSONString(jsonObject);
// PedestalModelVO pedestalModelVo = JSON.parseObject(string, PedestalModelVO.class);
List<PedestalModelVO> collect = modelListRes.stream()
.filter(pedestalModelVO -> pedestalModelVO.getDeploymentName()
.equals(name) && "running"
.equals(pedestalModelVO.getStatus()))
.equals(name) && "running"
.equals(pedestalModelVO.getStatus()))
.collect(Collectors.toList());
if (collect.size() > 0) {
PedestalModelVO pedestalModelVo = collect.get(0);
@ -303,23 +309,23 @@ public class BaseModelTaskService {
String string1 = pedestalModelVo.getHost() + "/v1/chat/completions";
localModel.setStatus(4);
localModel.setModelUrl(string1);
// localModel.setApiUrl(string1);
// localModel.setApiUrl(string1);
localModel.setJobId((long) pedestalModelVo.getId());
modelServiceMapper.updateById(localModel);
log.info("模型 {} 状态为 running无需更新", pedestalModelVo.getDeploymentName());
}
}
// else {
// //新增基座模型
// if ("running".equals(pedestalModelVo.getStatus())){
// BaseModelSaveReqVO baseModelSaveReqVO = new BaseModelSaveReqVO();
// baseModelSaveReqVO.setModelName(pedestalModelVo.getDeploymentName());
// baseModelSaveReqVO.setIsActive(1);
// baseModelSaveReqVO.setAigcModelName(pedestalModelVo.getDeploymentName());
// baseModelSaveReqVO.setChatUrl(pedestalModelVo.getHost() + "/v1/chat/completions");
// baseModelService.createBaseModel(baseModelSaveReqVO);
// }
// }
// else {
// //新增基座模型
// if ("running".equals(pedestalModelVo.getStatus())){
// BaseModelSaveReqVO baseModelSaveReqVO = new BaseModelSaveReqVO();
// baseModelSaveReqVO.setModelName(pedestalModelVo.getDeploymentName());
// baseModelSaveReqVO.setIsActive(1);
// baseModelSaveReqVO.setAigcModelName(pedestalModelVo.getDeploymentName());
// baseModelSaveReqVO.setChatUrl(pedestalModelVo.getHost() + "/v1/chat/completions");
// baseModelService.createBaseModel(baseModelSaveReqVO);
// }
// }
} else {
List<ModelServiceDO> localModels = modelServiceDOS.stream()
.filter(baseModel -> name.equals(baseModel.getBaseModelName()))
@ -338,4 +344,72 @@ public class BaseModelTaskService {
}
}
private boolean replaceActiveGroups (String remoteModelName, String active) {
// 定义正则表达式匹配 -active-数字-数字 这样的模式
String regex = "(-" + active + "-\\d+-\\d+)";
Pattern pattern = Pattern.compile(regex);
Matcher matcher = pattern.matcher(remoteModelName);
// 用于存储替换后的字符串
StringBuilder result = new StringBuilder(remoteModelName);
int groupCount = 0;
// 查找所有匹配的分组并替换
while (matcher.find()) {
groupCount++;
// 替换匹配的分组为 ${groupCount}
result.replace(matcher.start(), matcher.end(), "${" + groupCount + "}");
// 重置 Matcher 的位置因为字符串已被修改
matcher.reset(result.toString());
}
// 如果找到匹配的分组输出结果并返回 true
if (groupCount > 0) {
log.info("Modified string: {}", result);
return true;
} else {
log.info("No matching groups found.");
return false;
}
}
// public static void main (String[] args) {
// String s="Qwen2.5-0.5B-Instruct-dev-64-1";
//
// String active="dev";
// // 找到第一个 active 后的两个下划线-64-1然后将 -dev-64-1 分为一组
// // 然后将原始字符串变为 "Qwen2.5-0.5B-Instruct-${1}",有几个分组就命名几个 如Qwen2.5-0.5B-Instruct-dev-64-1-dev-64-1 变为Qwen2.5-0.5B-Instruct-${1}-${2}
// }
public static void main (String[] args) {
String s = "Qwen2.5-0.5B-Instruct";
String active = "dev";
// 定义正则表达式匹配 -dev-64-1 这样的模式
String regex = "(-" + active + "-\\d+-\\d+)";
Pattern pattern = Pattern.compile(regex);
Matcher matcher = pattern.matcher(s);
// 用于存储替换后的字符串
StringBuilder result = new StringBuilder(s);
int groupCount = 0;
// 查找所有匹配的分组并替换
while (matcher.find()) {
groupCount++;
// 替换匹配的分组为 ${groupCount}
result.replace(matcher.start(), matcher.end(), "${" + groupCount + "}");
// 重置 Matcher 的位置因为字符串已被修改
matcher.reset(result.toString());
}
// 输出结果
if (groupCount > 0) {
System.out.println("Modified string: " + result);
} else {
System.out.println("No matching groups found.");
}
}
}