diff --git a/yudao-module-infra/yudao-module-infra-api/src/main/java/cn/iocoder/yudao/module/infra/api/file/FileApi.java b/yudao-module-infra/yudao-module-infra-api/src/main/java/cn/iocoder/yudao/module/infra/api/file/FileApi.java index c41c6e039..ce029018d 100644 --- a/yudao-module-infra/yudao-module-infra-api/src/main/java/cn/iocoder/yudao/module/infra/api/file/FileApi.java +++ b/yudao-module-infra/yudao-module-infra-api/src/main/java/cn/iocoder/yudao/module/infra/api/file/FileApi.java @@ -1,5 +1,7 @@ package cn.iocoder.yudao.module.infra.api.file; +import java.util.Map; + /** * 文件 API 接口 * @@ -38,4 +40,6 @@ public interface FileApi { */ String createFile(String name, String path, byte[] content); + Map llmCreateFile(String originalFilename, byte[] bytes); + } diff --git a/yudao-module-infra/yudao-module-infra-biz/src/main/java/cn/iocoder/yudao/module/infra/api/file/FileApiImpl.java b/yudao-module-infra/yudao-module-infra-biz/src/main/java/cn/iocoder/yudao/module/infra/api/file/FileApiImpl.java index 05fb946fe..c0aebb3ca 100644 --- a/yudao-module-infra/yudao-module-infra-biz/src/main/java/cn/iocoder/yudao/module/infra/api/file/FileApiImpl.java +++ b/yudao-module-infra/yudao-module-infra-biz/src/main/java/cn/iocoder/yudao/module/infra/api/file/FileApiImpl.java @@ -1,10 +1,13 @@ package cn.iocoder.yudao.module.infra.api.file; +import cn.iocoder.yudao.module.infra.controller.admin.file.vo.file.FileCreateRespVo; import cn.iocoder.yudao.module.infra.service.file.FileService; import org.springframework.stereotype.Service; import org.springframework.validation.annotation.Validated; import javax.annotation.Resource; +import java.util.HashMap; +import java.util.Map; /** * 文件 API 实现类 @@ -23,4 +26,13 @@ public class FileApiImpl implements FileApi { return fileService.createFile(name, path, content); } + @Override + public Map llmCreateFile(String originalFilename, byte[] bytes) { + FileCreateRespVo fileCreateRespVo = fileService.llmcreateFile(originalFilename, null, bytes); + Map map = new HashMap<>(); + map.put("url",fileCreateRespVo.getUrl()); + map.put("id",fileCreateRespVo.getId()); + return map; + } + } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java index ccaec93e5..6ddc3e361 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java @@ -1,9 +1,11 @@ package cn.iocoder.yudao.module.llm.service.dataset; +import cn.hutool.core.io.IoUtil; import cn.iocoder.yudao.framework.common.exception.ErrorCode; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; +import cn.iocoder.yudao.module.infra.api.file.FileApi; import cn.iocoder.yudao.module.llm.constant.DataConstants; import cn.iocoder.yudao.module.llm.controller.admin.dataset.dto.DataJsonTemplate; import cn.iocoder.yudao.module.llm.controller.admin.dataset.vo.*; @@ -25,6 +27,8 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.opencsv.exceptions.CsvValidationException; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.compress.archivers.ArchiveEntry; +import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; import org.apache.poi.ss.usermodel.Cell; import org.apache.poi.ss.usermodel.Sheet; import org.apache.poi.ss.usermodel.Workbook; @@ -36,7 +40,10 @@ import java.io.*; import java.net.HttpURLConnection; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; +import java.util.zip.ZipEntry; +import java.util.zip.ZipInputStream; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*; @@ -59,6 +66,8 @@ public class DatasetServiceImpl implements DatasetService { private DatasetQuestionMapper datasetQuestionMapper; @Resource private DatasetAnswerMapper datasetAnswerMapper; + @Resource + private FileApi fileApi; @Override public Long createDataset (DatasetSaveReqVO createReqVO) { @@ -219,6 +228,90 @@ public class DatasetServiceImpl implements DatasetService { return root; } + public void readZipFile (List zipFiles, List datasetFiles) { + DatasetFilesSaveReqVO datasetFilesSaveReqVO = datasetFiles.get(0); + List res = new ArrayList<>(); + zipFiles.forEach(datasetFilesDO -> { + HttpURLConnection connection = DataSetReadFileUtils.readFile(datasetFilesDO.getDatasetFileUrl()); + if (connection != null) { + try { + InputStream inputStream = connection.getInputStream(); + ZipInputStream zipInputStream = new ZipInputStream(inputStream); + ZipEntry zipEntry = null; + try { + while ((zipEntry = zipInputStream.getNextEntry()) != null ){ + try { + final String name = zipEntry.getName(); + InputStream fileIn = zipInputStream; + Map map = fileApi.llmCreateFile(name, IoUtil.readBytes(fileIn)); + String url = map.get("url").toString(); + Long id = Long.parseLong(map.get("id").toString()); + DatasetFilesSaveReqVO vo = new DatasetFilesSaveReqVO(); + vo.setDatasetId(datasetFilesSaveReqVO.getDatasetId()); + vo.setDatasetFile(id); + vo.setDatasetFileUrl(url); + vo.setId(datasetFilesSaveReqVO.getId()); + res.add(vo); + } catch (Exception e) { + zipInputStream.closeEntry(); + } + } + } catch (IOException e) { + zipInputStream.close(); + } + } catch (Exception e) { + throw exception(new ErrorCode( + 11000, "请正确上传zip格式得数据!!!")); + } finally { + connection.disconnect(); + } + } + }); + parseFile(res); + } + + public void readTarFile (List tarFiles, List datasetFiles) { + DatasetFilesSaveReqVO datasetFilesSaveReqVO = datasetFiles.get(0); + List res = new ArrayList<>(); + tarFiles.forEach(datasetFilesDO -> { + HttpURLConnection connection = DataSetReadFileUtils.readFile(datasetFilesDO.getDatasetFileUrl()); + if (connection != null) { + try { + InputStream inputStream = connection.getInputStream(); + TarArchiveInputStream tarArchiveInputStream = new TarArchiveInputStream(inputStream); + ArchiveEntry archiveEntry = null; + try { + while ((archiveEntry = tarArchiveInputStream.getNextEntry()) != null ){ + try { + final String name = archiveEntry.getName(); + InputStream fileIn = tarArchiveInputStream; + Map map = fileApi.llmCreateFile(name, IoUtil.readBytes(fileIn)); + String url = map.get("url").toString(); + Long id = Long.parseLong(map.get("id").toString()); + DatasetFilesSaveReqVO vo = new DatasetFilesSaveReqVO(); + vo.setDatasetId(datasetFilesSaveReqVO.getDatasetId()); + vo.setDatasetFile(id); + vo.setDatasetFileUrl(url); + vo.setId(datasetFilesSaveReqVO.getId()); + res.add(vo); + } catch (Exception e) { + + } + } + } catch (IOException e) { + tarArchiveInputStream.close(); + } + } catch (Exception e) { + throw exception(new ErrorCode( + 11000, "请正确上传tar格式得数据!!!")); + } finally { + connection.disconnect(); + } + } + }); + parseFile(res); + } + public void readJsonFile (List jsonFiles) { jsonFiles.forEach(datasetFilesDO -> { @@ -401,6 +494,23 @@ public class DatasetServiceImpl implements DatasetService { public void parseFile (List datasetFiles) { List insertDatasetFiles = BeanUtils.toBean(datasetFiles, DatasetFilesDO.class); datasetFilesMapper.insertBatch(insertDatasetFiles, 100); + + // zip文件 + List zipFiles = insertDatasetFiles.stream() + .filter(datasetFilesDO -> datasetFilesDO.getDatasetFileUrl().toLowerCase().endsWith(".zip")) + .collect(Collectors.toList()); + if (CollectionUtils.isNotEmpty(zipFiles)) { + readZipFile(zipFiles,datasetFiles); + } + + // tar文件 + List tarFiles = insertDatasetFiles.stream() + .filter(datasetFilesDO -> datasetFilesDO.getDatasetFileUrl().toLowerCase().endsWith(".zip")) + .collect(Collectors.toList()); + if (CollectionUtils.isNotEmpty(tarFiles)) { + readTarFile(tarFiles,datasetFiles); + } + // 提取文件 List jsonFiles = insertDatasetFiles.stream() .filter(datasetFilesDO -> datasetFilesDO.getDatasetFileUrl().toLowerCase().endsWith(".json"))