Merge branch 'master' of https://codeup.aliyun.com/63736f52e9565f4348a4cd42/xnjz-ai/xhllm
This commit is contained in:
commit
d37cfba9f4
@ -40,7 +40,10 @@
|
||||
<groupId>cn.iocoder.boot</groupId>
|
||||
<artifactId>yudao-spring-boot-starter-excel</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-websocket</artifactId>
|
||||
</dependency>
|
||||
<!-- DB 相关 -->
|
||||
<dependency>
|
||||
<groupId>cn.iocoder.boot</groupId>
|
||||
@ -125,6 +128,10 @@
|
||||
<artifactId>poi-scratchpad</artifactId>
|
||||
<version>5.2.3</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-webflux</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
|
@ -2,8 +2,12 @@ package cn.iocoder.yudao.module.llm.controller.admin.conversation;
|
||||
|
||||
import cn.hutool.json.JSONArray;
|
||||
import cn.hutool.json.JSONObject;
|
||||
import cn.iocoder.yudao.framework.common.exception.ServiceException;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.ModelCompletionsReqVO;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.TextToImageReqVo;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import javax.annotation.Resource;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
@ -15,8 +19,10 @@ import io.swagger.v3.oas.annotations.Operation;
|
||||
import javax.validation.constraints.*;
|
||||
import javax.validation.*;
|
||||
import javax.servlet.http.*;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.*;
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
@ -32,11 +38,13 @@ import static cn.iocoder.yudao.framework.apilog.core.enums.OperateTypeEnum.*;
|
||||
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*;
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO;
|
||||
import cn.iocoder.yudao.module.llm.service.conversation.ConversationService;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
@Tag(name = "管理后台 - 大模型对话记录")
|
||||
@RestController
|
||||
@RequestMapping("/llm/conversation")
|
||||
@Validated
|
||||
@Slf4j
|
||||
public class ConversationController {
|
||||
|
||||
@Resource
|
||||
@ -101,6 +109,31 @@ public class ConversationController {
|
||||
public CommonResult<ChatRespVO> chat(@Valid @RequestBody ChatReqVO chatReqVO) {
|
||||
return success(conversationService.chat(chatReqVO));
|
||||
}
|
||||
|
||||
/**
|
||||
* 对话推理接口,使用 SSE 进行流式响应
|
||||
* @param chatReqVO 对话请求对象
|
||||
* @return SseEmitter 对象,用于流式发送响应
|
||||
*/
|
||||
@PostMapping("/stream-chat")
|
||||
public SseEmitter streamChat(@Valid @RequestBody ChatReqVO chatReqVO,HttpServletResponse response) {
|
||||
log.info("收到对话推理请求,请求参数: {}", chatReqVO);
|
||||
SseEmitter emitter = new SseEmitter();
|
||||
try {
|
||||
conversationService.chatStream(chatReqVO, emitter,response);
|
||||
} catch (Exception e) {
|
||||
log.error("处理对话推理请求时发生异常", e);
|
||||
try {
|
||||
emitter.completeWithError(e);
|
||||
} catch (Exception ex) {
|
||||
log.error("无法完成 SseEmitter 错误处理", ex);
|
||||
}
|
||||
}
|
||||
log.info("返回 SseEmitter 对象,准备进行流式响应");
|
||||
return emitter;
|
||||
}
|
||||
|
||||
|
||||
@PostMapping("/text-to-image")
|
||||
@Operation(summary = "文字转图片接口")
|
||||
public CommonResult<JSONArray> textToImage(@Valid @RequestBody TextToImageReqVo req) {
|
||||
|
@ -27,4 +27,8 @@ public class ChatReqVO {
|
||||
private String systemPrompt;
|
||||
@Schema(description = "知识库id")
|
||||
private Long knowledge;
|
||||
@Schema(description = "单次回复限制max_tokens")
|
||||
private Integer maxTokens;
|
||||
@Schema(description = "随机性temperature")
|
||||
private Double temperature;
|
||||
}
|
||||
|
@ -32,4 +32,8 @@ public class DataRefluxDataSaveReqVO {
|
||||
@Schema(description = "Response")
|
||||
private String response;
|
||||
|
||||
@Schema(description = "单次回复限制max_tokens")
|
||||
private Integer maxTokens;
|
||||
@Schema(description = "随机性temperature")
|
||||
private Double temperature;
|
||||
}
|
||||
|
@ -119,5 +119,8 @@ public class LLMBackendProperties {
|
||||
|
||||
private String embedQuery;
|
||||
|
||||
/**
|
||||
* 获取调优检查点列表
|
||||
*/
|
||||
private String checkFileList;
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package cn.iocoder.yudao.module.llm.service.conversation;
|
||||
|
||||
import java.util.*;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.validation.*;
|
||||
|
||||
import cn.hutool.json.JSONArray;
|
||||
@ -9,7 +10,9 @@ import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*;
|
||||
import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.ModelCompletionsReqVO;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.TextToImageReqVo;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
/**
|
||||
* 大模型对话记录 Service 接口
|
||||
@ -64,4 +67,11 @@ public interface ConversationService {
|
||||
* @return
|
||||
*/
|
||||
JSONArray textToImage(TextToImageReqVo req);
|
||||
}
|
||||
|
||||
/**
|
||||
* 聊天流
|
||||
* @param chatReqVO chatReqVO
|
||||
* @param emitter emitter
|
||||
*/
|
||||
void chatStream (@Valid ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response);
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
package cn.iocoder.yudao.module.llm.service.conversation;
|
||||
|
||||
import cn.hutool.core.bean.BeanUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.json.JSONArray;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.http.HttpUtils;
|
||||
@ -32,11 +34,17 @@ import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
@ -49,6 +57,7 @@ import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*;
|
||||
*/
|
||||
@Service
|
||||
@Validated
|
||||
@Slf4j
|
||||
public class ConversationServiceImpl implements ConversationService {
|
||||
|
||||
@Resource
|
||||
@ -130,19 +139,11 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
ApplicationSaveReqVO applicationSaveReqVO = BeanUtils.toBean(application, ApplicationSaveReqVO.class);
|
||||
applicationService.updateApplication(applicationSaveReqVO);
|
||||
}
|
||||
/*Long promptId = application.getPromptId();
|
||||
PromptTemplatesRespVO promptTemplates = promptTemplatesService.getPromptTemplates(promptId);
|
||||
if(promptTemplates != null){
|
||||
chatReqVO.setSystemPrompt(promptTemplates.getTemplateText());
|
||||
}*/
|
||||
|
||||
chatReqVO.setSystemPrompt(application.getPrompt());
|
||||
}
|
||||
}
|
||||
/* if (Objects.equals(1, chatReqVO.getModelType())){
|
||||
return publicModelChat(chatReqVO);
|
||||
}else {
|
||||
return privateModelChat(chatReqVO);
|
||||
}*/
|
||||
|
||||
return publicModelChat(chatReqVO);
|
||||
}
|
||||
|
||||
@ -273,9 +274,171 @@ public class ConversationServiceImpl implements ConversationService {
|
||||
dataRefluxDataSaveReqVO.setPrompt(chatReqVO.getPrompt());
|
||||
dataRefluxDataSaveReqVO.setResponse(modelCompletionsRespVO.getAnswer());
|
||||
dataRefluxDataSaveReqVO.setSystem(modelCompletionsRespVO.getSystem());
|
||||
dataRefluxDataSaveReqVO.setMaxTokens(chatReqVO.getMaxTokens());
|
||||
dataRefluxDataSaveReqVO.setTemperature(chatReqVO.getTemperature());
|
||||
dataRefluxDataService.saveDataRefluxData(dataRefluxDataSaveReqVO);
|
||||
return chatRespVO;
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理对话请求,以流式方式返回结果
|
||||
* @param chatReqVO 对话请求对象
|
||||
* @param emitter SseEmitter 对象,用于流式发送响应
|
||||
*/
|
||||
public void chatStream(ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response) {
|
||||
log.info("开始处理对话请求,请求参数: {}", chatReqVO);
|
||||
// 检查系统提示信息,如果为空则尝试从应用中获取
|
||||
if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().equals("")) {
|
||||
if (chatReqVO.getApplicationId() != null) {
|
||||
log.info("系统提示信息为空,尝试从应用中获取,应用 ID: {}", chatReqVO.getApplicationId());
|
||||
ApplicationRespVO application = applicationService.getApplication(chatReqVO.getApplicationId());
|
||||
List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
|
||||
if (CollectionUtils.isEmpty(messageHistoryList)) {
|
||||
log.info("聊天历史记录为空,更新应用聊天计数");
|
||||
application.setChatCount(application.getChatCount() + 1);
|
||||
ApplicationSaveReqVO applicationSaveReqVO = BeanUtil.toBean(application, ApplicationSaveReqVO.class);
|
||||
applicationService.updateApplication(applicationSaveReqVO);
|
||||
}
|
||||
chatReqVO.setSystemPrompt(application.getPrompt());
|
||||
log.info("已更新系统提示信息为: {}", chatReqVO.getSystemPrompt());
|
||||
}
|
||||
}
|
||||
// 调用公共模型聊天流式处理方法
|
||||
publicModelChatStream(chatReqVO, emitter, response);
|
||||
}
|
||||
|
||||
/**
|
||||
* 公共模型聊天流式处理方法
|
||||
* @param chatReqVO 对话请求对象
|
||||
* @param emitter SseEmitter 对象,用于流式发送响应
|
||||
*/
|
||||
public void publicModelChatStream(ChatReqVO chatReqVO, SseEmitter emitter,HttpServletResponse response) {
|
||||
log.info("开始公共模型聊天流式处理,请求参数: {}", chatReqVO);
|
||||
// 检查 UUID 是否为空,若为空则生成一个
|
||||
if (StrUtil.isBlank(chatReqVO.getUuid())) {
|
||||
log.info("UUID 为空,生成新的 UUID");
|
||||
chatReqVO.setUuid(UUID.randomUUID().toString());
|
||||
}
|
||||
String model = null;
|
||||
String selfModelUrl = "";
|
||||
// 根据模型类型获取模型信息
|
||||
if (Objects.equals(1, chatReqVO.getModelType())) {
|
||||
log.info("使用预制模型,模型 ID: {}", chatReqVO.getModelId());
|
||||
BaseModelDO baseModelDO = baseModelService.getBaseModel(chatReqVO.getModelId());
|
||||
if (baseModelDO == null) {
|
||||
log.error("预制模型不存在,模型 ID: {}", chatReqVO.getModelId());
|
||||
try {
|
||||
emitter.completeWithError(new RuntimeException("BASE_MODEL_NOT_EXISTS"));
|
||||
} catch (Exception e) {
|
||||
log.error("无法完成 SseEmitter 错误处理", e);
|
||||
}
|
||||
return;
|
||||
}
|
||||
selfModelUrl = baseModelDO.getChatUrl();
|
||||
model = baseModelDO.getAigcModelName();
|
||||
log.info("获取到预制模型信息,模型名称: {}, 聊天 URL: {}", model, selfModelUrl);
|
||||
} else if (Objects.equals(0, chatReqVO.getModelType())) {
|
||||
log.info("使用自定义模型,模型 ID: {}", chatReqVO.getModelId());
|
||||
ModelServiceDO modelServiceDO = modelServiceMapper.selectById(chatReqVO.getModelId());
|
||||
if (modelServiceDO == null) {
|
||||
log.error("自定义模型服务不存在,模型 ID: {}", chatReqVO.getModelId());
|
||||
try {
|
||||
emitter.completeWithError(new RuntimeException("MODEL_SERVICE_NOT_EXISTS"));
|
||||
} catch (Exception e) {
|
||||
log.error("无法完成 SseEmitter 错误处理", e);
|
||||
}
|
||||
return;
|
||||
}
|
||||
model = modelServiceDO.getBaseModelName();
|
||||
selfModelUrl = modelServiceDO.getModelUrl();
|
||||
log.info("获取到自定义模型信息,模型名称: {}, 模型 URL: {}", model, selfModelUrl);
|
||||
} else {
|
||||
log.error("无效的模型类型,模型类型: {}", chatReqVO.getModelType());
|
||||
try {
|
||||
emitter.completeWithError(new RuntimeException("BASE_MODEL_NOT_EXISTS"));
|
||||
} catch (Exception e) {
|
||||
log.error("无法完成 SseEmitter 错误处理", e);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
List<ModelCompletionsReqVO.ModelCompletionsMessage> messages = new ArrayList<>();
|
||||
|
||||
// 如果知识库 ID 不为空,先调用知识库获取相关信息
|
||||
StringBuilder knowledgeBase = new StringBuilder();
|
||||
if (chatReqVO.getKnowledge() != null && chatReqVO.getKnowledge() != 0) {
|
||||
log.info("知识库 ID 不为空,开始查询知识库,知识库 ID: {}", chatReqVO.getKnowledge());
|
||||
LambdaQueryWrapper<KnowledgeDocumentsDO> queryWrapper = new LambdaQueryWrapper<>();
|
||||
queryWrapper.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, chatReqVO.getKnowledge());
|
||||
List<KnowledgeDocumentsDO> fileList = knowledgeDocumentsMapper.selectList(queryWrapper);
|
||||
for (KnowledgeDocumentsDO knowledgeDocumentsDO : fileList) {
|
||||
Long id = knowledgeDocumentsDO.getId();
|
||||
KnowledgeRagEmbedQueryVO knowledgeRagEmbedQueryVO = new KnowledgeRagEmbedQueryVO();
|
||||
knowledgeRagEmbedQueryVO.setFile_id(id.toString());
|
||||
knowledgeRagEmbedQueryVO.setQuery(chatReqVO.getPrompt());
|
||||
String result = HttpUtils.post(llmBackendProperties.getEmbedQuery(), null, JSON.toJSONString(knowledgeRagEmbedQueryVO));
|
||||
com.alibaba.fastjson.JSONArray jsonArray = JSON.parseArray(result);
|
||||
if (jsonArray != null && !jsonArray.isEmpty()) {
|
||||
JSONArray jsonArray1 = (JSONArray) jsonArray.get(0);
|
||||
JSONObject jsonObject = (JSONObject) jsonArray1.get(0);
|
||||
knowledgeBase.append(jsonObject.get("page_content"));
|
||||
}
|
||||
}
|
||||
log.info("知识库查询完成,获取到的信息: {}", knowledgeBase.toString());
|
||||
}
|
||||
String mess = chatReqVO.getSystemPrompt() + "<content>" + knowledgeBase.toString() + "</content>";
|
||||
// 查询历史记录消息,并将查询出来的知识信息放入到 role = system 的消息中
|
||||
List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
|
||||
if (messageHistoryList != null && !messageHistoryList.isEmpty()) {
|
||||
log.info("存在聊天历史记录,处理历史记录消息");
|
||||
for (String messageHistory : messageHistoryList) {
|
||||
ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class);
|
||||
if (modelCompletionsMessage.getRole().equals("system")) {
|
||||
modelCompletionsMessage.setContent(mess);
|
||||
stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
|
||||
}
|
||||
messages.add(modelCompletionsMessage);
|
||||
}
|
||||
} else {
|
||||
log.info("不存在聊天历史记录,创建新的系统消息");
|
||||
ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
|
||||
systemMessage.setRole("system");
|
||||
systemMessage.setContent(mess);
|
||||
stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
|
||||
messages.add(systemMessage);
|
||||
}
|
||||
|
||||
// 创建用户消息
|
||||
ModelCompletionsReqVO.ModelCompletionsMessage message = new ModelCompletionsReqVO.ModelCompletionsMessage();
|
||||
message.setRole("user");
|
||||
message.setContent(chatReqVO.getPrompt());
|
||||
messages.add(message);
|
||||
|
||||
// 构建模型补全请求对象
|
||||
ModelCompletionsReqVO modelCompletionsReqVO = new ModelCompletionsReqVO();
|
||||
modelCompletionsReqVO.setMessages(messages);
|
||||
modelCompletionsReqVO.setModel(model);
|
||||
log.info("构建模型补全请求对象,请求参数: {}", modelCompletionsReqVO);
|
||||
|
||||
// 调用模型服务进行流式处理
|
||||
modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter, response);
|
||||
|
||||
// 将用户消息存入缓存
|
||||
stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(message));
|
||||
|
||||
// 保存数据回流信息
|
||||
DataRefluxDataSaveReqVO dataRefluxDataSaveReqVO = new DataRefluxDataSaveReqVO();
|
||||
dataRefluxDataSaveReqVO.setModelServiceId(chatReqVO.getModelId());
|
||||
dataRefluxDataSaveReqVO.setModelType(chatReqVO.getModelType());
|
||||
dataRefluxDataSaveReqVO.setPrompt(chatReqVO.getPrompt());
|
||||
dataRefluxDataSaveReqVO.setSystem("助手");
|
||||
dataRefluxDataSaveReqVO.setMaxTokens(chatReqVO.getMaxTokens());
|
||||
dataRefluxDataSaveReqVO.setTemperature(chatReqVO.getTemperature());
|
||||
dataRefluxDataService.saveDataRefluxData(dataRefluxDataSaveReqVO);
|
||||
log.info("数据回流信息保存完成");
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 私有模型聊天
|
||||
* @param chatReqVO
|
||||
|
@ -7,14 +7,12 @@ import cn.iocoder.yudao.module.llm.dal.mysql.finetuningtask.FineTuningTaskMapper
|
||||
import cn.iocoder.yudao.module.llm.dal.mysql.servername.ServerNameMapper;
|
||||
import cn.iocoder.yudao.module.llm.enums.FineTuningTaskStatusConstants;
|
||||
import cn.iocoder.yudao.module.llm.enums.FinetuningTaskStatusEnum;
|
||||
import cn.iocoder.yudao.module.llm.service.basemodel.vo.ModelListRes;
|
||||
import cn.iocoder.yudao.module.llm.service.http.FineTuningTaskHttpService;
|
||||
import cn.iocoder.yudao.module.llm.service.http.TrainHttpService;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.AigcFineTuningDetailRespVO;
|
||||
import cn.iocoder.yudao.module.llm.service.http.vo.AigcModelDeployVO;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.esotericsoftware.minlog.Log;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@ -26,7 +24,6 @@ import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.net.MalformedURLException;
|
||||
import java.net.URL;
|
||||
import java.net.URLConnection;
|
||||
import java.time.Duration;
|
||||
@ -49,46 +46,78 @@ public class FineTuningTaskSyncService {
|
||||
@Resource
|
||||
private FineTuningTaskHttpService fineTuningTaskHttpService;
|
||||
|
||||
@Scheduled(cron = "0 */1 * * * ?")
|
||||
public void updateFineTuningTaskStatus() {
|
||||
Log.info("FineTuningTaskSync 定时任务启动");
|
||||
List<FineTuningTaskDO> fineTuningTaskDOList = fineTuningTaskMapper.selectList();
|
||||
for (FineTuningTaskDO fineTuningTaskDO : fineTuningTaskDOList) {
|
||||
private static final String FINE_TUNING_LOG = "~~~ fine - tuning ~~~ : ";
|
||||
|
||||
if(Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.TRAINING.getStatus())
|
||||
|| Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.WAITING.getStatus())){
|
||||
ServerNameDO serverNameDO = serverNameMapper.selectById(fineTuningTaskDO.getGpuType());
|
||||
if (fineTuningTaskDO.getJobId() == null){
|
||||
@Scheduled(cron = "0 */1 * * * ?")
|
||||
public void updateFineTuningTaskStatus () {
|
||||
// 方法入口:微调任务状态同步定时任务
|
||||
log.info("{} [定时任务] FineTuningTaskSync 启动 - 开始同步微调任务状态", FINE_TUNING_LOG);
|
||||
|
||||
// 1. 查询所有待处理的微调任务
|
||||
log.debug("正在查询所有微调任务...");
|
||||
List<FineTuningTaskDO> fineTuningTaskDOList = fineTuningTaskMapper.selectList();
|
||||
log.info("{} [任务查询] 共获取到{}个微调任务。", FINE_TUNING_LOG, fineTuningTaskDOList.size());
|
||||
|
||||
for (FineTuningTaskDO fineTuningTaskDO : fineTuningTaskDOList) {
|
||||
log.info("{} [任务处理] 开始处理任务DO:{}", FINE_TUNING_LOG, JSON.toJSONString(fineTuningTaskDO));
|
||||
|
||||
// 检查任务状态是否为训练中或等待中
|
||||
if (Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.TRAINING.getStatus()) || Objects.equals(fineTuningTaskDO.getStatus(), FinetuningTaskStatusEnum.WAITING.getStatus())) {
|
||||
log.info("{} [状态更新] 检测到需更新状态的任务ID:{},当前状态:{}",
|
||||
FINE_TUNING_LOG, fineTuningTaskDO.getId(), fineTuningTaskDO.getStatus());
|
||||
|
||||
// 3. 获取GPU服务器信息
|
||||
|
||||
String hostUrl = getHostUrl(fineTuningTaskDO);
|
||||
|
||||
// 4. 检查任务是否有 Job ID
|
||||
if (fineTuningTaskDO.getJobId() == null) {
|
||||
log.warn("{} 微调任务未关联 Job ID,任务ID: {}", FINE_TUNING_LOG, fineTuningTaskDO.getId());
|
||||
continue;
|
||||
}
|
||||
String hostUrl = serverNameDO!=null ?serverNameDO.getHost():"";
|
||||
String queryJobs = "?filter={\"job_id\":\""+fineTuningTaskDO.getJobId()+"\"}";
|
||||
String respJobs = fineTuningTaskHttpService.modelTableQuery(new HashMap<>(), hostUrl,"fine_tuning_train_job",queryJobs);
|
||||
|
||||
// 5. 构建API查询参数
|
||||
String queryJobs = "?filter={\"job_id\":\"" + fineTuningTaskDO.getJobId() + "\"}";
|
||||
log.info("{} [API请求] 发起任务查询 job_id={} 查询参数: {}", FINE_TUNING_LOG, fineTuningTaskDO.getJobId(), queryJobs);
|
||||
|
||||
// 6. 调用HTTP服务获取任务详情
|
||||
String respJobs = fineTuningTaskHttpService.modelTableQuery(
|
||||
new HashMap<>(), hostUrl, "fine_tuning_train_job", queryJobs);
|
||||
log.info("{} [API响应] 任务ID:{} 原始响应数据:{}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), respJobs);
|
||||
|
||||
// 7. 解析API响应数据
|
||||
AigcFineTuningDetailRespVO resp = new AigcFineTuningDetailRespVO();
|
||||
|
||||
try {
|
||||
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
List<AigcFineTuningDetailRespVO> aigcFineTuningDetailRespVO = mapper.readValue(respJobs,new TypeReference<List<AigcFineTuningDetailRespVO>>() {});
|
||||
if (aigcFineTuningDetailRespVO != null && aigcFineTuningDetailRespVO.size() > 0){
|
||||
if(aigcFineTuningDetailRespVO.get(0).getJob_id().equals(fineTuningTaskDO.getJobId())){
|
||||
resp = aigcFineTuningDetailRespVO.get(0);
|
||||
}
|
||||
List<AigcFineTuningDetailRespVO> respList = mapper.readValue(
|
||||
respJobs, new TypeReference<List<AigcFineTuningDetailRespVO>>() {
|
||||
});
|
||||
|
||||
if (respList != null && !respList.isEmpty()) {
|
||||
resp = respList.get(0);
|
||||
log.info("{} [状态更新] 任务ID:{},当前状态:{},API响应状态:{}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), fineTuningTaskDO.getStatus(), resp.getTrain_status());
|
||||
|
||||
}
|
||||
}catch (Exception e){
|
||||
log.error("获取微调任务状态失败{}",e.getMessage());
|
||||
} catch (Exception e) {
|
||||
log.error("{} 解析微调任务状态时发生异常。任务ID: {}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), e);
|
||||
}
|
||||
if (resp == null){
|
||||
|
||||
if (resp == null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
FineTuningTaskDO updateObj = new FineTuningTaskDO();
|
||||
if (ObjectUtil.isAllEmpty(resp.getTrain_status())){
|
||||
if (ObjectUtil.isAllEmpty(resp.getTrain_status())) {
|
||||
continue;
|
||||
}
|
||||
Integer status = FineTuningTaskStatusConstants.getStatus(resp.getTrain_status());
|
||||
if(status != null){
|
||||
if (status != null) {
|
||||
updateObj.setId(fineTuningTaskDO.getId());
|
||||
}
|
||||
//如果微调任务由训练中变为已完成,则取模型列表中获取该任务生成的模型id
|
||||
if((fineTuningTaskDO.getStatus() == 1 && status == 2) || (fineTuningTaskDO.getStatus() == 3 && status == 2)){
|
||||
if ((fineTuningTaskDO.getStatus() == 1 && status == 2) || (fineTuningTaskDO.getStatus() == 3 && status == 2)) {
|
||||
String jobModelName = fineTuningTaskDO.getJobModelName();
|
||||
/* String resStr = trainHttpService.modelsList("");
|
||||
Log.info("获取aicg模型列表返回数据内容{}",resStr);
|
||||
@ -117,82 +146,121 @@ public class FineTuningTaskSyncService {
|
||||
}*/
|
||||
updateObj.setStatus(2);
|
||||
// 获取模型id
|
||||
String querModels = "?filter={\"model_name\":\""+resp.getFine_tuned_model()+"\"}";
|
||||
String resModels = fineTuningTaskHttpService.modelTableQuery(new HashMap<>(),hostUrl, "models",querModels);
|
||||
log.info("获取 aigc models 表数据 info {}",resModels);
|
||||
String querModels = "?filter={\"model_name\":\"" + resp.getFine_tuned_model() + "\"}";
|
||||
String resModels = fineTuningTaskHttpService.modelTableQuery(new HashMap<>(), hostUrl, "models", querModels);
|
||||
log.info("获取 aigc models 表数据 info {}", resModels);
|
||||
JSONArray jsonArrayModels = JSONArray.parseArray(resModels);
|
||||
|
||||
// 时长获取
|
||||
Duration duration = Duration.between(fineTuningTaskDO.getCreateTime(),LocalDateTime.now());
|
||||
Duration duration = Duration.between(fineTuningTaskDO.getCreateTime(), LocalDateTime.now());
|
||||
long minutes = duration.toMinutes();
|
||||
updateObj.setTrainDuration(String.valueOf(minutes));
|
||||
|
||||
if(jsonArrayModels.size() > 0){
|
||||
if (jsonArrayModels.size() > 0) {
|
||||
JSONObject jsonObjectModels = jsonArrayModels.getJSONObject(0);
|
||||
updateObj.setJobModelListId(jsonObjectModels.getLong("id"));
|
||||
}
|
||||
|
||||
}catch (Exception e){
|
||||
log.error(" error {}",e.getMessage());
|
||||
} catch (Exception e) {
|
||||
log.error(" error {}", e.getMessage());
|
||||
}
|
||||
|
||||
try {
|
||||
//获取检查点信息
|
||||
//todo 模型工厂的功能有问题,暂时写死
|
||||
// jobModelName = "Qwen2.5-0.5B-Instruct-147";
|
||||
String checkFileList = trainHttpService.getCheckFileList(hostUrl,jobModelName);
|
||||
List<String> checkpoints = new ArrayList<>();
|
||||
List<String> fileUrls = new ArrayList<>();
|
||||
List<String> fileList = JSONArray.parseArray(checkFileList,String.class);
|
||||
Map<String,JSONObject> map = new HashMap<>();
|
||||
for (String s : fileList) {
|
||||
if(s.contains("checkpoint")){
|
||||
checkpoints.add(s);
|
||||
}
|
||||
}
|
||||
if(checkpoints.size() > 0){
|
||||
for (String checkpoint : checkpoints) {
|
||||
String filePath = "/" + checkpoint + "/trainer_state.json";
|
||||
fileUrls.add(filePath);
|
||||
String fileUrl = trainHttpService.getCheckFile(hostUrl,jobModelName, filePath);
|
||||
try {
|
||||
URL url = new URL(fileUrl);
|
||||
URLConnection urlConnection = url.openConnection();
|
||||
InputStream inputStream = urlConnection.getInputStream();
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
|
||||
String line;
|
||||
StringBuilder content = new StringBuilder();
|
||||
while ((line = reader.readLine()) != null) {
|
||||
content.append(line + "\n");
|
||||
}
|
||||
reader.close();
|
||||
map.put(checkpoint,JSONObject.parseObject(content.toString()));
|
||||
} catch (MalformedURLException e) {
|
||||
e.printStackTrace();
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
updateObj.setCheckPointFilePath(JSONObject.toJSONString(fileUrls));
|
||||
updateObj.setCheckPointData(JSONObject.toJSONString(map));
|
||||
|
||||
}catch (Exception e){
|
||||
log.error(" error {}",e.getMessage());
|
||||
}
|
||||
getCheckPoint(fineTuningTaskDO, jobModelName, hostUrl, updateObj);
|
||||
|
||||
fineTuningTaskMapper.updateById(updateObj);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
updateObj.setStatus(status);
|
||||
}
|
||||
fineTuningTaskMapper.updateById(updateObj);
|
||||
|
||||
fineTuningTaskMapper.updateById(updateObj);
|
||||
} else {
|
||||
// 处理已成功但是没有获取到检查点错误
|
||||
String jobModelName = fineTuningTaskDO.getJobModelName();
|
||||
String hostUrl = getHostUrl(fineTuningTaskDO);
|
||||
FineTuningTaskDO updateObj = new FineTuningTaskDO();
|
||||
updateObj.setId(fineTuningTaskDO.getId());
|
||||
getCheckPoint(fineTuningTaskDO, jobModelName, hostUrl, updateObj);
|
||||
fineTuningTaskMapper.updateById(updateObj);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
private String getHostUrl (FineTuningTaskDO fineTuningTaskDO) {
|
||||
ServerNameDO serverNameDO = serverNameMapper.selectById(fineTuningTaskDO.getGpuType());
|
||||
String hostUrl = (serverNameDO != null) ? serverNameDO.getHost() : "";
|
||||
log.info("{} [主机信息] 任务ID:{} GPU 类型: {} 关联主机URL:{}", FINE_TUNING_LOG, fineTuningTaskDO.getId(), fineTuningTaskDO.getGpuType(), hostUrl);
|
||||
return hostUrl;
|
||||
}
|
||||
|
||||
private void getCheckPoint (FineTuningTaskDO fineTuningTaskDO, String jobModelName, String hostUrl, FineTuningTaskDO updateObj) {
|
||||
// 获取检查点信息
|
||||
try {
|
||||
// String jobModelName = fineTuningTaskDO.getJobModelName();
|
||||
log.info("正在获取检查点信息,模型名称: {}", jobModelName);
|
||||
String checkFileList = trainHttpService.getCheckFileList(hostUrl, jobModelName);
|
||||
List<String> checkpoints = new ArrayList<>();
|
||||
List<String> fileUrls = new ArrayList<>();
|
||||
Map<String, JSONObject> map = new HashMap<>();
|
||||
|
||||
// List<String> fileList = JSONArray.parseArray(checkFileList, String.class);
|
||||
// for (String s : fileList) {
|
||||
// if (s.contains("checkpoint")) {
|
||||
// checkpoints.add(s);
|
||||
// }
|
||||
// }
|
||||
// 判断是否是数组
|
||||
if (checkFileList.startsWith("[")) {
|
||||
List<String> fileList = JSONArray.parseArray(checkFileList, String.class);
|
||||
// 处理文件列表
|
||||
for (String s : fileList) {
|
||||
if (s.contains("checkpoint")) {
|
||||
checkpoints.add(s);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// TODO 处理错误信息
|
||||
JSONObject errorObj = JSON.parseObject(checkFileList);
|
||||
String errorMsg = errorObj.getString("error");
|
||||
// 抛出异常或记录日志
|
||||
log.error("获取检查点文件列表时出错:{}", errorMsg);
|
||||
}
|
||||
|
||||
if (!checkpoints.isEmpty()) {
|
||||
log.info("找到 {} 个检查点文件。", checkpoints.size());
|
||||
for (String checkpoint : checkpoints) {
|
||||
String filePath = "/" + checkpoint + "/trainer_state.json";
|
||||
fileUrls.add(filePath);
|
||||
String fileUrl = trainHttpService.getCheckFile(hostUrl, jobModelName, filePath);
|
||||
|
||||
try {
|
||||
URL url = new URL(fileUrl);
|
||||
URLConnection urlConnection = url.openConnection();
|
||||
InputStream inputStream = urlConnection.getInputStream();
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
|
||||
StringBuilder content = new StringBuilder();
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
content.append(line).append("\n");
|
||||
}
|
||||
reader.close();
|
||||
map.put(checkpoint, JSONObject.parseObject(content.toString()));
|
||||
log.info("检查点文件解析成功。检查点: {}", checkpoint);
|
||||
} catch (IOException e) {
|
||||
log.error("读取检查点文件时发生异常。任务ID: {}", fineTuningTaskDO.getId(), e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.warn("未找到检查点文件,任务ID: {}", fineTuningTaskDO.getId());
|
||||
}
|
||||
|
||||
updateObj.setCheckPointFilePath(JSONObject.toJSONString(fileUrls));
|
||||
updateObj.setCheckPointData(JSONObject.toJSONString(map));
|
||||
log.info("检查点信息更新完成。");
|
||||
} catch (Exception e) {
|
||||
log.error("获取检查点信息时发生异常。任务ID: {}", fineTuningTaskDO.getId(), e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
@ -11,12 +11,18 @@ import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.http.HttpEntity;
|
||||
import org.apache.http.HttpResponse;
|
||||
import org.apache.http.client.HttpClient;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.apache.http.entity.StringEntity;
|
||||
import org.apache.http.impl.client.HttpClients;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.io.*;
|
||||
import java.util.*;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
@ -131,14 +137,136 @@ public class ModelService {
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过 GPU 类型获取对应的主机地址
|
||||
* <p>ModelService GpuType 是 ServerName 表ID
|
||||
* 模型补全流式处理方法
|
||||
*
|
||||
* @param gpuType gpuType
|
||||
* @return host
|
||||
* @param url 模型服务的 URL
|
||||
* @param req 模型补全请求对象
|
||||
*/
|
||||
public String getHostByType (Long gpuType) {
|
||||
return serverNameService.getServerName(gpuType).getHost();
|
||||
public void modelCompletionsStream (String url, ModelCompletionsReqVO req, SseEmitter emitter, HttpServletResponse response) {
|
||||
req.setStream(true);
|
||||
log.info("开始处理模型补全请求,参数: {}", req);
|
||||
try {
|
||||
log.info("开始处理模型补全请求,参数: {}", JSON.toJSONString(req));
|
||||
sendPostRequest(url, JSON.toJSONString(req), emitter);
|
||||
} catch (Exception e) {
|
||||
emitter.completeWithError(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送 POST 请求并处理响应
|
||||
*
|
||||
* @param apiUrl 目标 API 的 URL
|
||||
* @param requestBody 请求体内容
|
||||
* @throws IOException 发送请求或处理响应时可能抛出的 IO 异常
|
||||
*/
|
||||
private void sendPostRequest (String apiUrl, String requestBody, SseEmitter emitter) throws IOException {
|
||||
// 创建 HttpClient 实例
|
||||
HttpClient httpClient = HttpClients.createDefault();
|
||||
// 创建 HttpPost 请求对象
|
||||
HttpPost httpPost = new HttpPost(apiUrl);
|
||||
|
||||
// 设置请求体和请求头
|
||||
setupRequest(httpPost, requestBody);
|
||||
|
||||
// 执行 POST 请求并获取响应
|
||||
HttpResponse response = httpClient.execute(httpPost);
|
||||
|
||||
// 处理响应实体
|
||||
handleResponseEntity(response, emitter);
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置请求体和请求头
|
||||
*
|
||||
* @param httpPost HttpPost 请求对象
|
||||
* @param requestBody 请求体内容
|
||||
* @throws IOException 创建 StringEntity 时可能抛出的 IO 异常
|
||||
*/
|
||||
private void setupRequest (HttpPost httpPost, String requestBody) throws IOException {
|
||||
// 创建 StringEntity 对象,用于封装请求体
|
||||
StringEntity entity = new StringEntity(requestBody);
|
||||
// 设置请求体
|
||||
httpPost.setEntity(entity);
|
||||
// 设置请求头,指定请求体的内容类型为 JSON
|
||||
httpPost.setHeader("Content-Type", "application/json");
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理响应实体
|
||||
*
|
||||
* @param response HttpResponse 对象
|
||||
* @throws IOException 读取响应实体输入流时可能抛出的 IO 异常
|
||||
*/
|
||||
private void handleResponseEntity (HttpResponse response, SseEmitter emitter) throws IOException {
|
||||
// 获取响应实体
|
||||
HttpEntity responseEntity = response.getEntity();
|
||||
if (responseEntity != null) {
|
||||
// 使用 try-with-resources 语句自动关闭输入流和 BufferedReader
|
||||
try (InputStream inputStream = responseEntity.getContent();
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
|
||||
String line;
|
||||
// 逐行读取响应内容并处理
|
||||
while ((line = reader.readLine()) != null) {
|
||||
System.out.println("接收到的响应行数据: " + line);
|
||||
String content = parseStreamLine(line);
|
||||
if (content != null) {
|
||||
emitter.send(SseEmitter.event().data(content));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析流式返回的单行数据,提取有效内容并清理特定标记
|
||||
*
|
||||
* @param line 流式响应中的单行JSON数据
|
||||
* @return 处理后的文本内容(若无有效内容返回null)
|
||||
*/
|
||||
private String parseStreamLine (String line) {
|
||||
if (StringUtils.isNotBlank(line)) {
|
||||
if (line.startsWith("data: ")) {
|
||||
String dataString = extractJsonFromDataString(line);
|
||||
if (!dataString.contains("[DONE]")) {
|
||||
JSONObject jsonObject = JSON.parseObject(dataString);
|
||||
// 获取 choices 数组
|
||||
JSONArray choicesArray = jsonObject.getJSONArray("choices");
|
||||
if (choicesArray != null && !choicesArray.isEmpty()) {
|
||||
// 获取第一个 choice 对象
|
||||
JSONObject firstChoice = choicesArray.getJSONObject(0);
|
||||
// 获取 delta 对象
|
||||
JSONObject delta = firstChoice.getJSONObject("delta");
|
||||
if (delta != null) {
|
||||
// 获取 content 的值
|
||||
String content = delta.getString("content");
|
||||
return "{\"content\":\"" + content + "\",\"finish_reason\":\"" + false + "\"}";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return "{\"content\":\"" + "\",\"finish_reason\":\"" + true + "\"}";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 从包含 "data: " 前缀的字符串中提取 JSON 字符串
|
||||
*
|
||||
* @param input 包含 "data: " 前缀的原始字符串
|
||||
* @return 提取出的 JSON 字符串,如果未找到 "data: " 则返回 null
|
||||
*/
|
||||
public static String extractJsonFromDataString (String input) {
|
||||
if (input == null) {
|
||||
return null;
|
||||
}
|
||||
int index = input.indexOf("data: ");
|
||||
if (index != -1) {
|
||||
return input.substring(index + "data: ".length()).trim();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -254,4 +382,5 @@ public class ModelService {
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -415,24 +415,24 @@ public class RagHttpService {
|
||||
Path tempFilePath = downloadFileToTemp(fileUrl, fileName);
|
||||
log.info("文件已下载到临时目录: {}", tempFilePath);
|
||||
|
||||
String fileSuffix = getFileSuffix(fileName);
|
||||
if ("doc".equals(fileSuffix)) {
|
||||
log.info("正在处理 doc 文件");
|
||||
try {
|
||||
tempFilePath= converterDocToDocx(tempFilePath.toString(), tempFilePath.toString().replace(".doc", ".docx"));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
// String fileSuffix = getFileSuffix(fileName);
|
||||
// if ("doc".equals(fileSuffix)) {
|
||||
// log.info("正在处理 doc 文件");
|
||||
// try {
|
||||
// tempFilePath = converterDocToDocx(tempFilePath.toString(), tempFilePath.toString().replace(".doc", ".docx"));
|
||||
// } catch (Exception e) {
|
||||
// throw new RuntimeException(e);
|
||||
// }
|
||||
// }
|
||||
|
||||
if ("md".equals(fileSuffix)) {
|
||||
log.info("正在处理 md 文件");
|
||||
try {
|
||||
tempFilePath= converterMdToTxt(tempFilePath.toString(), tempFilePath.toString().replace(".md", ".docx"));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
// if ("md".equals(fileSuffix)) {
|
||||
// log.info("正在处理 md 文件");
|
||||
// try {
|
||||
// tempFilePath= converterMdToTxt(tempFilePath.toString(), tempFilePath.toString().replace(".md", ".txt"));
|
||||
// } catch (Exception e) {
|
||||
// throw new RuntimeException(e);
|
||||
// }
|
||||
// }
|
||||
|
||||
// 创建 OkHttpClient 实例
|
||||
log.info("创建 OkHttpClient 实例,设置超时时间为 3 分钟");
|
||||
@ -626,7 +626,7 @@ public class RagHttpService {
|
||||
return path;
|
||||
}
|
||||
|
||||
public static Path converterDocToDocx(String inputPath, String outputPath) throws Exception {
|
||||
public static Path converterDocToDocx (String inputPath, String outputPath) throws Exception {
|
||||
// 读取DOC文档
|
||||
try (HWPFDocument doc = new HWPFDocument(Files.newInputStream(Paths.get(inputPath)))) {
|
||||
XWPFDocument docx = new XWPFDocument();
|
||||
@ -650,6 +650,7 @@ public class RagHttpService {
|
||||
return Paths.get(outputPath);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理响应结果
|
||||
*/
|
||||
|
@ -17,6 +17,9 @@ public class ModelCompletionsReqVO {
|
||||
private List<ModelCompletionsMessage> messages;
|
||||
private Integer max_tokens = 4000;
|
||||
private Double temperature = 0.7;
|
||||
// private Integer max_tokens;
|
||||
// private Double temperature;
|
||||
private Boolean stream;
|
||||
// private Integer max_length = 120000;
|
||||
|
||||
@Data
|
||||
|
Loading…
x
Reference in New Issue
Block a user