refactor(yudao-module-llm): 重构文件上传功能- 优化了文件上传的实现方式,使用 OkHttpClient 替代原有的 CloseableHttpClient
- 新增 getFileByte 方法获取文件字节数组 - 增加了对文件类型的处理,支持pdf
This commit is contained in:
parent
53305996f6
commit
3a41fb9100
@ -33,9 +33,11 @@ import org.springframework.stereotype.Service;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.io.*;
|
||||
import java.net.URL;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
|
||||
@ -316,8 +318,9 @@ public class RagHttpService {
|
||||
String fileId = reqVO.getFileId();
|
||||
String fileName = reqVO.getFileName();
|
||||
String fileUrl = reqVO.getFileUrl();
|
||||
String mediaType = getMediaType(fileName);
|
||||
|
||||
log.info("URL: {}, fileId: {} ,fileNam: {}, fileUrl: {}, ", ragEmbed, fileId, fileName, fileUrl);
|
||||
log.info("URL: {}, fileId: {} ,fileName: {}, fileUrl: {}, mediaType: {} ", ragEmbed, fileId, fileName, fileUrl, mediaType);
|
||||
|
||||
// 获取知识库文档
|
||||
KnowledgeDocumentsDO documents = getKnowledgeDocuments(id, fileId);
|
||||
@ -328,83 +331,109 @@ public class RagHttpService {
|
||||
// 更新文件状态为上传中
|
||||
updateFileState(documents, KnowledgeStatusEnum.UPLOADING);
|
||||
|
||||
// 创建HTTP客户端
|
||||
CloseableHttpClient httpClient = HttpClients.createDefault();
|
||||
|
||||
// 创建HTTP GET请求,获取文件内容
|
||||
HttpGet request = new HttpGet(fileUrl);
|
||||
try (CloseableHttpResponse response = httpClient.execute(request)) {
|
||||
HttpEntity entity = response.getEntity();
|
||||
if (entity != null) {
|
||||
try (InputStream inputStream = entity.getContent();
|
||||
BufferedInputStream bufferedInputStream = new BufferedInputStream(inputStream)) {
|
||||
byte[] fileBytes = Objects.requireNonNull(getFileByte(fileUrl));
|
||||
|
||||
// 标记流以便后续重置
|
||||
bufferedInputStream.mark(Integer.MAX_VALUE);
|
||||
// 检测文件编码
|
||||
String encoding = detectCharset(bufferedInputStream);
|
||||
// 创建 OkHttpClient 实例
|
||||
OkHttpClient client = new OkHttpClient();
|
||||
|
||||
// 重置流以便重新读取
|
||||
bufferedInputStream.reset();
|
||||
// 创建 MultipartBody
|
||||
RequestBody requestBody = new MultipartBody.Builder()
|
||||
.setType(MultipartBody.FORM)
|
||||
.addFormDataPart("file_id", fileId)
|
||||
.addFormDataPart("file", fileName,
|
||||
RequestBody.create(fileBytes, MediaType.parse(mediaType))
|
||||
)
|
||||
.build();
|
||||
|
||||
// 使用检测到的编码读取文件内容
|
||||
try (InputStreamReader reader = new InputStreamReader(bufferedInputStream, encoding);
|
||||
BufferedReader bufferedReader = new BufferedReader(reader)) {
|
||||
StringBuilder fileContentBuilder = new StringBuilder();
|
||||
String line;
|
||||
while ((line = bufferedReader.readLine()) != null) {
|
||||
fileContentBuilder.append(line).append(System.lineSeparator());
|
||||
}
|
||||
String fileContent = fileContentBuilder.toString();
|
||||
// 创建请求
|
||||
Request sendRequest = new Request.Builder()
|
||||
.url(ragEmbed)
|
||||
.post(requestBody)
|
||||
.addHeader("accept", "application/json")
|
||||
.build();
|
||||
|
||||
// 将文件内容转换为UTF-8编码的字节数组
|
||||
byte[] utf8Bytes = fileContent.getBytes(StandardCharsets.UTF_8);
|
||||
// 发送请求
|
||||
try (Response sendResponse = client.newCall(sendRequest).execute()) {
|
||||
if (sendResponse.body() != null) {
|
||||
String body = sendResponse.body().string();
|
||||
log.info("!!!!!!!!!! 响应原始内容 Response: {}", body);
|
||||
|
||||
try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(utf8Bytes);
|
||||
ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) {
|
||||
JSONObject resJson = JSONObject.parseObject(body);
|
||||
|
||||
int bufferSize = 1024;
|
||||
byte[] byteArray = new byte[bufferSize];
|
||||
int bytesRead;
|
||||
|
||||
// 读取字节数组并写入输出流
|
||||
while ((bytesRead = byteArrayInputStream.read(byteArray)) != -1) {
|
||||
outputStream.write(byteArray, 0, bytesRead);
|
||||
}
|
||||
|
||||
// 将输出流转换为字节数组
|
||||
byte[] result = outputStream.toByteArray();
|
||||
|
||||
// 发送HTTP POST请求,上传文件内容
|
||||
String body = HttpRequest.post(ragEmbed)
|
||||
.form("file", result, fileName)
|
||||
.form("file_id", fileId)
|
||||
.execute()
|
||||
.body();
|
||||
|
||||
// 打印响应内容
|
||||
log.info("响应原始内容 String: {}", body);
|
||||
|
||||
// 解析响应内容
|
||||
RagEmbedRespVO ragEmbedRespVO = JSON.parseObject(body, RagEmbedRespVO.class);
|
||||
log.info("解析响应原始内容 ragEmbedRespVO:{}", ragEmbedRespVO);
|
||||
|
||||
// 根据响应状态更新文件状态
|
||||
if (ragEmbedRespVO.isStatus()) {
|
||||
updateFileState(documents, KnowledgeStatusEnum.UPLOAD_SUCCESS);
|
||||
} else {
|
||||
updateFileState(documents, KnowledgeStatusEnum.UPLOAD_FAILED);
|
||||
throw new RuntimeException("文件上传失败:" + ragEmbedRespVO.getMessage());
|
||||
}
|
||||
} catch (UnirestException e) {
|
||||
throw new RuntimeException("文件上传失败: " + e.getMessage());
|
||||
}
|
||||
// 1: 先判断是否存在 detail,
|
||||
String detail = resJson.getString("detail");
|
||||
if (StringUtils.isNotEmpty(detail)) {
|
||||
handleFailure(documents, detail);
|
||||
} else {
|
||||
Boolean status = resJson.getBoolean("status");
|
||||
if (!status) {
|
||||
handleFailure(documents, resJson.getString("message"));
|
||||
} else {
|
||||
processResponse(body, documents);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
handleFailure(documents, FILE_UPLOAD_FAILED_MSG);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
handleFailure(documents, FILE_UPLOAD_FAILED_MSG, e);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取文件字节数组
|
||||
*
|
||||
* @param fileUrl 文件地址
|
||||
* @return 文件字节数组
|
||||
*/
|
||||
public static byte[] getFileByte (String fileUrl) {
|
||||
try (InputStream inputStream = new URL(fileUrl).openStream();
|
||||
ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) {
|
||||
|
||||
// 缓冲区大小
|
||||
byte[] buffer = new byte[1024];
|
||||
int bytesRead;
|
||||
|
||||
// 读取文件内容并写入 ByteArrayOutputStream
|
||||
while ((bytesRead = inputStream.read(buffer)) != -1) {
|
||||
outputStream.write(buffer, 0, bytesRead);
|
||||
}
|
||||
|
||||
// 返回字节数组
|
||||
return outputStream.toByteArray();
|
||||
|
||||
} catch (IOException e) {
|
||||
log.error("Failed to read remote file: {}", e.getMessage());
|
||||
|
||||
throw exception(new ErrorCode(10001_001, "文件读取错误"));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取文件类型
|
||||
*
|
||||
* @param fileName 文件名
|
||||
* @return 文件类型
|
||||
*/
|
||||
private static String getMediaType(String fileName) {
|
||||
String fileSuffix = fileName.substring(fileName.lastIndexOf(".") + 1);
|
||||
switch (fileSuffix) {
|
||||
case "pdf":
|
||||
return "application/pdf";
|
||||
case "md":
|
||||
return "text/x-markdown";
|
||||
case "docx":
|
||||
return "application/vnd.openxmlformats-officedocument.wordprocessingml.document";
|
||||
case "txt":
|
||||
return "text/plain";
|
||||
}
|
||||
return "application/octet-stream";
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 处理响应结果
|
||||
*/
|
||||
|
Loading…
x
Reference in New Issue
Block a user