diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java index 881681371..e8c07f524 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/http/RagHttpService.java @@ -33,9 +33,11 @@ import org.springframework.stereotype.Service; import javax.annotation.Resource; import java.io.*; +import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import java.util.Objects; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; @@ -316,8 +318,9 @@ public class RagHttpService { String fileId = reqVO.getFileId(); String fileName = reqVO.getFileName(); String fileUrl = reqVO.getFileUrl(); + String mediaType = getMediaType(fileName); - log.info("URL: {}, fileId: {} ,fileNam: {}, fileUrl: {}, ", ragEmbed, fileId, fileName, fileUrl); + log.info("URL: {}, fileId: {} ,fileName: {}, fileUrl: {}, mediaType: {} ", ragEmbed, fileId, fileName, fileUrl, mediaType); // 获取知识库文档 KnowledgeDocumentsDO documents = getKnowledgeDocuments(id, fileId); @@ -328,83 +331,109 @@ public class RagHttpService { // 更新文件状态为上传中 updateFileState(documents, KnowledgeStatusEnum.UPLOADING); - // 创建HTTP客户端 - CloseableHttpClient httpClient = HttpClients.createDefault(); - // 创建HTTP GET请求,获取文件内容 - HttpGet request = new HttpGet(fileUrl); - try (CloseableHttpResponse response = httpClient.execute(request)) { - HttpEntity entity = response.getEntity(); - if (entity != null) { - try (InputStream inputStream = entity.getContent(); - BufferedInputStream bufferedInputStream = new BufferedInputStream(inputStream)) { + byte[] fileBytes = Objects.requireNonNull(getFileByte(fileUrl)); - // 标记流以便后续重置 - bufferedInputStream.mark(Integer.MAX_VALUE); - // 检测文件编码 - String encoding = detectCharset(bufferedInputStream); + // 创建 OkHttpClient 实例 + OkHttpClient client = new OkHttpClient(); - // 重置流以便重新读取 - bufferedInputStream.reset(); + // 创建 MultipartBody + RequestBody requestBody = new MultipartBody.Builder() + .setType(MultipartBody.FORM) + .addFormDataPart("file_id", fileId) + .addFormDataPart("file", fileName, + RequestBody.create(fileBytes, MediaType.parse(mediaType)) + ) + .build(); - // 使用检测到的编码读取文件内容 - try (InputStreamReader reader = new InputStreamReader(bufferedInputStream, encoding); - BufferedReader bufferedReader = new BufferedReader(reader)) { - StringBuilder fileContentBuilder = new StringBuilder(); - String line; - while ((line = bufferedReader.readLine()) != null) { - fileContentBuilder.append(line).append(System.lineSeparator()); - } - String fileContent = fileContentBuilder.toString(); + // 创建请求 + Request sendRequest = new Request.Builder() + .url(ragEmbed) + .post(requestBody) + .addHeader("accept", "application/json") + .build(); - // 将文件内容转换为UTF-8编码的字节数组 - byte[] utf8Bytes = fileContent.getBytes(StandardCharsets.UTF_8); + // 发送请求 + try (Response sendResponse = client.newCall(sendRequest).execute()) { + if (sendResponse.body() != null) { + String body = sendResponse.body().string(); + log.info("!!!!!!!!!! 响应原始内容 Response: {}", body); - try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(utf8Bytes); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) { + JSONObject resJson = JSONObject.parseObject(body); - int bufferSize = 1024; - byte[] byteArray = new byte[bufferSize]; - int bytesRead; - - // 读取字节数组并写入输出流 - while ((bytesRead = byteArrayInputStream.read(byteArray)) != -1) { - outputStream.write(byteArray, 0, bytesRead); - } - - // 将输出流转换为字节数组 - byte[] result = outputStream.toByteArray(); - - // 发送HTTP POST请求,上传文件内容 - String body = HttpRequest.post(ragEmbed) - .form("file", result, fileName) - .form("file_id", fileId) - .execute() - .body(); - - // 打印响应内容 - log.info("响应原始内容 String: {}", body); - - // 解析响应内容 - RagEmbedRespVO ragEmbedRespVO = JSON.parseObject(body, RagEmbedRespVO.class); - log.info("解析响应原始内容 ragEmbedRespVO:{}", ragEmbedRespVO); - - // 根据响应状态更新文件状态 - if (ragEmbedRespVO.isStatus()) { - updateFileState(documents, KnowledgeStatusEnum.UPLOAD_SUCCESS); - } else { - updateFileState(documents, KnowledgeStatusEnum.UPLOAD_FAILED); - throw new RuntimeException("文件上传失败:" + ragEmbedRespVO.getMessage()); - } - } catch (UnirestException e) { - throw new RuntimeException("文件上传失败: " + e.getMessage()); - } + // 1: 先判断是否存在 detail, + String detail = resJson.getString("detail"); + if (StringUtils.isNotEmpty(detail)) { + handleFailure(documents, detail); + } else { + Boolean status = resJson.getBoolean("status"); + if (!status) { + handleFailure(documents, resJson.getString("message")); + } else { + processResponse(body, documents); } } + } else { + handleFailure(documents, FILE_UPLOAD_FAILED_MSG); } + } catch (IOException e) { + handleFailure(documents, FILE_UPLOAD_FAILED_MSG, e); + } + + + } + + /** + * 获取文件字节数组 + * + * @param fileUrl 文件地址 + * @return 文件字节数组 + */ + public static byte[] getFileByte (String fileUrl) { + try (InputStream inputStream = new URL(fileUrl).openStream(); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) { + + // 缓冲区大小 + byte[] buffer = new byte[1024]; + int bytesRead; + + // 读取文件内容并写入 ByteArrayOutputStream + while ((bytesRead = inputStream.read(buffer)) != -1) { + outputStream.write(buffer, 0, bytesRead); + } + + // 返回字节数组 + return outputStream.toByteArray(); + + } catch (IOException e) { + log.error("Failed to read remote file: {}", e.getMessage()); + + throw exception(new ErrorCode(10001_001, "文件读取错误")); } } + /** + * 获取文件类型 + * + * @param fileName 文件名 + * @return 文件类型 + */ + private static String getMediaType(String fileName) { + String fileSuffix = fileName.substring(fileName.lastIndexOf(".") + 1); + switch (fileSuffix) { + case "pdf": + return "application/pdf"; + case "md": + return "text/x-markdown"; + case "docx": + return "application/vnd.openxmlformats-officedocument.wordprocessingml.document"; + case "txt": + return "text/plain"; + } + return "application/octet-stream"; + } + + /** * 处理响应结果 */