diff --git a/yudao-module-llm/yudao-module-llm-api/src/main/java/cn/iocoder/yudao/module/llm/enums/ErrorCodeConstants.java b/yudao-module-llm/yudao-module-llm-api/src/main/java/cn/iocoder/yudao/module/llm/enums/ErrorCodeConstants.java
index 1c803f206..250b7a186 100644
--- a/yudao-module-llm/yudao-module-llm-api/src/main/java/cn/iocoder/yudao/module/llm/enums/ErrorCodeConstants.java
+++ b/yudao-module-llm/yudao-module-llm-api/src/main/java/cn/iocoder/yudao/module/llm/enums/ErrorCodeConstants.java
@@ -106,4 +106,5 @@ public interface ErrorCodeConstants {
ErrorCode OPTIMIZE_PROMPT_NOT_EXISTS = new ErrorCode(10044, "优化后信息不存在");
+ ErrorCode PARSE_CSV_ERROR = new ErrorCode(10034, "请正确上传csv格式得数据!!!");
}
diff --git a/yudao-module-llm/yudao-module-llm-biz/pom.xml b/yudao-module-llm/yudao-module-llm-biz/pom.xml
index 604e3f010..ef61452cb 100644
--- a/yudao-module-llm/yudao-module-llm-biz/pom.xml
+++ b/yudao-module-llm/yudao-module-llm-biz/pom.xml
@@ -71,6 +71,13 @@
1.0.3
+
+
+ com.opencsv
+ opencsv
+ 5.9
+
+
diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java
index 583094c79..e89b35fd2 100644
--- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java
+++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/service/dataset/DatasetServiceImpl.java
@@ -16,34 +16,32 @@ import cn.iocoder.yudao.module.llm.dal.mysql.dataset.DatasetFilesMapper;
import cn.iocoder.yudao.module.llm.dal.mysql.dataset.DatasetMapper;
import cn.iocoder.yudao.module.llm.dal.mysql.dataset.DatasetQuestionMapper;
import cn.iocoder.yudao.module.llm.utils.DataSetReadFileUtils;
+import cn.iocoder.yudao.module.llm.utils.vo.CsvDataSetVO;
+import com.alibaba.fastjson.JSON;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
-import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
+import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
import javax.annotation.Resource;
-
import java.io.*;
import java.net.HttpURLConnection;
-import java.net.URL;
-import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
-import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.DATASET_NAME_EXISTS;
-import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.DATASET_NOT_EXISTS;
+import static cn.iocoder.yudao.module.llm.enums.ErrorCodeConstants.*;
/**
* 数据集 Service 实现类
*
* @author 华大大模型
*/
+@Slf4j
@Service
@Validated
public class DatasetServiceImpl implements DatasetService {
@@ -58,17 +56,17 @@ public class DatasetServiceImpl implements DatasetService {
private DatasetAnswerMapper datasetAnswerMapper;
@Override
- public Long createDataset(DatasetSaveReqVO createReqVO) {
+ public Long createDataset (DatasetSaveReqVO createReqVO) {
// 校验
validateDatasetNameExists(createReqVO);
- if(createReqVO.getType() == null){
+ if (createReqVO.getType() == null) {
createReqVO.setType(0);
}
// 插入
DatasetDO dataset = BeanUtils.toBean(createReqVO, DatasetDO.class);
datasetMapper.insert(dataset);
List datasetFiles = createReqVO.getDatasetFiles();
- if (CollectionUtils.isNotEmpty(datasetFiles)){
+ if (CollectionUtils.isNotEmpty(datasetFiles)) {
datasetFiles.stream().forEach(
datasetFilesSaveReqVO -> {
datasetFilesSaveReqVO.setDatasetId(dataset.getId());
@@ -80,15 +78,15 @@ public class DatasetServiceImpl implements DatasetService {
dataset.setDataLength(count);
Long annoCount = datasetQuestionMapper.selectCount(new LambdaQueryWrapper()
.eq(DatasetQuestionDO::getDatasetId, dataset.getId())
- .eq(DatasetQuestionDO::getStatus,2));
- double ratio = count == 0 ? 0 : ((double) annoCount / count) *100;
+ .eq(DatasetQuestionDO::getStatus, 2));
+ double ratio = count == 0 ? 0 : ((double) annoCount / count) * 100;
Integer formattedRatio = ratio == 0 ? 0 : (int) ratio;
Integer status = formattedRatio == 100 ? 2 : 1;
- if (formattedRatio != null){
+ if (formattedRatio != null) {
dataset.setAnnotateProgress(formattedRatio);
}
dataset.setStatus(status);
- if(annoCount == 0){
+ if (annoCount == 0) {
dataset.setStatus(0);
}
datasetMapper.updateById(dataset);
@@ -97,11 +95,11 @@ public class DatasetServiceImpl implements DatasetService {
}
- private static long getFileContentLength(File file)throws IOException {
+ private static long getFileContentLength (File file) throws IOException {
FileInputStream fis = new FileInputStream(file);
byte[] buffer = new byte[1024];
- long charCount =0;
- while(fis.read(buffer)!=-1) {
+ long charCount = 0;
+ while (fis.read(buffer) != -1) {
charCount += new String(buffer).length();
}
fis.close();
@@ -109,14 +107,14 @@ public class DatasetServiceImpl implements DatasetService {
}
@Override
- public void updateDataset(DatasetSaveReqVO updateReqVO) {
+ public void updateDataset (DatasetSaveReqVO updateReqVO) {
// 校验存在
validateDatasetExists(updateReqVO.getId());
validateDatasetNameExists(updateReqVO);
// 更新
DatasetDO updateObj = BeanUtils.toBean(updateReqVO, DatasetDO.class);
List datasetFiles = updateReqVO.getDatasetFiles();
- if (CollectionUtils.isNotEmpty(datasetFiles)){
+ if (CollectionUtils.isNotEmpty(datasetFiles)) {
datasetFiles.stream().forEach(
datasetFilesSaveReqVO -> {
datasetFilesSaveReqVO.setDatasetId(updateObj.getId());
@@ -128,15 +126,15 @@ public class DatasetServiceImpl implements DatasetService {
updateObj.setDataLength(count);
Long annoCount = datasetQuestionMapper.selectCount(new LambdaQueryWrapper()
.eq(DatasetQuestionDO::getDatasetId, updateObj.getId())
- .eq(DatasetQuestionDO::getStatus,2));
- double ratio = count == 0 ? 0 : ((double) annoCount / count) *100;
+ .eq(DatasetQuestionDO::getStatus, 2));
+ double ratio = count == 0 ? 0 : ((double) annoCount / count) * 100;
Integer formattedRatio = ratio == 0 ? 0 : (int) ratio;
Integer status = formattedRatio == 100 ? 2 : 1;
- if (formattedRatio != null){
+ if (formattedRatio != null) {
updateObj.setAnnotateProgress(formattedRatio);
}
updateObj.setStatus(status);
- if(annoCount == 0){
+ if (annoCount == 0) {
updateObj.setStatus(0);
}
}
@@ -144,33 +142,34 @@ public class DatasetServiceImpl implements DatasetService {
}
@Override
- public void deleteDataset(Long id) {
+ public void deleteDataset (Long id) {
// 校验存在
validateDatasetExists(id);
// 删除
datasetMapper.deleteById(id);
}
- private void validateDatasetExists(Long id) {
+ private void validateDatasetExists (Long id) {
if (datasetMapper.selectById(id) == null) {
throw exception(DATASET_NOT_EXISTS);
}
}
- private void validateDatasetNameExists(DatasetSaveReqVO dateReqVO){
+
+ private void validateDatasetNameExists (DatasetSaveReqVO dateReqVO) {
LambdaQueryWrapper wrapper = new LambdaQueryWrapper()
.eq(DatasetDO::getDatasetName, dateReqVO.getDatasetName());
- if (dateReqVO.getId() != null){
+ if (dateReqVO.getId() != null) {
wrapper.ne(DatasetDO::getId, dateReqVO.getId());
}
List datasetDOS = datasetMapper.selectList(wrapper);
- if (CollectionUtils.isNotEmpty(datasetDOS)){
+ if (CollectionUtils.isNotEmpty(datasetDOS)) {
throw exception(DATASET_NAME_EXISTS);
}
}
@Override
- public DatasetRespVO getDataset(Long id) {
+ public DatasetRespVO getDataset (Long id) {
DatasetDO datasetDO = datasetMapper.selectById(id);
DatasetRespVO datasetRespVO = BeanUtils.toBean(datasetDO, DatasetRespVO.class);
/*List datasetQuestionDO = datasetQuestionMapper.selectList(new LambdaQueryWrapper().eq(DatasetQuestionDO::getDatasetId, id));
@@ -180,12 +179,12 @@ public class DatasetServiceImpl implements DatasetService {
}
@Override
- public PageResult getDatasetPage(DatasetPageReqVO pageReqVO) {
+ public PageResult getDatasetPage (DatasetPageReqVO pageReqVO) {
return datasetMapper.selectPage(pageReqVO);
}
@Override
- public List queryAll() {
+ public List queryAll () {
/*List datasetDOS0 = datasetMapper.selectList(new LambdaQueryWrapper().eq(DatasetDO::getType, DataConstants.dataTypePrivate));
List datasetRespVOS0 = BeanUtils.toBean(datasetDOS0, DatasetRespVO.class);
List datasetDOS1 = datasetMapper.selectList(new LambdaQueryWrapper().eq(DatasetDO::getType, DataConstants.dataTypePublic));
@@ -194,7 +193,7 @@ public class DatasetServiceImpl implements DatasetService {
result.add(datasetRespVOS0);
result.add(datasetRespVOS1);*/
List datasetDOS = datasetMapper.selectList(new LambdaQueryWrapper()
- .eq(DatasetDO::getStatus,2)); // 获取所有数据集
+ .eq(DatasetDO::getStatus, 2)); // 获取所有数据集
// 创建两个根节点,分别代表两种 type
DatasetTreeNode privateRoot = new DatasetTreeNode(DataConstants.dataTypePrivate);
@@ -203,9 +202,9 @@ public class DatasetServiceImpl implements DatasetService {
for (DatasetDO datasetDO : datasetDOS) {
DatasetRespVO datasetRespVO = BeanUtils.toBean(datasetDO, DatasetRespVO.class);
// 根据 type 字段决定节点的位置
- if (datasetRespVO.getType()==DataConstants.dataTypePrivate) {
+ if (datasetRespVO.getType() == DataConstants.dataTypePrivate) {
privateRoot.getChildren().add(datasetRespVO);
- } else if (datasetRespVO.getType()==DataConstants.dataTypePublic) {
+ } else if (datasetRespVO.getType() == DataConstants.dataTypePublic) {
publicRoot.getChildren().add(datasetRespVO);
}
}
@@ -216,10 +215,10 @@ public class DatasetServiceImpl implements DatasetService {
}
- public void readJsonFile(List jsonFiles){
+ public void readJsonFile (List jsonFiles) {
jsonFiles.forEach(datasetFilesDO -> {
HttpURLConnection connection = DataSetReadFileUtils.readFile(datasetFilesDO.getDatasetFileUrl());
- if (connection != null){
+ if (connection != null) {
try (BufferedReader in = new BufferedReader(new InputStreamReader(connection.getInputStream(), "UTF-8"))) {
StringBuilder content = new StringBuilder();
String line;
@@ -230,16 +229,17 @@ public class DatasetServiceImpl implements DatasetService {
// 使用Jackson解析 Json 字符串为List对象
ObjectMapper mapper = new ObjectMapper();
// 使用 TypeReference 解析 JSON 字符串为 List
- List jsonList = mapper.readValue(content.toString(), new TypeReference>() {});
+ List jsonList = mapper.readValue(content.toString(), new TypeReference>() {
+ });
jsonList.forEach(
dataJsonTemplate -> {
List answers = dataJsonTemplate.getAnswers();
DatasetQuestionDO datasetQuestionDO = BeanUtils.toBean(dataJsonTemplate, DatasetQuestionDO.class);
datasetQuestionDO.setDatasetId(datasetFilesDO.getDatasetId());
datasetQuestionDO.setDatasetFilesId(datasetFilesDO.getId());
- datasetQuestionDO.setStatus(CollectionUtils.isNotEmpty(answers) ? 2:0);
+ datasetQuestionDO.setStatus(CollectionUtils.isNotEmpty(answers) ? 2 : 0);
datasetQuestionMapper.insert(datasetQuestionDO);
- if (CollectionUtils.isNotEmpty(answers)){
+ if (CollectionUtils.isNotEmpty(answers)) {
for (String answer : answers) {
DatasetAnswerDO datasetAnswerDO = new DatasetAnswerDO();
datasetAnswerDO.setDatasetId(datasetFilesDO.getDatasetId());
@@ -251,24 +251,26 @@ public class DatasetServiceImpl implements DatasetService {
}
}
);
- }catch (Exception e){
- throw exception(new ErrorCode(11000,"请正确上传json格式得数据!!!"));
- }finally {
+ } catch (Exception e) {
+ throw exception(new ErrorCode(11000, "请正确上传json格式得数据!!!"));
+ } finally {
connection.disconnect();
}
}
});
}
+
/**
* txt文本数据
+ *
* @param txtFiles
*/
- public void readTxtFile(List txtFiles){
+ public void readTxtFile (List txtFiles) {
txtFiles.forEach(datasetFilesDO -> {
List newContent = new ArrayList<>();
HttpURLConnection connection = DataSetReadFileUtils.readFile(datasetFilesDO.getDatasetFileUrl());
- if (connection != null){
+ if (connection != null) {
try (BufferedReader in = new BufferedReader(new InputStreamReader(connection.getInputStream(), "UTF-8"))) {
String inputLine;
while ((inputLine = in.readLine()) != null) {
@@ -281,31 +283,83 @@ public class DatasetServiceImpl implements DatasetService {
datasetQuestionMapper.insert(datasetQuestionDO);
}
}
- }catch (Exception e){
+ } catch (Exception e) {
e.printStackTrace();
- }finally {
+ } finally {
connection.disconnect();
}
}
});
}
- public void parseFile(List datasetFiles) {
+ public void parseFile (List datasetFiles) {
List insertDatasetFiles = BeanUtils.toBean(datasetFiles, DatasetFilesDO.class);
datasetFilesMapper.insertBatch(insertDatasetFiles, 100);
// 提取文件
List jsonFiles = insertDatasetFiles.stream()
.filter(datasetFilesDO -> datasetFilesDO.getDatasetFileUrl().toLowerCase().endsWith(".json"))
.collect(Collectors.toList());
- if (CollectionUtils.isNotEmpty(jsonFiles)){
+ if (CollectionUtils.isNotEmpty(jsonFiles)) {
readJsonFile(jsonFiles);
}
List txtFiles = insertDatasetFiles.stream()
.filter(datasetFilesDO -> datasetFilesDO.getDatasetFileUrl().toLowerCase().endsWith(".txt"))
.collect(Collectors.toList());
- if (CollectionUtils.isNotEmpty(txtFiles)){
+ if (CollectionUtils.isNotEmpty(txtFiles)) {
readTxtFile(txtFiles);
}
+
+ // csv 文件处理
+ List csvFiles = insertDatasetFiles.stream()
+ .filter(datasetFilesDO -> datasetFilesDO.getDatasetFileUrl().toLowerCase().endsWith(".csv"))
+ .collect(Collectors.toList());
+ if (CollectionUtils.isNotEmpty(csvFiles)) {
+ readCsvFile(csvFiles);
+ }
+ }
+
+ /**
+ * csv文件处理
+ *
+ * @param csvFiles csv文件
+ */
+ private void readCsvFile (List csvFiles) {
+ csvFiles.forEach(datasetFilesDO -> {
+
+ try {
+ // 读取并解析CSV文件
+ List dataSetVos = DataSetReadFileUtils.readParseCsv(datasetFilesDO.getDatasetFileUrl());
+
+ if (CollectionUtils.isNotEmpty(dataSetVos)) {
+ // 获取数据集ID
+ Long datasetId = datasetFilesDO.getDatasetId();
+ // 数据集文件ID
+ Long fileId = datasetFilesDO.getId();
+
+ dataSetVos.forEach(dataSetVO -> {
+ // 保存到 数据集数据问题
+ DatasetQuestionDO datasetQuestionDO = new DatasetQuestionDO();// 检查是否为空行
+ datasetQuestionDO.setDatasetId(datasetId);
+ datasetQuestionDO.setDatasetFilesId(fileId);
+ datasetQuestionDO.setSystem(dataSetVO.getSystem());
+ datasetQuestionDO.setQuestion(dataSetVO.getQuestion());
+ datasetQuestionMapper.insert(datasetQuestionDO);
+
+ // 保存到 数据集数据问题标注
+ DatasetAnswerDO datasetAnswerDO = new DatasetAnswerDO();
+ datasetAnswerDO.setDatasetId(datasetId);
+ datasetAnswerDO.setDatasetFilesId(fileId);
+ datasetAnswerDO.setQuestionId(datasetQuestionDO.getId());
+ datasetAnswerDO.setAnswer(dataSetVO.getAnswer());
+ datasetAnswerMapper.insert(datasetAnswerDO);
+ });
+ }
+
+ } catch (IOException e) {
+ throw exception(PARSE_CSV_ERROR);
+ }
+
+ });
}
}
diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/DataSetReadFileUtils.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/DataSetReadFileUtils.java
index bed347aeb..fd61c6141 100644
--- a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/DataSetReadFileUtils.java
+++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/DataSetReadFileUtils.java
@@ -1,14 +1,27 @@
package cn.iocoder.yudao.module.llm.utils;
+import cn.iocoder.yudao.module.llm.utils.vo.CsvDataSetVO;
+import com.alibaba.fastjson.JSON;
+import com.opencsv.CSVParser;
+import com.opencsv.CSVParserBuilder;
+import com.opencsv.CSVReader;
+import com.opencsv.CSVReaderBuilder;
+import com.opencsv.exceptions.CsvValidationException;
+import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
-import java.sql.Connection;
+import java.util.ArrayList;
+import java.util.List;
+@Slf4j
@Component
public class DataSetReadFileUtils {
- public static HttpURLConnection readFile(String filePath) {
+ public static HttpURLConnection readFile (String filePath) {
try {
URL url = new URL(filePath);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
@@ -25,4 +38,66 @@ public class DataSetReadFileUtils {
}
return null;
}
+
+ /**
+ * 读取并解析 CSV 文件
+ *
+ * @param csvUrl CSV文件的URL
+ * @return CSV 文件解析对象
+ */
+ public static List readParseCsv (String csvUrl) throws IOException {
+ // 根据传入的URL创建URL对象
+ URL url = new URL(csvUrl);
+ BufferedReader reader = new BufferedReader(new InputStreamReader(url.openStream()));
+
+ // 创建CSV解析器,指定分隔符为 ,
+ CSVParser parser = new CSVParserBuilder().withSeparator(',').build();
+ // 创建CSV读取器
+ CSVReader csvReader = new CSVReaderBuilder(reader).withCSVParser(parser).build();
+
+ List dataSetVos = new ArrayList<>();
+ String[] line;
+
+ try {
+ // 跳过标题行
+ csvReader.readNext();
+ } catch (CsvValidationException e) {
+ // 跳过标题行异常
+ throw new RuntimeException("跳过标题行时发生错误", e);
+ }
+
+ while (true) {
+ try {
+ // 读取CSV文件的下一行数据
+ line = csvReader.readNext();
+ if (line == null) {
+ // 如果读取到的行为null,表示数据集读取完成
+ log.info("数据集读取完成");
+ break;
+ }
+ } catch (CsvValidationException e) {
+ // 读取行时异常
+ throw new RuntimeException("读取CSV行时发生错误", e);
+ }
+
+ // 读取行
+ if (line.length >= 3) {
+ // 使用读取到的数据创建对象
+ CsvDataSetVO dataSetVO = new CsvDataSetVO(line[0], line[1], line[2]);
+ dataSetVos.add(dataSetVO);
+ }
+
+ }
+
+ try {
+ // 关闭CSV读取器,释放资源
+ csvReader.close();
+ } catch (IOException e) {
+ // 关闭CSV读取器异常
+ log.error("关闭CSV读取器时发生错误", e);
+ }
+
+ // 返回解析后的对象
+ return dataSetVos;
+ }
}
diff --git a/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/vo/CsvDataSetVO.java b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/vo/CsvDataSetVO.java
new file mode 100644
index 000000000..9dadac17e
--- /dev/null
+++ b/yudao-module-llm/yudao-module-llm-biz/src/main/java/cn/iocoder/yudao/module/llm/utils/vo/CsvDataSetVO.java
@@ -0,0 +1,19 @@
+package cn.iocoder.yudao.module.llm.utils.vo;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import lombok.ToString;
+
+/**
+ * @Description Csv数据集
+ */
+@AllArgsConstructor
+@NoArgsConstructor
+@Data
+@ToString
+public class CsvDataSetVO {
+ private String system;
+ private String question;
+ private String answer;
+}