feat(module-llm):增加根据 GPU 类型获取模型下载列表功能
- 新增 getHostByType 方法,通过 GPU 类型获取对应的主机地址 - 修改 getFileList 方法,增加 GPU 类型参数,使用 getHostByType 获取主机地址 - 更新 getDownLoadList 方法,传入 GPU 类型到 getFileList 方法
This commit is contained in:
parent
2766385bcb
commit
9bbe8fa8ff
@ -4,20 +4,18 @@ import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.ChatApiReqVO
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.ChatReqVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.ChatRespVO;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.finetuningtask.vo.FineTuningTaskRespVO;
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.application.ApplicationDO;
|
||||
import cn.iocoder.yudao.module.llm.service.conversation.ConversationService;
|
||||
import cn.iocoder.yudao.module.llm.service.finetuningtask.FineTuningTaskService;
|
||||
import cn.iocoder.yudao.module.llm.service.http.ModelService;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import javax.annotation.Resource;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
import org.springframework.security.access.prepost.PreAuthorize;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
|
||||
import javax.annotation.security.PermitAll;
|
||||
import javax.validation.constraints.*;
|
||||
import javax.validation.*;
|
||||
import javax.servlet.http.*;
|
||||
import java.util.*;
|
||||
|
@ -42,9 +42,8 @@ public class ModelServiceDO extends BaseDO {
|
||||
*/
|
||||
private Long checkPoint;
|
||||
/**
|
||||
* GPU,使用字典(llm_gpu_type)
|
||||
*
|
||||
* 枚举 {@link TODO llm_gpu_type 对应的类}
|
||||
* GPU
|
||||
* <p>
|
||||
*/
|
||||
private Long gpuType;
|
||||
/**
|
||||
|
@ -4,11 +4,10 @@ import cn.iocoder.yudao.framework.common.util.http.HttpUtils;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.llm.framework.backend.config.LLMBackendProperties;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.*;
|
||||
import cn.iocoder.yudao.module.llm.service.servername.ServerNameService;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
@ -33,10 +32,13 @@ public class ModelService {
|
||||
@Resource
|
||||
private TrainHttpService trainHttpService;
|
||||
|
||||
@Resource
|
||||
private ServerNameService serverNameService;
|
||||
|
||||
/**
|
||||
* 获取模型列表
|
||||
*/
|
||||
public ModelListRespVO modelList(){
|
||||
public ModelListRespVO modelList () {
|
||||
String result = HttpUtils.get(llmBackendProperties.getModelCompletions(), null);
|
||||
if (StringUtils.isBlank(result)) {
|
||||
return null;
|
||||
@ -57,20 +59,21 @@ public class ModelService {
|
||||
|
||||
/**
|
||||
* 对话聊天
|
||||
*
|
||||
* @param req
|
||||
* @return
|
||||
*/
|
||||
public ModelCompletionsRespVO modelCompletions(String url,ModelCompletionsReqVO req) {
|
||||
public ModelCompletionsRespVO modelCompletions (String url, ModelCompletionsReqVO req) {
|
||||
if (StringUtils.isBlank(req.getModel())) {
|
||||
req.setModel(DEFAULT_MODEL_ID);
|
||||
}
|
||||
|
||||
log.info("request: {}", req);
|
||||
String result;
|
||||
if (StringUtils.isBlank(url)){
|
||||
if (StringUtils.isBlank(url)) {
|
||||
log.info("url: {}", llmBackendProperties.getModelCompletions());
|
||||
result = HttpUtils.post(llmBackendProperties.getModelCompletions(), null, JSON.toJSONString(req));
|
||||
}else {
|
||||
} else {
|
||||
log.info("url: {}", url);
|
||||
result = HttpUtils.post(url, null, JSON.toJSONString(req));
|
||||
}
|
||||
@ -90,15 +93,28 @@ public class ModelService {
|
||||
}
|
||||
return null;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过 GPU 类型获取对应的主机地址
|
||||
* <p>ModelService GpuType 是 ServerName 表ID
|
||||
*
|
||||
* @param gpuType gpuType
|
||||
* @return host
|
||||
*/
|
||||
public String getHostByType (Long gpuType) {
|
||||
return serverNameService.getServerName(gpuType).getHost();
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取模型下载列表
|
||||
*/
|
||||
public List<String> getFileList(String fileName){
|
||||
String url = llmBackendProperties.getModelFileList() + fileName;
|
||||
public List<String> getFileList (Long gpuType, String fileName) {
|
||||
String baseUrl = getHostByType(gpuType);
|
||||
String url = baseUrl + llmBackendProperties.getModelFileList() + fileName;
|
||||
|
||||
String res = HttpUtils.get(url, null);
|
||||
log.info(" getFileList:{}", res);
|
||||
try {
|
||||
@ -109,8 +125,8 @@ public class ModelService {
|
||||
return null;
|
||||
}
|
||||
|
||||
public ModelCompletionsRespVO modelPrivateCompletions(Map<String, String> headers,ModelCompletionsReqVO req) {
|
||||
trainHttpService.login(headers);
|
||||
public ModelCompletionsRespVO modelPrivateCompletions (Map<String, String> headers, ModelCompletionsReqVO req) {
|
||||
trainHttpService.login(headers);
|
||||
if (StringUtils.isBlank(req.getModel())) {
|
||||
req.setModel(PRIVATE_MODEL_ID);
|
||||
}
|
||||
@ -148,18 +164,18 @@ public class ModelService {
|
||||
}
|
||||
}
|
||||
|
||||
public TextToImageRespVo textToImage(TextToImageReqVo req) {
|
||||
public TextToImageRespVo textToImage (TextToImageReqVo req) {
|
||||
log.info("url: {}", llmBackendProperties.getTextToImage());
|
||||
log.info("request: {}", req);
|
||||
String result = HttpUtils.post(llmBackendProperties.getTextToImage(),new HashMap<>(), JSON.toJSONString(req));
|
||||
String result = HttpUtils.post(llmBackendProperties.getTextToImage(), new HashMap<>(), JSON.toJSONString(req));
|
||||
log.info("response: {}", result);
|
||||
if (StringUtils.isBlank(result)) {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
return JSON.parseObject(result, TextToImageRespVo.class);
|
||||
}catch (Exception e){
|
||||
log.error("text to image error : {}",e.getMessage());
|
||||
} catch (Exception e) {
|
||||
log.error("text to image error : {}", e.getMessage());
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
@ -307,9 +307,11 @@ public class ModelServiceServiceImpl implements ModelServiceService {
|
||||
@Override
|
||||
public List<String> getDownLoadList (Long id) {
|
||||
ModelServiceDO modelServiceDO = modelServiceMapper.selectById(id);
|
||||
String baseModelName = modelServiceDO.getBaseModelName();
|
||||
|
||||
List<String> fileList = modelService.getFileList(baseModelName);
|
||||
String baseModelName = modelServiceDO.getBaseModelName();
|
||||
Long type = modelServiceDO.getGpuType();
|
||||
|
||||
List<String> fileList = modelService.getFileList(type, baseModelName);
|
||||
String modelFileDownload = llmBackendProperties.getModelFileDownload();
|
||||
List<String> res = new ArrayList<>();
|
||||
if (fileList != null) {
|
||||
@ -321,6 +323,7 @@ public class ModelServiceServiceImpl implements ModelServiceService {
|
||||
res.add(modelFileDownload + baseModelName + "/" + fileName);
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user