refactor(llm): 优化代码结构和流式接口设计

- 重构了 ConversationController 和 ConversationServiceImpl 中的方法
- 优化了代码格式和命名规范
-调整了部分方法的参数和返回类型
- 重构了 ModelService 中的流式处理逻辑- 增加了心跳检测机制,保证长连接稳定
This commit is contained in:
Liuyang 2025-03-02 10:54:10 +08:00
parent 6370cb223e
commit e95e42da1d
3 changed files with 110 additions and 104 deletions

View File

@ -1,45 +1,34 @@
package cn.iocoder.yudao.module.llm.controller.admin.conversation;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import cn.iocoder.yudao.framework.common.exception.ServiceException;
import cn.iocoder.yudao.module.llm.service.http.vo.ModelCompletionsReqVO;
import cn.iocoder.yudao.module.llm.service.http.vo.TextToImageReqVo;
import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.*;
import javax.annotation.Resource;
import org.springframework.validation.annotation.Validated;
import org.springframework.security.access.prepost.PreAuthorize;
import io.swagger.v3.oas.annotations.tags.Tag;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.Operation;
import javax.validation.constraints.*;
import javax.validation.*;
import javax.servlet.http.*;
import java.time.LocalDateTime;
import java.util.*;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;
import cn.iocoder.yudao.framework.apilog.core.annotation.ApiAccessLog;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import cn.iocoder.yudao.framework.excel.core.util.ExcelUtils;
import cn.iocoder.yudao.framework.apilog.core.annotation.ApiAccessLog;
import static cn.iocoder.yudao.framework.apilog.core.enums.OperateTypeEnum.*;
import cn.iocoder.yudao.module.llm.controller.admin.conversation.vo.*;
import cn.iocoder.yudao.module.llm.dal.dataobject.conversation.ConversationDO;
import cn.iocoder.yudao.module.llm.service.conversation.ConversationService;
import cn.iocoder.yudao.module.llm.service.http.vo.TextToImageReqVo;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletResponse;
import javax.validation.Valid;
import java.io.IOException;
import java.util.List;
import static cn.iocoder.yudao.framework.apilog.core.enums.OperateTypeEnum.EXPORT;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
@Tag(name = "管理后台 - 大模型对话记录")
@RestController
@RequestMapping("/llm/conversation")
@ -53,14 +42,14 @@ public class ConversationController {
@PostMapping("/create")
@Operation(summary = "创建大模型对话记录")
@PreAuthorize("@ss.hasPermission('llm:conversation:create')")
public CommonResult<Integer> createConversation(@Valid @RequestBody ConversationSaveReqVO createReqVO) {
public CommonResult<Integer> createConversation (@Valid @RequestBody ConversationSaveReqVO createReqVO) {
return success(conversationService.createConversation(createReqVO));
}
@PutMapping("/update")
@Operation(summary = "更新大模型对话记录")
@PreAuthorize("@ss.hasPermission('llm:conversation:update')")
public CommonResult<Boolean> updateConversation(@Valid @RequestBody ConversationSaveReqVO updateReqVO) {
public CommonResult<Boolean> updateConversation (@Valid @RequestBody ConversationSaveReqVO updateReqVO) {
conversationService.updateConversation(updateReqVO);
return success(true);
}
@ -69,7 +58,7 @@ public class ConversationController {
@Operation(summary = "删除大模型对话记录")
@Parameter(name = "id", description = "编号", required = true)
@PreAuthorize("@ss.hasPermission('llm:conversation:delete')")
public CommonResult<Boolean> deleteConversation(@RequestParam("id") Integer id) {
public CommonResult<Boolean> deleteConversation (@RequestParam("id") Integer id) {
conversationService.deleteConversation(id);
return success(true);
}
@ -78,7 +67,7 @@ public class ConversationController {
@Operation(summary = "获得大模型对话记录")
@Parameter(name = "id", description = "编号", required = true, example = "1024")
@PreAuthorize("@ss.hasPermission('llm:conversation:query')")
public CommonResult<ConversationRespVO> getConversation(@RequestParam("id") Integer id) {
public CommonResult<ConversationRespVO> getConversation (@RequestParam("id") Integer id) {
ConversationDO conversation = conversationService.getConversation(id);
return success(BeanUtils.toBean(conversation, ConversationRespVO.class));
}
@ -86,7 +75,7 @@ public class ConversationController {
@GetMapping("/page")
@Operation(summary = "获得大模型对话记录分页")
@PreAuthorize("@ss.hasPermission('llm:conversation:query')")
public CommonResult<PageResult<ConversationRespVO>> getConversationPage(@Valid ConversationPageReqVO pageReqVO) {
public CommonResult<PageResult<ConversationRespVO>> getConversationPage (@Valid ConversationPageReqVO pageReqVO) {
PageResult<ConversationDO> pageResult = conversationService.getConversationPage(pageReqVO);
return success(BeanUtils.toBean(pageResult, ConversationRespVO.class));
}
@ -95,32 +84,33 @@ public class ConversationController {
@Operation(summary = "导出大模型对话记录 Excel")
@PreAuthorize("@ss.hasPermission('llm:conversation:export')")
@ApiAccessLog(operateType = EXPORT)
public void exportConversationExcel(@Valid ConversationPageReqVO pageReqVO,
HttpServletResponse response) throws IOException {
public void exportConversationExcel (@Valid ConversationPageReqVO pageReqVO,
HttpServletResponse response) throws IOException {
pageReqVO.setPageSize(PageParam.PAGE_SIZE_NONE);
List<ConversationDO> list = conversationService.getConversationPage(pageReqVO).getList();
// 导出 Excel
ExcelUtils.write(response, "大模型对话记录.xls", "数据", ConversationRespVO.class,
BeanUtils.toBean(list, ConversationRespVO.class));
BeanUtils.toBean(list, ConversationRespVO.class));
}
@PostMapping("/chat")
@Operation(summary = "对话推理接口")
public CommonResult<ChatRespVO> chat(@Valid @RequestBody ChatReqVO chatReqVO) {
public CommonResult<ChatRespVO> chat (@Valid @RequestBody ChatReqVO chatReqVO) {
return success(conversationService.chat(chatReqVO));
}
/**
* 对话推理接口使用 SSE 进行流式响应
*
* @param chatReqVO 对话请求对象
* @return SseEmitter 对象用于流式发送响应
*/
@PostMapping("/stream-chat")
public SseEmitter streamChat(@Valid @RequestBody ChatReqVO chatReqVO,HttpServletResponse response) {
public SseEmitter streamChat (@Valid @RequestBody ChatReqVO chatReqVO, HttpServletResponse response) {
log.info("收到对话推理请求,请求参数: {}", chatReqVO);
SseEmitter emitter = new SseEmitter();
SseEmitter emitter = new SseEmitter(60_000L);
try {
conversationService.chatStream(chatReqVO, emitter,response);
conversationService.chatStream(chatReqVO, emitter, response);
} catch (Exception e) {
log.error("处理对话推理请求时发生异常", e);
try {
@ -136,7 +126,7 @@ public class ConversationController {
@PostMapping("/text-to-image")
@Operation(summary = "文字转图片接口")
public CommonResult<JSONArray> textToImage(@Valid @RequestBody TextToImageReqVo req) {
public CommonResult<JSONArray> textToImage (@Valid @RequestBody TextToImageReqVo req) {
return success(conversationService.textToImage(req));
}
}

View File

@ -28,7 +28,6 @@ import cn.iocoder.yudao.module.llm.service.datarefluxdata.DataRefluxDataService;
import cn.iocoder.yudao.module.llm.service.http.ModelService;
import cn.iocoder.yudao.module.llm.service.http.vo.*;
import cn.iocoder.yudao.module.llm.service.prompttemplates.PromptTemplatesService;
import cn.iocoder.yudao.module.llm.service.servername.ServerNameService;
import com.alibaba.excel.util.StringUtils;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
@ -36,15 +35,12 @@ import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.*;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@ -87,7 +83,7 @@ public class ConversationServiceImpl implements ConversationService {
private final static Long CHAT_HISTORY_REDIS_EXPIRE_SECONDS = 60 * 60 * 24L;
@Override
public Integer createConversation(ConversationSaveReqVO createReqVO) {
public Integer createConversation (ConversationSaveReqVO createReqVO) {
// 插入
ConversationDO conversation = BeanUtils.toBean(createReqVO, ConversationDO.class);
conversationMapper.insert(conversation);
@ -96,7 +92,7 @@ public class ConversationServiceImpl implements ConversationService {
}
@Override
public void updateConversation(ConversationSaveReqVO updateReqVO) {
public void updateConversation (ConversationSaveReqVO updateReqVO) {
// 校验存在
validateConversationExists(updateReqVO.getId());
// 更新
@ -105,36 +101,36 @@ public class ConversationServiceImpl implements ConversationService {
}
@Override
public void deleteConversation(Integer id) {
public void deleteConversation (Integer id) {
// 校验存在
validateConversationExists(id);
// 删除
conversationMapper.deleteById(id);
}
private void validateConversationExists(Integer id) {
private void validateConversationExists (Integer id) {
if (conversationMapper.selectById(id) == null) {
throw exception(CONVERSATION_NOT_EXISTS);
}
}
@Override
public ConversationDO getConversation(Integer id) {
public ConversationDO getConversation (Integer id) {
return conversationMapper.selectById(id);
}
@Override
public PageResult<ConversationDO> getConversationPage(ConversationPageReqVO pageReqVO) {
public PageResult<ConversationDO> getConversationPage (ConversationPageReqVO pageReqVO) {
return conversationMapper.selectPage(pageReqVO);
}
@Override
public ChatRespVO chat(ChatReqVO chatReqVO) {
if(chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().equals("")){
if(chatReqVO.getApplicationId() != null){
public ChatRespVO chat (ChatReqVO chatReqVO) {
if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().equals("")) {
if (chatReqVO.getApplicationId() != null) {
ApplicationRespVO application = applicationService.getApplication(chatReqVO.getApplicationId());
List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
if (CollectionUtils.isEmpty(messageHistoryList)){
if (CollectionUtils.isEmpty(messageHistoryList)) {
application.setChatCount(application.getChatCount() + 1);
ApplicationSaveReqVO applicationSaveReqVO = BeanUtils.toBean(application, ApplicationSaveReqVO.class);
applicationService.updateApplication(applicationSaveReqVO);
@ -148,7 +144,7 @@ public class ConversationServiceImpl implements ConversationService {
}
@Override
public JSONArray textToImage(TextToImageReqVo req) {
public JSONArray textToImage (TextToImageReqVo req) {
TextToImageRespVo textToImageRespVo = modelService.textToImage(req);
return textToImageRespVo.getData();
}
@ -156,10 +152,11 @@ public class ConversationServiceImpl implements ConversationService {
/**
* 公共模型聊天
*
* @param chatReqVO
* @return
*/
public ChatRespVO publicModelChat(ChatReqVO chatReqVO) {
public ChatRespVO publicModelChat (ChatReqVO chatReqVO) {
if (StringUtils.isBlank(chatReqVO.getUuid())) {
// 如果没有uuid就生成一个
chatReqVO.setUuid(UUID.randomUUID().toString());
@ -174,7 +171,7 @@ public class ConversationServiceImpl implements ConversationService {
}
selfModelUrl = baseModelDO.getChatUrl();
model = baseModelDO.getAigcModelName();
}else if (Objects.equals(0, chatReqVO.getModelType())) {
} else if (Objects.equals(0, chatReqVO.getModelType())) {
// 自定义模型
ModelServiceDO modelServiceDO = modelServiceMapper.selectById(chatReqVO.getModelId());
if (modelServiceDO == null) {
@ -182,17 +179,17 @@ public class ConversationServiceImpl implements ConversationService {
}
model = modelServiceDO.getBaseModelName();
selfModelUrl = modelServiceDO.getModelUrl();
}else {
} else {
throw exception(BASE_MODEL_NOT_EXISTS);
}
List<ModelCompletionsReqVO.ModelCompletionsMessage> messages = new ArrayList<>();
//如果知识库id不为null先去调用知识库
StringBuilder knowledgeBase = new StringBuilder();
if(chatReqVO.getKnowledge() != null && chatReqVO.getKnowledge() != 0){
StringBuilder knowledgeBase = new StringBuilder();
if (chatReqVO.getKnowledge() != null && chatReqVO.getKnowledge() != 0) {
LambdaQueryWrapper<KnowledgeDocumentsDO> queryWrapper = new LambdaQueryWrapper<>();
queryWrapper.eq(KnowledgeDocumentsDO::getKnowledgeBaseId,chatReqVO.getKnowledge());
queryWrapper.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, chatReqVO.getKnowledge());
List<KnowledgeDocumentsDO> fileList = knowledgeDocumentsMapper.selectList(queryWrapper);
for (KnowledgeDocumentsDO knowledgeDocumentsDO : fileList) {
Long id = knowledgeDocumentsDO.getId();
@ -201,26 +198,26 @@ public class ConversationServiceImpl implements ConversationService {
knowledgeRagEmbedQueryVO.setQuery(chatReqVO.getPrompt());
String result = HttpUtils.post(llmBackendProperties.getEmbedQuery(), null, JSON.toJSONString(knowledgeRagEmbedQueryVO));
com.alibaba.fastjson.JSONArray jsonArray = JSON.parseArray(result);
if (jsonArray != null && !jsonArray.isEmpty()){
if (jsonArray != null && !jsonArray.isEmpty()) {
com.alibaba.fastjson.JSONArray jsonArray1 = (com.alibaba.fastjson.JSONArray) jsonArray.get(0);
JSONObject jsonObject = (JSONObject) jsonArray1.get(0);
knowledgeBase.append(jsonObject.get("page_content"));
}
}
}
String mess = chatReqVO.getSystemPrompt()+"<content>"+knowledgeBase.toString()+"</content>";
String mess = chatReqVO.getSystemPrompt() + "<content>" + knowledgeBase.toString() + "</content>";
// 查询历史记录消息 将查询出来的知识信息放入到 role = system 里面
List<String> messageHistoryList = stringRedisTemplate.opsForList().range(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, -1);
if (messageHistoryList != null && !messageHistoryList.isEmpty()) {
for (String messageHistory : messageHistoryList) {
ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class);
if (modelCompletionsMessage.getRole().equals("system")){
if (modelCompletionsMessage.getRole().equals("system")) {
modelCompletionsMessage.setContent(mess);
stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY+ ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
}
messages.add(modelCompletionsMessage);
}
}else {
} else {
ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
systemMessage.setRole("system");
systemMessage.setContent(mess);
@ -244,8 +241,6 @@ public class ConversationServiceImpl implements ConversationService {
}*/
ModelCompletionsReqVO.ModelCompletionsMessage message = new ModelCompletionsReqVO.ModelCompletionsMessage();
message.setRole("user");
message.setContent(chatReqVO.getPrompt());
@ -256,7 +251,7 @@ public class ConversationServiceImpl implements ConversationService {
modelCompletionsReqVO.setMessages(messages);
// baseModel aigcModelName 为aigc中的模型名称
modelCompletionsReqVO.setModel(model);
ModelCompletionsRespVO modelCompletionsRespVO = modelService.modelCompletions(selfModelUrl,modelCompletionsReqVO);
ModelCompletionsRespVO modelCompletionsRespVO = modelService.modelCompletions(selfModelUrl, modelCompletionsReqVO);
if (modelCompletionsRespVO == null) {
throw exception(MODEL_COMPLETIONS_ERROR);
}
@ -282,13 +277,15 @@ public class ConversationServiceImpl implements ConversationService {
/**
* 处理对话请求以流式方式返回结果
*
* @param chatReqVO 对话请求对象
* @param emitter SseEmitter 对象用于流式发送响应
* @param emitter SseEmitter 对象用于流式发送响应
*/
public void chatStream(ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response) {
@Override
public void chatStream (ChatReqVO chatReqVO, SseEmitter emitter, HttpServletResponse response) {
log.info("开始处理对话请求,请求参数: {}", chatReqVO);
// 检查系统提示信息如果为空则尝试从应用中获取
if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().equals("")) {
if (chatReqVO.getSystemPrompt() == null || chatReqVO.getSystemPrompt().isEmpty()) {
if (chatReqVO.getApplicationId() != null) {
log.info("系统提示信息为空,尝试从应用中获取,应用 ID: {}", chatReqVO.getApplicationId());
ApplicationRespVO application = applicationService.getApplication(chatReqVO.getApplicationId());
@ -304,15 +301,16 @@ public class ConversationServiceImpl implements ConversationService {
}
}
// 调用公共模型聊天流式处理方法
publicModelChatStream(chatReqVO, emitter, response);
publicModelChatStream(chatReqVO, emitter);
}
/**
* 公共模型聊天流式处理方法
*
* @param chatReqVO 对话请求对象
* @param emitter SseEmitter 对象用于流式发送响应
* @param emitter SseEmitter 对象用于流式发送响应
*/
public void publicModelChatStream(ChatReqVO chatReqVO, SseEmitter emitter,HttpServletResponse response) {
public void publicModelChatStream (ChatReqVO chatReqVO, SseEmitter emitter) {
log.info("开始公共模型聊天流式处理,请求参数: {}", chatReqVO);
// 检查 UUID 是否为空若为空则生成一个
if (StrUtil.isBlank(chatReqVO.getUuid())) {
@ -393,7 +391,7 @@ public class ConversationServiceImpl implements ConversationService {
log.info("存在聊天历史记录,处理历史记录消息");
for (String messageHistory : messageHistoryList) {
ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class);
if (modelCompletionsMessage.getRole().equals("system")) {
if ("system".equals(modelCompletionsMessage.getRole())) {
modelCompletionsMessage.setContent(mess);
stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
}
@ -421,7 +419,7 @@ public class ConversationServiceImpl implements ConversationService {
log.info("构建模型补全请求对象,请求参数: {}", modelCompletionsReqVO);
// 调用模型服务进行流式处理
modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter, response);
modelService.modelCompletionsStream(selfModelUrl, modelCompletionsReqVO, emitter);
// 将用户消息存入缓存
stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(message));
@ -438,13 +436,14 @@ public class ConversationServiceImpl implements ConversationService {
log.info("数据回流信息保存完成");
}
/**
* 私有模型聊天
*
* @param chatReqVO
* @return
*/
private ChatRespVO privateModelChat(ChatReqVO chatReqVO) {
private ChatRespVO privateModelChat (ChatReqVO chatReqVO) {
if (StringUtils.isBlank(chatReqVO.getUuid())) {
// 如果没有uuid就生成一个
chatReqVO.setUuid(UUID.randomUUID().toString());
@ -467,16 +466,16 @@ public class ConversationServiceImpl implements ConversationService {
if (messageHistoryList != null && !messageHistoryList.isEmpty()) {
for (String messageHistory : messageHistoryList) {
ModelCompletionsReqVO.ModelCompletionsMessage modelCompletionsMessage = JsonUtils.parseObject(messageHistory, ModelCompletionsReqVO.ModelCompletionsMessage.class);
if (modelCompletionsMessage.getRole().equals("system")){
modelCompletionsMessage.setContent(StringUtils.isNotBlank(chatReqVO.getSystemPrompt())? chatReqVO.getSystemPrompt():"");
stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY+ ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
if (modelCompletionsMessage.getRole().equals("system")) {
modelCompletionsMessage.setContent(StringUtils.isNotBlank(chatReqVO.getSystemPrompt()) ? chatReqVO.getSystemPrompt() : "");
stringRedisTemplate.opsForList().set(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), 0, JsonUtils.toJsonString(modelCompletionsMessage));
}
messages.add(modelCompletionsMessage);
}
}else {
} else {
ModelCompletionsReqVO.ModelCompletionsMessage systemMessage = new ModelCompletionsReqVO.ModelCompletionsMessage();
systemMessage.setRole("system");
systemMessage.setContent(StringUtils.isNotBlank(chatReqVO.getSystemPrompt())? chatReqVO.getSystemPrompt():"");
systemMessage.setContent(StringUtils.isNotBlank(chatReqVO.getSystemPrompt()) ? chatReqVO.getSystemPrompt() : "");
stringRedisTemplate.opsForList().rightPush(CHAT_HIStORY_REDIS_KEY + ":" + chatReqVO.getUuid(), JsonUtils.toJsonString(systemMessage));
messages.add(systemMessage);
}
@ -485,7 +484,7 @@ public class ConversationServiceImpl implements ConversationService {
modelCompletionsReqVO.setMessages(messages);
// TODO 先传固定的内容 后期和后端调通直接修改成 model 1.23 已修改
modelCompletionsReqVO.setModel(model);
ModelCompletionsRespVO modelCompletionsRespVO = modelService.modelPrivateCompletions(new HashMap<>(),modelCompletionsReqVO);
ModelCompletionsRespVO modelCompletionsRespVO = modelService.modelPrivateCompletions(new HashMap<>(), modelCompletionsReqVO);
if (modelCompletionsRespVO == null) {
throw exception(MODEL_COMPLETIONS_ERROR);
}

View File

@ -17,12 +17,15 @@ import org.apache.http.client.HttpClient;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.HttpClients;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletResponse;
import java.io.*;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@ -142,7 +145,7 @@ public class ModelService {
* @param url 模型服务的 URL
* @param req 模型补全请求对象
*/
public void modelCompletionsStream (String url, ModelCompletionsReqVO req, SseEmitter emitter, HttpServletResponse response) {
public void modelCompletionsStream (String url, ModelCompletionsReqVO req, SseEmitter emitter) {
req.setStream(true);
log.info("开始处理模型补全请求,参数: {}", req);
try {
@ -190,6 +193,9 @@ public class ModelService {
httpPost.setEntity(entity);
// 设置请求头指定请求体的内容类型为 JSON
httpPost.setHeader("Content-Type", "application/json");
httpPost.setHeader("Accept-Encoding", "gzip, deflate, br");
httpPost.setHeader("Accept", "*/*");
httpPost.setHeader("Connection", "keep-alive");
}
/**
@ -201,20 +207,31 @@ public class ModelService {
private void handleResponseEntity (HttpResponse response, SseEmitter emitter) throws IOException {
// 获取响应实体
HttpEntity responseEntity = response.getEntity();
if (responseEntity != null) {
// 使用 try-with-resources 语句自动关闭输入流和 BufferedReader
try (InputStream inputStream = responseEntity.getContent();
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
String line;
// 逐行读取响应内容并处理
while ((line = reader.readLine()) != null) {
System.out.println("接收到的响应行数据: " + line);
String content = parseStreamLine(line);
if (content != null) {
emitter.send(SseEmitter.event().data(content));
}
long lastSendTime = System.currentTimeMillis();
try (InputStream inputStream = responseEntity.getContent();
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
String line;
while ((line = reader.readLine()) != null) {
if (StringUtils.isBlank(line)) {
continue;
}
log.info("接收到的响应行数据: {}", line);
String content = parseStreamLine(line);
if (content != null) {
emitter.send(SseEmitter.event()
.data(content, MediaType.TEXT_EVENT_STREAM)
);
}
// 心跳检测
if (System.currentTimeMillis() - lastSendTime > 15_000) {
emitter.send(SseEmitter.event().comment("heartbeat"));
lastSendTime = System.currentTimeMillis();
}
}
emitter.complete();
} catch (IOException e) {
emitter.completeWithError(e);
}
}