修改模型评估人工标注

This commit is contained in:
limin 2025-01-03 12:42:56 +08:00
parent eea8bd3fc6
commit a41dbe5559
6 changed files with 105 additions and 54 deletions

View File

@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.llm.controller.admin.modelassesstaskmanual.manua
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.*;
import java.util.HashMap;
import java.util.List;
@Schema(description = "管理后台 - 模型评估人工标注 模型回答新增/修改 Request VO")
@ -26,8 +27,7 @@ public class ManualModelAnswerSaveReqVO {
@Schema(description = "模型回答内容")
private String modelAnswer;
@Schema(description = "模型评估人工标注列表")
private List<ManualModelAnnoSaveReqVO> annoReqRespVo;
private List<HashMap<String, Object>> reqRespVos;
}

View File

@ -26,19 +26,6 @@ public class ManualAssessmentPageRespVO {
@ExcelProperty("评估任务ID")
private Long id;
@Schema(description = "任务名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "张三")
@ExcelProperty("任务名称")
private String taskName;
@Schema(description = "任务描述")
@ExcelProperty("任务描述")
private String taskInfro;
@Schema(description = "CPU类型使用字典llm_cpu_type", example = "2")
@ExcelProperty(value = "CPU类型使用字典llm_cpu_type", converter = DictConvert.class)
@DictFormat("llm_cpu_type") // TODO 代码优化建议设置到对应的 DictTypeConstants 枚举类中
private Integer cpuType;
@Schema(description = "模型服务")
@ExcelProperty("模型服务")
private Long modelService;
@ -47,14 +34,10 @@ public class ManualAssessmentPageRespVO {
@ExcelProperty("数据集")
private Long dataset;
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED)
@ExcelProperty("创建时间")
private LocalDateTime createTime;
@Schema(description = "人工评估列表")
private List<ModelAssessTaskDimensionRespVO> dimensions;
@Schema(description = "模型评估任务状态,使用字典llm_model_assess_task_status")
@Schema(description = "模型评估任务状态人工标注状态0未标注 2标注完成")
private Integer status;
@Schema(description = "任务进度")

View File

@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.llm.dal.dataobject.modelassesstaskmanual;
import cn.iocoder.yudao.module.llm.handler.ListHashMapTypeHandler;
import lombok.*;
import java.util.*;
import java.time.LocalDateTime;
@ -48,4 +49,7 @@ public class ManualModelAnswerDO extends BaseDO {
*/
private String modelAnswer;
@TableField(typeHandler = ListHashMapTypeHandler.class)
private List<HashMap<String, Object>> reqRespVos;
}

View File

@ -2,14 +2,18 @@ package cn.iocoder.yudao.module.llm.framework.backend.config;
import cn.hutool.json.JSONObject;
import cn.iocoder.yudao.module.llm.handler.JSONObjectTypeHandler;
import cn.iocoder.yudao.module.llm.handler.ListHashMapTypeHandler;
import com.baomidou.mybatisplus.autoconfigure.ConfigurationCustomizer;
import org.apache.ibatis.type.TypeHandlerRegistry;
import org.mybatis.spring.annotation.MapperScan;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.HashMap;
import java.util.List;
@Configuration
@MapperScan("cn.iocoder.yudao.module.llm.dal.mysql.dataprocesstask")
@MapperScan("cn.iocoder.yudao.module.llm.dal.mysql.*")
public class MyBatisConfig {
@Bean
@ -17,6 +21,11 @@ public class MyBatisConfig {
return new JSONObjectTypeHandler();
}
@Bean
public ListHashMapTypeHandler listHashMapTypeHandler() {
return new ListHashMapTypeHandler();
}
@Bean
public ConfigurationCustomizer configurationCustomizer() {
return configuration -> {
@ -24,6 +33,12 @@ public class MyBatisConfig {
if (!typeHandlerRegistry.hasTypeHandler(JSONObject.class)) {
typeHandlerRegistry.register(JSONObject.class, jsonObjectTypeHandler());
}
// 注册 List<HashMap<String, Object>> 类型的处理器
if (!typeHandlerRegistry.hasTypeHandler(List.class)){
typeHandlerRegistry.register(List.class, listHashMapTypeHandler());
}
};
}
}

View File

@ -0,0 +1,56 @@
package cn.iocoder.yudao.module.llm.handler;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.ibatis.type.BaseTypeHandler;
import org.apache.ibatis.type.JdbcType;
import java.sql.CallableStatement;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.List;
public class ListHashMapTypeHandler extends BaseTypeHandler<List<HashMap<String, Object>>> {
private final ObjectMapper objectMapper = new ObjectMapper();
@Override
public void setNonNullParameter(PreparedStatement ps, int i, List<HashMap<String, Object>> parameter, JdbcType jdbcType) throws SQLException {
try {
String jsonString = objectMapper.writeValueAsString(parameter);
ps.setString(i, jsonString);
} catch (Exception e) {
throw new SQLException("Failed to convert List<HashMap<String, Object>> to JSON string", e);
}
}
@Override
public List<HashMap<String, Object>> getNullableResult(ResultSet rs, String columnName) throws SQLException {
String jsonStr = rs.getString(columnName);
return getListFromJson(jsonStr);
}
@Override
public List<HashMap<String, Object>> getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
String jsonStr = rs.getString(columnIndex);
return getListFromJson(jsonStr);
}
@Override
public List<HashMap<String, Object>> getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
String jsonStr = cs.getString(columnIndex);
return getListFromJson(jsonStr);
}
private List<HashMap<String, Object>> getListFromJson(String jsonStr) {
if (jsonStr == null) {
return null;
}
try {
return objectMapper.readValue(jsonStr, objectMapper.getTypeFactory().constructCollectionType(List.class, HashMap.class));
} catch (Exception e) {
return null;
}
}
}

View File

@ -83,26 +83,21 @@ public class ManualModelAnswerServiceImpl implements ManualModelAnswerService {
resp.setQuestion(datasetQuestionDO.getQuestion());
resp.setAnswers(BeanUtils.toBean(datasetAnswerDOS, DatasetAnswerRespVO.class));
resp.setModelAnswer(modelAnswerDO.getModelAnswer());
List<ManualModelAnnoDO> manualModelAnnoDOS = manualModelAnnoMapper.selectList(new LambdaQueryWrapper<>(ManualModelAnnoDO.class)
.eq(ManualModelAnnoDO::getModelAnswerId, modelAnswerDO.getId()));
List<ModelAssessTaskDimensionRespVO> dimensions = resp.getDimensions();
Map<Long, ManualModelAnnoDO> longModelServiceDOMap = cn.iocoder.yudao.framework.common.util.collection.
CollectionUtils.convertMap(manualModelAnnoDOS, ManualModelAnnoDO::getDimensionId);
if (!CollectionUtils.isAnyEmpty(dimensions)){
List<HashMap<String, Object>> annoList= new ArrayList<>();
dimensions.forEach(dimension -> {
HashMap<String, Object> map = new HashMap<>();
if (longModelServiceDOMap.containsKey(dimension.getDimensionId())){
map.put("label",dimension.getDimensionName());
map.put("score",longModelServiceDOMap.get(dimension.getDimensionId()).getScore());
annoList.add(map);
}else {
map.put("label",dimension.getDimensionName());
map.put("score",0);
annoList.add(map);
}
});
resp.setReqRespVos(annoList);
resp.setId(modelAnswerDO.getId());
resp.setStatus(modelAnswerDO.getStatus());
if (!CollectionUtils.isAnyEmpty(modelAnswerDO.getReqRespVos())){
resp.setReqRespVos(modelAnswerDO.getReqRespVos());
}else {
if (!CollectionUtils.isAnyEmpty(modelAssessTaskManual.getDimensions())){
List<HashMap<String, Object>> map = new ArrayList<>();
modelAssessTaskManual.getDimensions().forEach(dimension -> {
HashMap<String, Object> map1 = new HashMap<>();
map1.put("dimension", dimension.getDimensionName());
map1.put("score", 0);
map.add(map1);
});
resp.setReqRespVos(map);
}
}
res.add(resp);
});
@ -119,25 +114,23 @@ public class ManualModelAnswerServiceImpl implements ManualModelAnswerService {
@Override
public void annoManualModelAnswer(List<ManualModelAnswerSaveReqVO> reqRespVo) {
List<ManualModelAnswerDO> modelAnswerDOS = BeanUtils.toBean(reqRespVo, ManualModelAnswerDO.class);
if (reqRespVo.size() > 0){
Long modelAnswerId = modelAnswerDOS.get(0).getManalTaskId();
reqRespVo.forEach(modelAnswerDO -> {
if (!CollectionUtils.isAnyEmpty(modelAnswerDO.getAnnoReqRespVo())){
List<ManualModelAnnoDO> modelAnnoDOS = BeanUtils.toBean(modelAnswerDO.getAnnoReqRespVo(), ManualModelAnnoDO.class);
manualModelAnnoMapper.insertOrUpdate(modelAnnoDOS);
manualModelAnswerMapper.updateStatus(modelAnswerDO.getId(),2);
}
if (!CollectionUtils.isAnyEmpty(modelAnswerDOS)) {
modelAnswerDOS.forEach(modelAnswerDO -> {
modelAnswerDO.setStatus(2);
manualModelAnswerMapper.updateById(modelAnswerDO);
});
// 标注进度修改
LambdaQueryWrapper<ManualModelAnswerDO> wrapper = new LambdaQueryWrapper<ManualModelAnswerDO>()
.eq(ManualModelAnswerDO::getManalTaskId, modelAnswerId);
LambdaQueryWrapper<ManualModelAnswerDO> wrapper = new LambdaQueryWrapper<ManualModelAnswerDO>()
.eq(ManualModelAnswerDO::getManalTaskId, reqRespVo.get(0).getManalTaskId());
Long sumCount = manualModelAnswerMapper.selectCount(wrapper);
wrapper.eq(ManualModelAnswerDO::getStatus,2);
wrapper.eq(ManualModelAnswerDO::getStatus, 2);
Long annoCount = manualModelAnswerMapper.selectCount(wrapper);
double ratio = sumCount == 0 ? 0 : ((double) annoCount / sumCount) *100;
double ratio = sumCount == 0 ? 0 : ((double) annoCount / sumCount) * 100;
Integer formattedRatio = ratio == 0 ? 0 : (int) ratio;
Integer status = formattedRatio == 100 ? 4 : 2;
modelAssessTaskManualMapper.updateStatus(modelAnswerId,formattedRatio,status);
modelAssessTaskManualMapper.updateStatus(reqRespVo.get(0).getManalTaskId(), formattedRatio, status);
}
}