diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java index 0842208da..ec4fb0d0a 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java @@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.llm.service.dataset; import cn.hutool.core.io.IoUtil; import cn.iocoder.yudao.framework.common.exception.ErrorCode; +import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.infra.api.file.FileApi; @@ -29,10 +30,13 @@ import com.opencsv.exceptions.CsvValidationException; import lombok.extern.slf4j.Slf4j; import org.apache.commons.compress.archivers.ArchiveEntry; import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; +import org.apache.commons.compress.archivers.tar.TarArchiveEntry; +import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; import org.apache.poi.ss.usermodel.Cell; import org.apache.poi.ss.usermodel.Sheet; import org.apache.poi.ss.usermodel.Workbook; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import org.springframework.validation.annotation.Validated; import javax.annotation.Resource; @@ -70,6 +74,8 @@ public class DatasetServiceImpl implements DatasetService { @Resource private FileApi fileApi; + + @Transactional @Override public Long createDataset (DatasetSaveReqVO createReqVO) { // 校验 @@ -270,7 +276,53 @@ public class DatasetServiceImpl implements DatasetService { }); parseFile(res); } - + public void readTarGzFile(List tarGzFiles, List datasetFiles) { + DatasetFilesSaveReqVO datasetFilesSaveReqVO = datasetFiles.get(0); + List res = new ArrayList<>(); + tarGzFiles.forEach(datasetFilesDO -> { + HttpURLConnection connection = DataSetReadFileUtils.readFile(datasetFilesDO.getDatasetFileUrl()); + if (connection != null) { + try { + InputStream inputStream = connection.getInputStream(); + GzipCompressorInputStream gzipInputStream = new GzipCompressorInputStream(inputStream); + TarArchiveInputStream tarArchiveInputStream = new TarArchiveInputStream(gzipInputStream, Charset.forName("GBK").name()); + TarArchiveEntry tarEntry; + try { + while ((tarEntry = tarArchiveInputStream.getNextTarEntry()) != null) { + if (!tarEntry.isDirectory()) { + try { + final String name = tarEntry.getName(); + byte[] fileBytes = IoUtil.readBytes(tarArchiveInputStream, false); + Map map = fileApi.llmCreateFile(name, fileBytes); + String url = map.get("url").toString(); + Long id = Long.parseLong(map.get("id").toString()); + DatasetFilesSaveReqVO vo = new DatasetFilesSaveReqVO(); + vo.setDatasetId(datasetFilesSaveReqVO.getDatasetId()); + vo.setDatasetFile(id); + vo.setDatasetFileUrl(url); + vo.setId(datasetFilesSaveReqVO.getId()); + res.add(vo); + } catch (Exception e) { + // Handle exception for individual file + } + } + } + } catch (IOException e) { + tarArchiveInputStream.close(); + } finally { + tarArchiveInputStream.close(); + } + } catch (Exception e) { + throw ServiceExceptionUtil.exception(new ErrorCode( + 11001, "请正确上传tar.gz格式的数据!!!")); + } finally { + connection.disconnect(); + } + } + }); + parseFile(res); + } + // 暂时先不用 public void readTarFile (List tarFiles, List datasetFiles) { DatasetFilesSaveReqVO datasetFilesSaveReqVO = datasetFiles.get(0); List res = new ArrayList<>(); @@ -444,7 +496,7 @@ public class DatasetServiceImpl implements DatasetService { * * @param xlsxFiles */ - public void readXlsxFile(List xlsxFiles) { + public void readXlsxFile(List xlsxFiles) { xlsxFiles.forEach(datasetFilesDO -> { Workbook sheets = DataSetReadFileUtils.readXlsxFromUrl(datasetFilesDO.getDatasetFileUrl()); if (sheets != null){ @@ -506,10 +558,10 @@ public class DatasetServiceImpl implements DatasetService { // tar文件 List tarFiles = insertDatasetFiles.stream() - .filter(datasetFilesDO -> datasetFilesDO.getDatasetFileUrl().toLowerCase().endsWith(".tar")) + .filter(datasetFilesDO -> datasetFilesDO.getDatasetFileUrl().toLowerCase().endsWith(".tar.gz")) .collect(Collectors.toList()); if (CollectionUtils.isNotEmpty(tarFiles)) { - readTarFile(tarFiles,datasetFiles); + readTarGzFile(tarFiles,datasetFiles); } // 提取文件