refactor(yudao-module-llm): 重构文件上传功能- 优化了文件上传的实现方式,使用 OkHttpClient 替代原有的 CloseableHttpClient

- 新增 getFileByte 方法获取文件字节数组
- 增加了对文件类型的处理,支持pdf
This commit is contained in:
Liuyang 2025-02-21 10:03:14 +08:00
parent 53305996f6
commit 3a41fb9100

View File

@ -33,9 +33,11 @@ import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.io.*;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@ -316,8 +318,9 @@ public class RagHttpService {
String fileId = reqVO.getFileId();
String fileName = reqVO.getFileName();
String fileUrl = reqVO.getFileUrl();
String mediaType = getMediaType(fileName);
log.info("URL: {}, fileId: {} ,fileNam: {}, fileUrl: {}, ", ragEmbed, fileId, fileName, fileUrl);
log.info("URL: {}, fileId: {} ,fileName: {}, fileUrl: {}, mediaType: {} ", ragEmbed, fileId, fileName, fileUrl, mediaType);
// 获取知识库文档
KnowledgeDocumentsDO documents = getKnowledgeDocuments(id, fileId);
@ -328,83 +331,109 @@ public class RagHttpService {
// 更新文件状态为上传中
updateFileState(documents, KnowledgeStatusEnum.UPLOADING);
// 创建HTTP客户端
CloseableHttpClient httpClient = HttpClients.createDefault();
// 创建HTTP GET请求获取文件内容
HttpGet request = new HttpGet(fileUrl);
try (CloseableHttpResponse response = httpClient.execute(request)) {
HttpEntity entity = response.getEntity();
if (entity != null) {
try (InputStream inputStream = entity.getContent();
BufferedInputStream bufferedInputStream = new BufferedInputStream(inputStream)) {
byte[] fileBytes = Objects.requireNonNull(getFileByte(fileUrl));
// 标记流以便后续重置
bufferedInputStream.mark(Integer.MAX_VALUE);
// 检测文件编码
String encoding = detectCharset(bufferedInputStream);
// 创建 OkHttpClient 实例
OkHttpClient client = new OkHttpClient();
// 重置流以便重新读取
bufferedInputStream.reset();
// 创建 MultipartBody
RequestBody requestBody = new MultipartBody.Builder()
.setType(MultipartBody.FORM)
.addFormDataPart("file_id", fileId)
.addFormDataPart("file", fileName,
RequestBody.create(fileBytes, MediaType.parse(mediaType))
)
.build();
// 使用检测到的编码读取文件内容
try (InputStreamReader reader = new InputStreamReader(bufferedInputStream, encoding);
BufferedReader bufferedReader = new BufferedReader(reader)) {
StringBuilder fileContentBuilder = new StringBuilder();
String line;
while ((line = bufferedReader.readLine()) != null) {
fileContentBuilder.append(line).append(System.lineSeparator());
}
String fileContent = fileContentBuilder.toString();
// 创建请求
Request sendRequest = new Request.Builder()
.url(ragEmbed)
.post(requestBody)
.addHeader("accept", "application/json")
.build();
// 将文件内容转换为UTF-8编码的字节数组
byte[] utf8Bytes = fileContent.getBytes(StandardCharsets.UTF_8);
// 发送请求
try (Response sendResponse = client.newCall(sendRequest).execute()) {
if (sendResponse.body() != null) {
String body = sendResponse.body().string();
log.info("!!!!!!!!!! 响应原始内容 Response: {}", body);
try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(utf8Bytes);
ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) {
JSONObject resJson = JSONObject.parseObject(body);
int bufferSize = 1024;
byte[] byteArray = new byte[bufferSize];
int bytesRead;
// 读取字节数组并写入输出流
while ((bytesRead = byteArrayInputStream.read(byteArray)) != -1) {
outputStream.write(byteArray, 0, bytesRead);
}
// 将输出流转换为字节数组
byte[] result = outputStream.toByteArray();
// 发送HTTP POST请求上传文件内容
String body = HttpRequest.post(ragEmbed)
.form("file", result, fileName)
.form("file_id", fileId)
.execute()
.body();
// 打印响应内容
log.info("响应原始内容 String: {}", body);
// 解析响应内容
RagEmbedRespVO ragEmbedRespVO = JSON.parseObject(body, RagEmbedRespVO.class);
log.info("解析响应原始内容 ragEmbedRespVO:{}", ragEmbedRespVO);
// 根据响应状态更新文件状态
if (ragEmbedRespVO.isStatus()) {
updateFileState(documents, KnowledgeStatusEnum.UPLOAD_SUCCESS);
} else {
updateFileState(documents, KnowledgeStatusEnum.UPLOAD_FAILED);
throw new RuntimeException("文件上传失败:" + ragEmbedRespVO.getMessage());
}
} catch (UnirestException e) {
throw new RuntimeException("文件上传失败: " + e.getMessage());
}
// 1: 先判断是否存在 detail
String detail = resJson.getString("detail");
if (StringUtils.isNotEmpty(detail)) {
handleFailure(documents, detail);
} else {
Boolean status = resJson.getBoolean("status");
if (!status) {
handleFailure(documents, resJson.getString("message"));
} else {
processResponse(body, documents);
}
}
} else {
handleFailure(documents, FILE_UPLOAD_FAILED_MSG);
}
} catch (IOException e) {
handleFailure(documents, FILE_UPLOAD_FAILED_MSG, e);
}
}
/**
* 获取文件字节数组
*
* @param fileUrl 文件地址
* @return 文件字节数组
*/
public static byte[] getFileByte (String fileUrl) {
try (InputStream inputStream = new URL(fileUrl).openStream();
ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) {
// 缓冲区大小
byte[] buffer = new byte[1024];
int bytesRead;
// 读取文件内容并写入 ByteArrayOutputStream
while ((bytesRead = inputStream.read(buffer)) != -1) {
outputStream.write(buffer, 0, bytesRead);
}
// 返回字节数组
return outputStream.toByteArray();
} catch (IOException e) {
log.error("Failed to read remote file: {}", e.getMessage());
throw exception(new ErrorCode(10001_001, "文件读取错误"));
}
}
/**
* 获取文件类型
*
* @param fileName 文件名
* @return 文件类型
*/
private static String getMediaType(String fileName) {
String fileSuffix = fileName.substring(fileName.lastIndexOf(".") + 1);
switch (fileSuffix) {
case "pdf":
return "application/pdf";
case "md":
return "text/x-markdown";
case "docx":
return "application/vnd.openxmlformats-officedocument.wordprocessingml.document";
case "txt":
return "text/plain";
}
return "application/octet-stream";
}
/**
* 处理响应结果
*/