feat(module-llm):增加根据 GPU 类型获取模型下载列表功能

- 新增 getHostByType 方法,通过 GPU 类型获取对应的主机地址
- 修改 getFileList 方法,增加 GPU 类型参数,使用 getHostByType 获取主机地址
- 更新 getDownLoadList 方法,传入 GPU 类型到 getFileList 方法
This commit is contained in:
Liuyang 2025-02-18 11:22:04 +08:00
parent 2766385bcb
commit 9bbe8fa8ff
4 changed files with 39 additions and 23 deletions

View File

@ -4,20 +4,18 @@ import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.ChatApiReqVO
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.ChatReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.ChatRespVO;
import cn.iocoder.yudao.module.llm.controller.admin.finetuningtask.vo.FineTuningTaskRespVO;
import cn.iocoder.yudao.module.llm.dal.dataobject.application.ApplicationDO;
import cn.iocoder.yudao.module.llm.service.conversation.ConversationService;
import cn.iocoder.yudao.module.llm.service.finetuningtask.FineTuningTaskService;
import cn.iocoder.yudao.module.llm.service.http.ModelService;
import org.springframework.context.annotation.Lazy;
import org.springframework.web.bind.annotation.*;
import javax.annotation.Resource;
import org.springframework.validation.annotation.Validated;
import org.springframework.security.access.prepost.PreAuthorize;
import io.swagger.v3.oas.annotations.tags.Tag;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.Operation;
import javax.annotation.security.PermitAll;
import javax.validation.constraints.*;
import javax.validation.*;
import javax.servlet.http.*;
import java.util.*;

View File

@ -42,9 +42,8 @@ public class ModelServiceDO extends BaseDO {
*/
private Long checkPoint;
/**
* GPU使用字典llm_gpu_type
*
* 枚举 {@link TODO llm_gpu_type 对应的类}
* GPU
* <p>
*/
private Long gpuType;
/**

View File

@ -4,11 +4,10 @@ import cn.iocoder.yudao.framework.common.util.http.HttpUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.llm.framework.backend.config.LLMBackendProperties;
import cn.iocoder.yudao.module.llm.service.http.vo.*;
import cn.iocoder.yudao.module.llm.service.servername.ServerNameService;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;
@ -33,10 +32,13 @@ public class ModelService {
@Resource
private TrainHttpService trainHttpService;
@Resource
private ServerNameService serverNameService;
/**
* 获取模型列表
*/
public ModelListRespVO modelList(){
public ModelListRespVO modelList () {
String result = HttpUtils.get(llmBackendProperties.getModelCompletions(), null);
if (StringUtils.isBlank(result)) {
return null;
@ -57,20 +59,21 @@ public class ModelService {
/**
* 对话聊天
*
* @param req
* @return
*/
public ModelCompletionsRespVO modelCompletions(String url,ModelCompletionsReqVO req) {
public ModelCompletionsRespVO modelCompletions (String url, ModelCompletionsReqVO req) {
if (StringUtils.isBlank(req.getModel())) {
req.setModel(DEFAULT_MODEL_ID);
}
log.info("request: {}", req);
String result;
if (StringUtils.isBlank(url)){
if (StringUtils.isBlank(url)) {
log.info("url: {}", llmBackendProperties.getModelCompletions());
result = HttpUtils.post(llmBackendProperties.getModelCompletions(), null, JSON.toJSONString(req));
}else {
} else {
log.info("url: {}", url);
result = HttpUtils.post(url, null, JSON.toJSONString(req));
}
@ -90,15 +93,28 @@ public class ModelService {
}
return null;
} catch (Exception e) {
throw new RuntimeException(e);
throw new RuntimeException(e);
}
}
/**
* 通过 GPU 类型获取对应的主机地址
* <p>ModelService GpuType ServerName 表ID
*
* @param gpuType gpuType
* @return host
*/
public String getHostByType (Long gpuType) {
return serverNameService.getServerName(gpuType).getHost();
}
/**
* 获取模型下载列表
*/
public List<String> getFileList(String fileName){
String url = llmBackendProperties.getModelFileList() + fileName;
public List<String> getFileList (Long gpuType, String fileName) {
String baseUrl = getHostByType(gpuType);
String url = baseUrl + llmBackendProperties.getModelFileList() + fileName;
String res = HttpUtils.get(url, null);
log.info(" getFileList:{}", res);
try {
@ -109,8 +125,8 @@ public class ModelService {
return null;
}
public ModelCompletionsRespVO modelPrivateCompletions(Map<String, String> headers,ModelCompletionsReqVO req) {
trainHttpService.login(headers);
public ModelCompletionsRespVO modelPrivateCompletions (Map<String, String> headers, ModelCompletionsReqVO req) {
trainHttpService.login(headers);
if (StringUtils.isBlank(req.getModel())) {
req.setModel(PRIVATE_MODEL_ID);
}
@ -148,18 +164,18 @@ public class ModelService {
}
}
public TextToImageRespVo textToImage(TextToImageReqVo req) {
public TextToImageRespVo textToImage (TextToImageReqVo req) {
log.info("url: {}", llmBackendProperties.getTextToImage());
log.info("request: {}", req);
String result = HttpUtils.post(llmBackendProperties.getTextToImage(),new HashMap<>(), JSON.toJSONString(req));
String result = HttpUtils.post(llmBackendProperties.getTextToImage(), new HashMap<>(), JSON.toJSONString(req));
log.info("response: {}", result);
if (StringUtils.isBlank(result)) {
return null;
}
try {
return JSON.parseObject(result, TextToImageRespVo.class);
}catch (Exception e){
log.error("text to image error : {}",e.getMessage());
} catch (Exception e) {
log.error("text to image error : {}", e.getMessage());
}
return null;
}

View File

@ -307,9 +307,11 @@ public class ModelServiceServiceImpl implements ModelServiceService {
@Override
public List<String> getDownLoadList (Long id) {
ModelServiceDO modelServiceDO = modelServiceMapper.selectById(id);
String baseModelName = modelServiceDO.getBaseModelName();
List<String> fileList = modelService.getFileList(baseModelName);
String baseModelName = modelServiceDO.getBaseModelName();
Long type = modelServiceDO.getGpuType();
List<String> fileList = modelService.getFileList(type, baseModelName);
String modelFileDownload = llmBackendProperties.getModelFileDownload();
List<String> res = new ArrayList<>();
if (fileList != null) {
@ -321,6 +323,7 @@ public class ModelServiceServiceImpl implements ModelServiceService {
res.add(modelFileDownload + baseModelName + "/" + fileName);
}
}
return res;
}