From 7e3ca47a0fcc7ecb67568f1f08cb59394011734e Mon Sep 17 00:00:00 2001 From: Liuyang <2746366019@qq.com> Date: Mon, 13 Jan 2025 16:35:00 +0800 Subject: [PATCH] =?UTF-8?q?Csv=20=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../module/llm/enums/ErrorCodeConstants.java | 1 + yudao-module-llm/yudao-module-llm-biz/pom.xml | 7 + .../service/dataset/DatasetServiceImpl.java | 152 ++++++++++++------ .../llm/utils/DataSetReadFileUtils.java | 79 ++++++++- .../module/llm/utils/vo/CsvDataSetVO.java | 19 +++ 5 files changed, 207 insertions(+), 51 deletions(-) create mode 100644 yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/vo/CsvDataSetVO.java diff --git a/yudao-module-llm/yudao-module-llm-api/src/main/java/cn/iocoder/yudao/module/llm/enums/ErrorCodeConstants.java b/yudao-module-llm/yudao-module-llm-api/src/main/java/cn/iocoder/yudao/module/llm/enums/ErrorCodeConstants.java index 1c803f206..250b7a186 100644 --- a/yudao-module-llm/yudao-module-llm-api/src/main/java/cn/iocoder/yudao/module/llm/enums/ErrorCodeConstants.java +++ b/yudao-module-llm/yudao-module-llm-api/src/main/java/cn/iocoder/yudao/module/llm/enums/ErrorCodeConstants.java @@ -106,4 +106,5 @@ public interface ErrorCodeConstants { ErrorCode OPTIMIZE_PROMPT_NOT_EXISTS = new ErrorCode(10044, "优化后信息不存在"); + ErrorCode PARSE_CSV_ERROR = new ErrorCode(10034, "请正确上传csv格式得数据!!!"); } diff --git a/yudao-module-llm/yudao-module-llm-biz/pom.xml b/yudao-module-llm/yudao-module-llm-biz/pom.xml index 604e3f010..ef61452cb 100644 --- a/yudao-module-llm/yudao-module-llm-biz/pom.xml +++ b/yudao-module-llm/yudao-module-llm-biz/pom.xml @@ -71,6 +71,13 @@ 1.0.3 + + + com.opencsv + opencsv + 5.9 + + diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java index 583094c79..e89b35fd2 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java @@ -16,34 +16,32 @@ import cn.iocoder.yudao.module.llm.dal.mysql.dataset.DatasetFilesMapper; import cn.iocoder.yudao.module.llm.dal.mysql.dataset.DatasetMapper; import cn.iocoder.yudao.module.llm.dal.mysql.dataset.DatasetQuestionMapper; import cn.iocoder.yudao.module.llm.utils.DataSetReadFileUtils; +import cn.iocoder.yudao.module.llm.utils.vo.CsvDataSetVO; +import com.alibaba.fastjson.JSON; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; -import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.core.toolkit.CollectionUtils; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import org.springframework.validation.annotation.Validated; import javax.annotation.Resource; - import java.io.*; import java.net.HttpURLConnection; -import java.net.URL; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.stream.Collectors; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; -import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.DATASET_NAME_EXISTS; -import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.DATASET_NOT_EXISTS; +import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*; /** * 数据集 Service 实现类 * * @author 华大大模型 */ +@Slf4j @Service @Validated public class DatasetServiceImpl implements DatasetService { @@ -58,17 +56,17 @@ public class DatasetServiceImpl implements DatasetService { private DatasetAnswerMapper datasetAnswerMapper; @Override - public Long createDataset(DatasetSaveReqVO createReqVO) { + public Long createDataset (DatasetSaveReqVO createReqVO) { // 校验 validateDatasetNameExists(createReqVO); - if(createReqVO.getType() == null){ + if (createReqVO.getType() == null) { createReqVO.setType(0); } // 插入 DatasetDO dataset = BeanUtils.toBean(createReqVO, DatasetDO.class); datasetMapper.insert(dataset); List datasetFiles = createReqVO.getDatasetFiles(); - if (CollectionUtils.isNotEmpty(datasetFiles)){ + if (CollectionUtils.isNotEmpty(datasetFiles)) { datasetFiles.stream().forEach( datasetFilesSaveReqVO -> { datasetFilesSaveReqVO.setDatasetId(dataset.getId()); @@ -80,15 +78,15 @@ public class DatasetServiceImpl implements DatasetService { dataset.setDataLength(count); Long annoCount = datasetQuestionMapper.selectCount(new LambdaQueryWrapper() .eq(DatasetQuestionDO::getDatasetId, dataset.getId()) - .eq(DatasetQuestionDO::getStatus,2)); - double ratio = count == 0 ? 0 : ((double) annoCount / count) *100; + .eq(DatasetQuestionDO::getStatus, 2)); + double ratio = count == 0 ? 0 : ((double) annoCount / count) * 100; Integer formattedRatio = ratio == 0 ? 0 : (int) ratio; Integer status = formattedRatio == 100 ? 2 : 1; - if (formattedRatio != null){ + if (formattedRatio != null) { dataset.setAnnotateProgress(formattedRatio); } dataset.setStatus(status); - if(annoCount == 0){ + if (annoCount == 0) { dataset.setStatus(0); } datasetMapper.updateById(dataset); @@ -97,11 +95,11 @@ public class DatasetServiceImpl implements DatasetService { } - private static long getFileContentLength(File file)throws IOException { + private static long getFileContentLength (File file) throws IOException { FileInputStream fis = new FileInputStream(file); byte[] buffer = new byte[1024]; - long charCount =0; - while(fis.read(buffer)!=-1) { + long charCount = 0; + while (fis.read(buffer) != -1) { charCount += new String(buffer).length(); } fis.close(); @@ -109,14 +107,14 @@ public class DatasetServiceImpl implements DatasetService { } @Override - public void updateDataset(DatasetSaveReqVO updateReqVO) { + public void updateDataset (DatasetSaveReqVO updateReqVO) { // 校验存在 validateDatasetExists(updateReqVO.getId()); validateDatasetNameExists(updateReqVO); // 更新 DatasetDO updateObj = BeanUtils.toBean(updateReqVO, DatasetDO.class); List datasetFiles = updateReqVO.getDatasetFiles(); - if (CollectionUtils.isNotEmpty(datasetFiles)){ + if (CollectionUtils.isNotEmpty(datasetFiles)) { datasetFiles.stream().forEach( datasetFilesSaveReqVO -> { datasetFilesSaveReqVO.setDatasetId(updateObj.getId()); @@ -128,15 +126,15 @@ public class DatasetServiceImpl implements DatasetService { updateObj.setDataLength(count); Long annoCount = datasetQuestionMapper.selectCount(new LambdaQueryWrapper() .eq(DatasetQuestionDO::getDatasetId, updateObj.getId()) - .eq(DatasetQuestionDO::getStatus,2)); - double ratio = count == 0 ? 0 : ((double) annoCount / count) *100; + .eq(DatasetQuestionDO::getStatus, 2)); + double ratio = count == 0 ? 0 : ((double) annoCount / count) * 100; Integer formattedRatio = ratio == 0 ? 0 : (int) ratio; Integer status = formattedRatio == 100 ? 2 : 1; - if (formattedRatio != null){ + if (formattedRatio != null) { updateObj.setAnnotateProgress(formattedRatio); } updateObj.setStatus(status); - if(annoCount == 0){ + if (annoCount == 0) { updateObj.setStatus(0); } } @@ -144,33 +142,34 @@ public class DatasetServiceImpl implements DatasetService { } @Override - public void deleteDataset(Long id) { + public void deleteDataset (Long id) { // 校验存在 validateDatasetExists(id); // 删除 datasetMapper.deleteById(id); } - private void validateDatasetExists(Long id) { + private void validateDatasetExists (Long id) { if (datasetMapper.selectById(id) == null) { throw exception(DATASET_NOT_EXISTS); } } - private void validateDatasetNameExists(DatasetSaveReqVO dateReqVO){ + + private void validateDatasetNameExists (DatasetSaveReqVO dateReqVO) { LambdaQueryWrapper wrapper = new LambdaQueryWrapper() .eq(DatasetDO::getDatasetName, dateReqVO.getDatasetName()); - if (dateReqVO.getId() != null){ + if (dateReqVO.getId() != null) { wrapper.ne(DatasetDO::getId, dateReqVO.getId()); } List datasetDOS = datasetMapper.selectList(wrapper); - if (CollectionUtils.isNotEmpty(datasetDOS)){ + if (CollectionUtils.isNotEmpty(datasetDOS)) { throw exception(DATASET_NAME_EXISTS); } } @Override - public DatasetRespVO getDataset(Long id) { + public DatasetRespVO getDataset (Long id) { DatasetDO datasetDO = datasetMapper.selectById(id); DatasetRespVO datasetRespVO = BeanUtils.toBean(datasetDO, DatasetRespVO.class); /*List datasetQuestionDO = datasetQuestionMapper.selectList(new LambdaQueryWrapper().eq(DatasetQuestionDO::getDatasetId, id)); @@ -180,12 +179,12 @@ public class DatasetServiceImpl implements DatasetService { } @Override - public PageResult getDatasetPage(DatasetPageReqVO pageReqVO) { + public PageResult getDatasetPage (DatasetPageReqVO pageReqVO) { return datasetMapper.selectPage(pageReqVO); } @Override - public List queryAll() { + public List queryAll () { /*List datasetDOS0 = datasetMapper.selectList(new LambdaQueryWrapper().eq(DatasetDO::getType, DataConstants.dataTypePrivate)); List datasetRespVOS0 = BeanUtils.toBean(datasetDOS0, DatasetRespVO.class); List datasetDOS1 = datasetMapper.selectList(new LambdaQueryWrapper().eq(DatasetDO::getType, DataConstants.dataTypePublic)); @@ -194,7 +193,7 @@ public class DatasetServiceImpl implements DatasetService { result.add(datasetRespVOS0); result.add(datasetRespVOS1);*/ List datasetDOS = datasetMapper.selectList(new LambdaQueryWrapper() - .eq(DatasetDO::getStatus,2)); // 获取所有数据集 + .eq(DatasetDO::getStatus, 2)); // 获取所有数据集 // 创建两个根节点,分别代表两种 type DatasetTreeNode privateRoot = new DatasetTreeNode(DataConstants.dataTypePrivate); @@ -203,9 +202,9 @@ public class DatasetServiceImpl implements DatasetService { for (DatasetDO datasetDO : datasetDOS) { DatasetRespVO datasetRespVO = BeanUtils.toBean(datasetDO, DatasetRespVO.class); // 根据 type 字段决定节点的位置 - if (datasetRespVO.getType()==DataConstants.dataTypePrivate) { + if (datasetRespVO.getType() == DataConstants.dataTypePrivate) { privateRoot.getChildren().add(datasetRespVO); - } else if (datasetRespVO.getType()==DataConstants.dataTypePublic) { + } else if (datasetRespVO.getType() == DataConstants.dataTypePublic) { publicRoot.getChildren().add(datasetRespVO); } } @@ -216,10 +215,10 @@ public class DatasetServiceImpl implements DatasetService { } - public void readJsonFile(List jsonFiles){ + public void readJsonFile (List jsonFiles) { jsonFiles.forEach(datasetFilesDO -> { HttpURLConnection connection = DataSetReadFileUtils.readFile(datasetFilesDO.getDatasetFileUrl()); - if (connection != null){ + if (connection != null) { try (BufferedReader in = new BufferedReader(new InputStreamReader(connection.getInputStream(), "UTF-8"))) { StringBuilder content = new StringBuilder(); String line; @@ -230,16 +229,17 @@ public class DatasetServiceImpl implements DatasetService { // 使用Jackson解析 Json 字符串为List对象 ObjectMapper mapper = new ObjectMapper(); // 使用 TypeReference 解析 JSON 字符串为 List - List jsonList = mapper.readValue(content.toString(), new TypeReference>() {}); + List jsonList = mapper.readValue(content.toString(), new TypeReference>() { + }); jsonList.forEach( dataJsonTemplate -> { List answers = dataJsonTemplate.getAnswers(); DatasetQuestionDO datasetQuestionDO = BeanUtils.toBean(dataJsonTemplate, DatasetQuestionDO.class); datasetQuestionDO.setDatasetId(datasetFilesDO.getDatasetId()); datasetQuestionDO.setDatasetFilesId(datasetFilesDO.getId()); - datasetQuestionDO.setStatus(CollectionUtils.isNotEmpty(answers) ? 2:0); + datasetQuestionDO.setStatus(CollectionUtils.isNotEmpty(answers) ? 2 : 0); datasetQuestionMapper.insert(datasetQuestionDO); - if (CollectionUtils.isNotEmpty(answers)){ + if (CollectionUtils.isNotEmpty(answers)) { for (String answer : answers) { DatasetAnswerDO datasetAnswerDO = new DatasetAnswerDO(); datasetAnswerDO.setDatasetId(datasetFilesDO.getDatasetId()); @@ -251,24 +251,26 @@ public class DatasetServiceImpl implements DatasetService { } } ); - }catch (Exception e){ - throw exception(new ErrorCode(11000,"请正确上传json格式得数据!!!")); - }finally { + } catch (Exception e) { + throw exception(new ErrorCode(11000, "请正确上传json格式得数据!!!")); + } finally { connection.disconnect(); } } }); } + /** * txt文本数据 + * * @param txtFiles */ - public void readTxtFile(List txtFiles){ + public void readTxtFile (List txtFiles) { txtFiles.forEach(datasetFilesDO -> { List newContent = new ArrayList<>(); HttpURLConnection connection = DataSetReadFileUtils.readFile(datasetFilesDO.getDatasetFileUrl()); - if (connection != null){ + if (connection != null) { try (BufferedReader in = new BufferedReader(new InputStreamReader(connection.getInputStream(), "UTF-8"))) { String inputLine; while ((inputLine = in.readLine()) != null) { @@ -281,31 +283,83 @@ public class DatasetServiceImpl implements DatasetService { datasetQuestionMapper.insert(datasetQuestionDO); } } - }catch (Exception e){ + } catch (Exception e) { e.printStackTrace(); - }finally { + } finally { connection.disconnect(); } } }); } - public void parseFile(List datasetFiles) { + public void parseFile (List datasetFiles) { List insertDatasetFiles = BeanUtils.toBean(datasetFiles, DatasetFilesDO.class); datasetFilesMapper.insertBatch(insertDatasetFiles, 100); // 提取文件 List jsonFiles = insertDatasetFiles.stream() .filter(datasetFilesDO -> datasetFilesDO.getDatasetFileUrl().toLowerCase().endsWith(".json")) .collect(Collectors.toList()); - if (CollectionUtils.isNotEmpty(jsonFiles)){ + if (CollectionUtils.isNotEmpty(jsonFiles)) { readJsonFile(jsonFiles); } List txtFiles = insertDatasetFiles.stream() .filter(datasetFilesDO -> datasetFilesDO.getDatasetFileUrl().toLowerCase().endsWith(".txt")) .collect(Collectors.toList()); - if (CollectionUtils.isNotEmpty(txtFiles)){ + if (CollectionUtils.isNotEmpty(txtFiles)) { readTxtFile(txtFiles); } + + // csv 文件处理 + List csvFiles = insertDatasetFiles.stream() + .filter(datasetFilesDO -> datasetFilesDO.getDatasetFileUrl().toLowerCase().endsWith(".csv")) + .collect(Collectors.toList()); + if (CollectionUtils.isNotEmpty(csvFiles)) { + readCsvFile(csvFiles); + } + } + + /** + * csv文件处理 + * + * @param csvFiles csv文件 + */ + private void readCsvFile (List csvFiles) { + csvFiles.forEach(datasetFilesDO -> { + + try { + // 读取并解析CSV文件 + List dataSetVos = DataSetReadFileUtils.readParseCsv(datasetFilesDO.getDatasetFileUrl()); + + if (CollectionUtils.isNotEmpty(dataSetVos)) { + // 获取数据集ID + Long datasetId = datasetFilesDO.getDatasetId(); + // 数据集文件ID + Long fileId = datasetFilesDO.getId(); + + dataSetVos.forEach(dataSetVO -> { + // 保存到 数据集数据问题 + DatasetQuestionDO datasetQuestionDO = new DatasetQuestionDO();// 检查是否为空行 + datasetQuestionDO.setDatasetId(datasetId); + datasetQuestionDO.setDatasetFilesId(fileId); + datasetQuestionDO.setSystem(dataSetVO.getSystem()); + datasetQuestionDO.setQuestion(dataSetVO.getQuestion()); + datasetQuestionMapper.insert(datasetQuestionDO); + + // 保存到 数据集数据问题标注 + DatasetAnswerDO datasetAnswerDO = new DatasetAnswerDO(); + datasetAnswerDO.setDatasetId(datasetId); + datasetAnswerDO.setDatasetFilesId(fileId); + datasetAnswerDO.setQuestionId(datasetQuestionDO.getId()); + datasetAnswerDO.setAnswer(dataSetVO.getAnswer()); + datasetAnswerMapper.insert(datasetAnswerDO); + }); + } + + } catch (IOException e) { + throw exception(PARSE_CSV_ERROR); + } + + }); } } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/DataSetReadFileUtils.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/DataSetReadFileUtils.java index bed347aeb..fd61c6141 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/DataSetReadFileUtils.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/DataSetReadFileUtils.java @@ -1,14 +1,27 @@ package cn.iocoder.yudao.module.llm.utils; +import cn.iocoder.yudao.module.llm.utils.vo.CsvDataSetVO; +import com.alibaba.fastjson.JSON; +import com.opencsv.CSVParser; +import com.opencsv.CSVParserBuilder; +import com.opencsv.CSVReader; +import com.opencsv.CSVReaderBuilder; +import com.opencsv.exceptions.CsvValidationException; +import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; import java.net.HttpURLConnection; import java.net.URL; -import java.sql.Connection; +import java.util.ArrayList; +import java.util.List; +@Slf4j @Component public class DataSetReadFileUtils { - public static HttpURLConnection readFile(String filePath) { + public static HttpURLConnection readFile (String filePath) { try { URL url = new URL(filePath); HttpURLConnection connection = (HttpURLConnection) url.openConnection(); @@ -25,4 +38,66 @@ public class DataSetReadFileUtils { } return null; } + + /** + * 读取并解析 CSV 文件 + * + * @param csvUrl CSV文件的URL + * @return CSV 文件解析对象 + */ + public static List readParseCsv (String csvUrl) throws IOException { + // 根据传入的URL创建URL对象 + URL url = new URL(csvUrl); + BufferedReader reader = new BufferedReader(new InputStreamReader(url.openStream())); + + // 创建CSV解析器,指定分隔符为 , + CSVParser parser = new CSVParserBuilder().withSeparator(',').build(); + // 创建CSV读取器 + CSVReader csvReader = new CSVReaderBuilder(reader).withCSVParser(parser).build(); + + List dataSetVos = new ArrayList<>(); + String[] line; + + try { + // 跳过标题行 + csvReader.readNext(); + } catch (CsvValidationException e) { + // 跳过标题行异常 + throw new RuntimeException("跳过标题行时发生错误", e); + } + + while (true) { + try { + // 读取CSV文件的下一行数据 + line = csvReader.readNext(); + if (line == null) { + // 如果读取到的行为null,表示数据集读取完成 + log.info("数据集读取完成"); + break; + } + } catch (CsvValidationException e) { + // 读取行时异常 + throw new RuntimeException("读取CSV行时发生错误", e); + } + + // 读取行 + if (line.length >= 3) { + // 使用读取到的数据创建对象 + CsvDataSetVO dataSetVO = new CsvDataSetVO(line[0], line[1], line[2]); + dataSetVos.add(dataSetVO); + } + + } + + try { + // 关闭CSV读取器,释放资源 + csvReader.close(); + } catch (IOException e) { + // 关闭CSV读取器异常 + log.error("关闭CSV读取器时发生错误", e); + } + + // 返回解析后的对象 + return dataSetVos; + } } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/vo/CsvDataSetVO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/vo/CsvDataSetVO.java new file mode 100644 index 000000000..9dadac17e --- /dev/null +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/vo/CsvDataSetVO.java @@ -0,0 +1,19 @@ +package cn.iocoder.yudao.module.llm.utils.vo; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.ToString; + +/** + * @Description Csv数据集 + */ +@AllArgsConstructor +@NoArgsConstructor +@Data +@ToString +public class CsvDataSetVO { + private String system; + private String question; + private String answer; +}