自动评估增加停用词

This commit is contained in:
ire 2025-01-20 14:51:23 +08:00
parent 9f181c21f3
commit fb3c83bd36
3 changed files with 77 additions and 57 deletions

View File

@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.llm.controller.admin.dataset.vo.DatasetAnswerRespVO;
import cn.iocoder.yudao.module.llm.controller.admin.dataset.vo.DatasetQuestionRespVO;
import cn.iocoder.yudao.module.llm.controller.admin.modelassesstaskauto.vo.ModelAssessTaskStoplistSaveReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.modelassesstaskmanual.vo.ModelAssessTaskManualSaveReqVO;
import cn.iocoder.yudao.module.llm.controller.admin.modelassesstaskmanualbackup.vo.ModelAssessTaskManualBackupSaveReqVO;
import cn.iocoder.yudao.module.llm.dal.dataobject.basemodel.BaseModelDO;
@ -33,6 +34,8 @@ import cn.iocoder.yudao.module.llm.service.http.TrainHttpService;
import cn.iocoder.yudao.module.llm.service.http.vo.ModelCompletionsReqVO;
import cn.iocoder.yudao.module.llm.service.http.vo.ModelCompletionsRespVO;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.google.common.base.Joiner;
import org.apache.commons.lang3.StringUtils;
import org.hibernate.validator.internal.engine.constraintvalidation.PredefinedScopeConstraintValidatorManagerImpl;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -74,21 +77,32 @@ public class AsyncModelAccessManualService {
@Async
public void auto(ModelAssessTaskAutoDO modelAssessTaskAuto) {
public void auto(ModelAssessTaskAutoDO modelAssessTaskAuto,List<ModelAssessTaskStoplistSaveReqVO> stoplists) {
try {
List<DatasetQuestionRespVO> datasetQuestionList = datasetQuestionService.getDatasetQuestionList(modelAssessTaskAuto.getDataset());
for (DatasetQuestionRespVO datasetQuestionRespVO : datasetQuestionList) {
String question = datasetQuestionRespVO.getQuestion();
DatasetAnswerRespVO datasetAnswerRespVO = datasetQuestionRespVO.getDatasetAnswerRespVO().get(0);
String datasetPrompt = datasetAnswerRespVO.getAnswer();
ModelCompletionsReqVO modelCompletionsReqVO = new ModelCompletionsReqVO();
List<ModelCompletionsReqVO.ModelCompletionsMessage> messages = new ArrayList<>();
if(stoplists != null && stoplists.size() > 0){
String stoplistsStr = Joiner.on(",").join(stoplists);
ModelCompletionsReqVO.ModelCompletionsMessage system = new ModelCompletionsReqVO.ModelCompletionsMessage();
system.setRole("system");
system.setContent("不要出现以下停用词:" + stoplistsStr);
messages.add(system);
}
ModelCompletionsReqVO.ModelCompletionsMessage message = new ModelCompletionsReqVO.ModelCompletionsMessage();
if (question != null){
message.setContent(question);
}
message.setRole("user");
List<ModelCompletionsReqVO.ModelCompletionsMessage> messages = new ArrayList<>();
messages.add(message);
ModelCompletionsReqVO modelCompletionsReqVO = new ModelCompletionsReqVO();
modelCompletionsReqVO.setMessages(messages);
ModelCompletionsRespVO modelCompletionsRespVO = modelService.modelCompletions(modelCompletionsReqVO);
String prompt = modelCompletionsRespVO.getAnswer();

View File

@ -24,7 +24,7 @@ public class BaseModelTaskService {
BaseModelService baseModelService;
@Scheduled(cron ="0 0 0/1 * * ?")
@Scheduled(cron ="0 0/1 * * * ?")
public void updateBaseModel() {
Log.info("定时任务启动");
String resStr = trainHttpService.modelsList();

View File

@ -76,6 +76,10 @@ public class ModelAssessTaskAutoServiceImpl implements ModelAssessTaskAutoServic
private ModelAssessTaskAutoInfoService modelAssessTaskAutoInfoService;
@Resource
private ModelAssessTaskAutoInfoBackupService modelAssessTaskAutoInfoBackupService;
@Resource
private ModelAssessStoplistMapper modelAssessStoplistMapper;
@Override
@ -91,16 +95,16 @@ public class ModelAssessTaskAutoServiceImpl implements ModelAssessTaskAutoServic
throw exception(new ErrorCode(11000,"任务名称重复"));
}
modelAssessTaskAutoMapper.insert(modelAssessTaskAuto);
// List<ModelAssessTaskStoplistSaveReqVO> stoplists = createReqVO.getStoplists();
// if (!CollectionUtils.isEmpty(stoplists)){
// stoplists.stream().forEach(stoplist -> {
// stoplist.setTaskId(modelAssessTaskAuto.getId());
// ModelAssessTaskStoplistDO modelAssessTaskStoplistDO = BeanUtils.toBean(stoplist,ModelAssessTaskStoplistDO.class);
// modelingTaskStoplistMapper.insert(modelAssessTaskStoplistDO);
// });
// }
List<ModelAssessTaskStoplistSaveReqVO> stoplists = createReqVO.getStoplists();
if (!CollectionUtils.isEmpty(stoplists)){
stoplists.stream().forEach(stoplist -> {
stoplist.setTaskId(modelAssessTaskAuto.getId());
ModelAssessTaskStoplistDO modelAssessTaskStoplistDO = BeanUtils.toBean(stoplist,ModelAssessTaskStoplistDO.class);
modelingTaskStoplistMapper.insert(modelAssessTaskStoplistDO);
});
}
asyncModelAccessManualService.auto(modelAssessTaskAuto);
asyncModelAccessManualService.auto(modelAssessTaskAuto,stoplists);
// 返回
return modelAssessTaskAuto.getId();
}
@ -113,22 +117,22 @@ public class ModelAssessTaskAutoServiceImpl implements ModelAssessTaskAutoServic
validateModelAssessTaskAutoExists(updateReqVO.getId());
// 更新
ModelAssessTaskAutoDO updateObj = BeanUtils.toBean(updateReqVO, ModelAssessTaskAutoDO.class);
// List<ModelAssessTaskStoplistSaveReqVO> stoplists = updateReqVO.getStoplists();
// if (!CollectionUtils.isEmpty(stoplists)){
// List<Long> stopIds = stoplists.stream().filter(stoplist -> stoplist.getId() != null)
// .map(stoplist -> stoplist.getId()).collect(Collectors.toList());
// LambdaQueryWrapper<ModelAssessTaskStoplistDO> wrapper = new LambdaQueryWrapper<ModelAssessTaskStoplistDO>();
// wrapper.eq(ModelAssessTaskStoplistDO::getTaskId,updateReqVO.getId());
// if (!CollectionUtils.isEmpty(stopIds)){
// wrapper.notIn(ModelAssessTaskStoplistDO::getId,stopIds);
// }
// modelingTaskStoplistMapper.delete(wrapper);
// stoplists.stream().forEach(stoplist -> {
//// stoplist.setTaskId(updateObj.getId());
// ModelAssessTaskStoplistDO modelAssessTaskStoplistDO = BeanUtils.toBean(stoplist,ModelAssessTaskStoplistDO.class);
// modelingTaskStoplistMapper.insertOrUpdate(modelAssessTaskStoplistDO);
// });
// }
List<ModelAssessTaskStoplistSaveReqVO> stoplists = updateReqVO.getStoplists();
if (!CollectionUtils.isEmpty(stoplists)){
List<Long> stopIds = stoplists.stream().filter(stoplist -> stoplist.getId() != null)
.map(stoplist -> stoplist.getId()).collect(Collectors.toList());
LambdaQueryWrapper<ModelAssessTaskStoplistDO> wrapper = new LambdaQueryWrapper<ModelAssessTaskStoplistDO>();
wrapper.eq(ModelAssessTaskStoplistDO::getTaskId,updateReqVO.getId());
if (!CollectionUtils.isEmpty(stopIds)){
wrapper.notIn(ModelAssessTaskStoplistDO::getId,stopIds);
}
modelingTaskStoplistMapper.delete(wrapper);
stoplists.stream().forEach(stoplist -> {
// stoplist.setTaskId(updateObj.getId());
ModelAssessTaskStoplistDO modelAssessTaskStoplistDO = BeanUtils.toBean(stoplist,ModelAssessTaskStoplistDO.class);
modelingTaskStoplistMapper.insertOrUpdate(modelAssessTaskStoplistDO);
});
}
modelAssessTaskAutoMapper.updateById(updateObj);
}
@ -156,28 +160,28 @@ public class ModelAssessTaskAutoServiceImpl implements ModelAssessTaskAutoServic
String dimension = modelAssessTaskAutoDO.getDimension();
List<String> list = Arrays.asList(dimension.split(","));
result.setDimension(list);
// // 标注查询
// List<ModelAssessTaskStoplistDO> modelAssessTaskStoplistDOS = modelingTaskStoplistMapper.selectList(new LambdaQueryWrapper<ModelAssessTaskStoplistDO>()
// .eq(ModelAssessTaskStoplistDO::getTaskId, id)
// .eq(ModelAssessTaskStoplistDO::getDeleted, false));
// List<ModelAssessTaskStoplistRespVO> stoplistRespVOS = BeanUtils.toBean(modelAssessTaskStoplistDOS, ModelAssessTaskStoplistRespVO.class);
// // 提取id
// List<Long> stoplistIds = modelAssessTaskStoplistDOS.stream().map(ModelAssessTaskStoplistDO::getStoplistId).collect(Collectors.toList());
// if (!CollectionUtils.isEmpty(stoplistIds)){
// // 查询停用词表 将词表word返回
// LambdaQueryWrapper<ModelAssessStoplistDO> wrapper = new LambdaQueryWrapper<>();
// wrapper.in(ModelAssessStoplistDO::getId,stoplistIds);
// List<ModelAssessStoplistDO> modelAssessStoplistDOs = modelAssessStoplistMapper.selectList(wrapper);
// Map<Long, ModelAssessStoplistDO> longModelServiceDOMap = cn.iocoder.yudao.framework.common.util.collection.
// CollectionUtils.convertMap(modelAssessStoplistDOs, ModelAssessStoplistDO::getId);
// stoplistRespVOS.stream().forEach(stoplistRespVO -> {
// ModelAssessStoplistDO modelAssessStoplistDO = longModelServiceDOMap.get(stoplistRespVO.getStoplistId());
// if(modelAssessStoplistDO != null){
// stoplistRespVO.setStoplistName(modelAssessStoplistDO.getWord());
// }
// });
// }
// result.setStoplists(stoplistRespVOS);
// 标注查询
List<ModelAssessTaskStoplistDO> modelAssessTaskStoplistDOS = modelingTaskStoplistMapper.selectList(new LambdaQueryWrapper<ModelAssessTaskStoplistDO>()
.eq(ModelAssessTaskStoplistDO::getTaskId, id)
.eq(ModelAssessTaskStoplistDO::getDeleted, false));
List<ModelAssessTaskStoplistRespVO> stoplistRespVOS = BeanUtils.toBean(modelAssessTaskStoplistDOS, ModelAssessTaskStoplistRespVO.class);
// 提取id
List<Long> stoplistIds = modelAssessTaskStoplistDOS.stream().map(ModelAssessTaskStoplistDO::getStoplistId).collect(Collectors.toList());
if (!CollectionUtils.isEmpty(stoplistIds)){
// 查询停用词表 将词表word返回
LambdaQueryWrapper<ModelAssessStoplistDO> wrapper = new LambdaQueryWrapper<>();
wrapper.in(ModelAssessStoplistDO::getId,stoplistIds);
List<ModelAssessStoplistDO> modelAssessStoplistDOs = modelAssessStoplistMapper.selectList(wrapper);
Map<Long, ModelAssessStoplistDO> longModelServiceDOMap = cn.iocoder.yudao.framework.common.util.collection.
CollectionUtils.convertMap(modelAssessStoplistDOs, ModelAssessStoplistDO::getId);
stoplistRespVOS.stream().forEach(stoplistRespVO -> {
ModelAssessStoplistDO modelAssessStoplistDO = longModelServiceDOMap.get(stoplistRespVO.getStoplistId());
if(modelAssessStoplistDO != null){
stoplistRespVO.setStoplistName(modelAssessStoplistDO.getWord());
}
});
}
result.setStoplists(stoplistRespVOS);
if(modelAssessTaskAutoDO.getModelType() == 0){
ModelServiceDO modelServiceDO = modelServiceMapper.selectById(modelAssessTaskAutoDO.getModelService());
@ -258,12 +262,14 @@ public class ModelAssessTaskAutoServiceImpl implements ModelAssessTaskAutoServic
modelAssessTaskAutoBackupMapper.insert(bean);
modelAssessTaskAutoDO.setBackupId(bean.getBackupId());
modelAssessTaskAutoMapper.updateById(modelAssessTaskAutoDO);
// List<ModelAssessTaskStoplistDO> modelAssessTaskStoplistDOS = modelingTaskStoplistMapper.selectList(new LambdaQueryWrapper<ModelAssessTaskStoplistDO>().eq(ModelAssessTaskStoplistDO::getTaskId, id));
// if (CollectionUtils.isEmpty(modelAssessTaskStoplistDOS)){
// throw exception(MODEL_ASSESS_TASK_STOPLIST_NOT_EXISTS);
// }
// List<ModelAssessTaskStoplistBackupDO> bean1 = BeanUtils.toBean(modelAssessTaskStoplistDOS, ModelAssessTaskStoplistBackupDO.class);
// modelAssessTaskStoplistBackupMapper.insertBatch(bean1);
List<ModelAssessTaskStoplistDO> modelAssessTaskStoplistDOS = modelingTaskStoplistMapper.selectList(new LambdaQueryWrapper<ModelAssessTaskStoplistDO>().eq(ModelAssessTaskStoplistDO::getTaskId, id));
if (CollectionUtils.isEmpty(modelAssessTaskStoplistDOS)){
throw exception(MODEL_ASSESS_TASK_STOPLIST_NOT_EXISTS);
}
List<ModelAssessTaskStoplistBackupDO> bean1 = BeanUtils.toBean(modelAssessTaskStoplistDOS, ModelAssessTaskStoplistBackupDO.class);
modelAssessTaskStoplistBackupMapper.insertBatch(bean1);
ModelAssessTaskAutoInfoPageReqVO modelAssessTaskAutoInfoPageReqVO = new ModelAssessTaskAutoInfoPageReqVO();
modelAssessTaskAutoInfoPageReqVO.setTaskId(id);
List<ModelAssessTaskAutoInfoDO> listByTaskId = modelAssessTaskAutoInfoService.getListByTaskId(modelAssessTaskAutoInfoPageReqVO);