diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java index 882fb9bdf..a3c7ad425 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java @@ -23,6 +23,7 @@ import com.baomidou.mybatisplus.core.toolkit.CollectionUtils; import com.baomidou.mybatisplus.core.toolkit.StringUtils; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import com.opencsv.exceptions.CsvValidationException; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import org.springframework.validation.annotation.Validated; @@ -335,7 +336,7 @@ public class DatasetServiceImpl implements DatasetService { }); } - } catch (IOException e) { + } catch (IOException | CsvValidationException e) { throw exception(PARSE_CSV_ERROR); } diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/DataSetReadFileUtils.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/DataSetReadFileUtils.java index fd61c6141..d160ab896 100644 --- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/DataSetReadFileUtils.java +++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/DataSetReadFileUtils.java @@ -1,7 +1,6 @@ package cn.iocoder.yudao.module.llm.utils; import cn.iocoder.yudao.module.llm.utils.vo.CsvDataSetVO; -import com.alibaba.fastjson.JSON; import com.opencsv.CSVParser; import com.opencsv.CSVParserBuilder; import com.opencsv.CSVReader; @@ -45,59 +44,75 @@ public class DataSetReadFileUtils { * @param csvUrl CSV文件的URL * @return CSV 文件解析对象 */ - public static List readParseCsv (String csvUrl) throws IOException { - // 根据传入的URL创建URL对象 - URL url = new URL(csvUrl); - BufferedReader reader = new BufferedReader(new InputStreamReader(url.openStream())); - - // 创建CSV解析器,指定分隔符为 , - CSVParser parser = new CSVParserBuilder().withSeparator(',').build(); - // 创建CSV读取器 - CSVReader csvReader = new CSVReaderBuilder(reader).withCSVParser(parser).build(); + public static List readParseCsv (String csvUrl) throws IOException, CsvValidationException { List dataSetVos = new ArrayList<>(); - String[] line; + // 创建CSV读取器 + CSVReader csvReader = null; try { - // 跳过标题行 - csvReader.readNext(); - } catch (CsvValidationException e) { - // 跳过标题行异常 - throw new RuntimeException("跳过标题行时发生错误", e); - } + URL url = new URL(csvUrl); + BufferedReader reader = new BufferedReader(new InputStreamReader(url.openStream())); - while (true) { - try { - // 读取CSV文件的下一行数据 - line = csvReader.readNext(); - if (line == null) { - // 如果读取到的行为null,表示数据集读取完成 - log.info("数据集读取完成"); - break; + // 创建 CSV 解析器,分隔符为逗号 + CSVParser parser = new CSVParserBuilder().withSeparator(',').build(); + // 构建 CSV 读取器 + csvReader = new CSVReaderBuilder(reader).withCSVParser(parser).build(); + + // 读取标题行 + String[] headers = csvReader.readNext(); + String[] line; + + while (true) { + try { + // 读取下一行数据 + line = csvReader.readNext(); + if (line == null) { + // 数据集读取完成 + break; + } + } catch (com.opencsv.exceptions.CsvValidationException e) { + // 处理读取行时的异常 + throw new IOException("读取 CSV 行时发生错误", e); + } + + // 动态读取,当行长度与标题行长度相等时 + if (line.length == headers.length) { + // 根据标题行找到相应列的索引创建 CsvDataSetVO 对象 + CsvDataSetVO dataSetVO = new CsvDataSetVO(line[getIndex(headers, "system")], line[getIndex(headers, "question")], line[getIndex(headers, "answer")]); + // 将对象添加到列表中 + dataSetVos.add(dataSetVO); } - } catch (CsvValidationException e) { - // 读取行时异常 - throw new RuntimeException("读取CSV行时发生错误", e); } - - // 读取行 - if (line.length >= 3) { - // 使用读取到的数据创建对象 - CsvDataSetVO dataSetVO = new CsvDataSetVO(line[0], line[1], line[2]); - dataSetVos.add(dataSetVO); + } finally { + if (csvReader != null) { + try { + // 关闭 CSV 读取器 + csvReader.close(); + } catch (IOException e) { + // 关闭CSV读取器异常 + log.error("关闭CSV读取器时发生错误", e); + } } - - } - - try { - // 关闭CSV读取器,释放资源 - csvReader.close(); - } catch (IOException e) { - // 关闭CSV读取器异常 - log.error("关闭CSV读取器时发生错误", e); } // 返回解析后的对象 return dataSetVos; } + + /** + * 查找列名在标题行中的索引 + * + * @param headers 标题行 + * @param columnName 列名 + * @return 索引 + */ + private static int getIndex (String[] headers, String columnName) { + for (int i = 0; i < headers.length; i++) { + if (headers[i].equals(columnName)) { + return i; + } + } + return -1; + } }