feat(llm): 添加知识库向量嵌入功能

- 在应用配置文件中添加知识库向量嵌入的 URL
- 新增 AsyncKnowledgeBase 类中的 knowledgeEmbed 方法,用于异步处理知识库向量嵌入
- 在 KnowledgeBaseServiceImpl 中集成知识库向量嵌入的逻辑
- 新增 KnowledgeRagEmbedReqVO 类作为知识库向量嵌入的请求参数
- 在 LLMBackendProperties 中添加 embed 属性用于配置向量嵌入的 URL
- 在 RagHttpService 中实现 knowledgeEmbed 方法,用于调用向量嵌入的 API
This commit is contained in:
Liuyang 2025-02-08 11:51:43 +08:00
parent dd4756de63
commit 58128ba243
8 changed files with 235 additions and 64 deletions

View File

@ -102,4 +102,9 @@ public class LLMBackendProperties {
private String autoEvaluation;
private String textToImage;
/**
* 知识库向量嵌入
*/
private String embed;
}

View File

@ -1,9 +1,12 @@
package cn.iocoder.yudao.module.llm.service.async;
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.llm.dal.dataobject.knowledgedocuments.KnowledgeDocumentsDO;
import cn.iocoder.yudao.module.llm.framework.backend.config.LLMBackendProperties;
import cn.iocoder.yudao.module.llm.service.http.RagHttpService;
import cn.iocoder.yudao.module.llm.service.http.vo.KnowledgeRagEmbedReqVO;
import cn.iocoder.yudao.module.llm.service.http.vo.RegUploadReqVO;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -11,7 +14,15 @@ import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.util.List;
import java.util.Objects;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@Service
public class AsyncKnowledgeBase {
@ -25,12 +36,12 @@ public class AsyncKnowledgeBase {
// 向向量知识库创建文件
@Async
public void createKnowledgeBase(List<KnowledgeDocumentsDO> knowledgeList,List<Long> ids) {
if (!CollectionUtils.isAnyEmpty(ids)){
public void createKnowledgeBase (List<KnowledgeDocumentsDO> knowledgeList, List<Long> ids) {
if (!CollectionUtils.isAnyEmpty(ids)) {
String mes = ragHttpService.ragDocumentsDel(llmBackendProperties.getRagDocumentsDel(), ids);
log.info("delete knowledge base info {}",mes);
log.info("delete knowledge base info {}", mes);
}
if (!CollectionUtils.isAnyEmpty(knowledgeList)){
if (!CollectionUtils.isAnyEmpty(knowledgeList)) {
knowledgeList.stream().forEach(knowledge -> {
try {
RegUploadReqVO regUploadReqVO = new RegUploadReqVO()
@ -39,11 +50,86 @@ public class AsyncKnowledgeBase {
.setFileName(knowledge.getDocumentName())
.setFileUrl(knowledge.getFileUrl());
ragHttpService.embedUploadFile(regUploadReqVO);
}catch (Exception e){
log.error("the creation of the knowledge base error {}",e.getMessage());
} catch (Exception e) {
log.error("the creation of the knowledge base error {}", e.getMessage());
}
});
}
}
/**
* 知识库向量嵌入
*
* @param knowledgeList 文件列表
*/
@Async
public void knowledgeEmbed (List<KnowledgeDocumentsDO> knowledgeList) {
if (!CollectionUtils.isAnyEmpty(knowledgeList)) {
knowledgeList.forEach(knowledge -> {
try {
// TODO:本地调试时打开
/*
String tmpUrl = "http://xhllm.xinnuojinzhi.com/admin-api/infra/file/29/get/486b9a6fc855abf48847e9639f3c090855c6aafdc22a13b10e3244c37f03d3e0.txt";
log.info("knowledge url {}", tmpUrl);
knowledge.setFileUrl(tmpUrl);
*/
// 创建知识向量
KnowledgeRagEmbedReqVO ragEmbedReqVo = new KnowledgeRagEmbedReqVO()
.setFileId(String.valueOf(knowledge.getId()))
.setFileName(knowledge.getDocumentName())
.setFileInputStream(new ByteArrayInputStream(Objects.requireNonNull(getFileByte(knowledge.getFileUrl()))))
.setEntityId(String.valueOf(getEntityId()));
ragHttpService.knowledgeEmbed(ragEmbedReqVo);
} catch (Exception e) {
log.error("the creation of the knowledge base error {}", e.getMessage(), e);
}
});
}
}
/**
* 获取当前用户ID
*
* @return 用户ID
*/
public Long getEntityId () {
return SecurityFrameworkUtils.getLoginUserId();
}
/**
* 获取文件字节数组
*
* @param fileUrl 文件地址
* @return 文件字节数组
*/
public static byte[] getFileByte (String fileUrl) {
log.info("knowledge url: {}", fileUrl);
try (InputStream inputStream = new URL(fileUrl).openStream();
ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) {
// 缓冲区大小
byte[] buffer = new byte[1024];
int bytesRead;
// 读取文件内容并写入 ByteArrayOutputStream
while ((bytesRead = inputStream.read(buffer)) != -1) {
outputStream.write(buffer, 0, bytesRead);
}
// 返回字节数组
return outputStream.toByteArray();
} catch (IOException e) {
log.error("Failed to read remote file: {}", e.getMessage());
throw exception(new ErrorCode(10001_001, "文件读取错误"));
}
}
}

View File

@ -6,7 +6,6 @@ import cn.iocoder.yudao.module.llm.framework.backend.config.LLMBackendProperties
import cn.iocoder.yudao.module.llm.service.http.vo.*;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.toolkit.BeanUtils;
import com.google.gson.JsonArray;
import kong.unirest.HttpResponse;
@ -40,7 +39,7 @@ public class RagHttpService {
/**
* RAG健康检查API
*/
public RagHealthRespVO ragHealth(Map<String, String> headers){
public RagHealthRespVO ragHealth (Map<String, String> headers) {
String ragHealth = llmBackendProperties.getRagHealth();
String res = HttpUtils.get(ragHealth, headers);
log.info(" ragHealth:{}", res);
@ -52,10 +51,10 @@ public class RagHttpService {
/**
* 上传并向量化
*/
public RagEmbedRespVO ragEmbed(Map<String, String> headers, RagEmbedReqVo ragEmbedReqVo){
public RagEmbedRespVO ragEmbed (Map<String, String> headers, RagEmbedReqVo ragEmbedReqVo) {
String ragEmbed = llmBackendProperties.getRagEmbed();
Map<String, Object> map = BeanUtils.beanToMap(ragEmbedReqVo);
String res = HttpUtils.postForm(ragEmbed, headers,map);
String res = HttpUtils.postForm(ragEmbed, headers, map);
log.info(" ragEmbedRespVO:{}", res);
RagEmbedRespVO ragEmbedRespVO = JSON.parseObject(res, RagEmbedRespVO.class);
log.info(" ragEmbedRespVO:{}", ragEmbedRespVO);
@ -64,12 +63,13 @@ public class RagHttpService {
/**
* 向量知识库文档上传
*
* @param ragUploadReqVO
* @return
* @throws UnirestException
* @throws IOException
*/
public RagEmbedRespVO embedUploadFile(RegUploadReqVO ragUploadReqVO) throws UnirestException, IOException {
public RagEmbedRespVO embedUploadFile (RegUploadReqVO ragUploadReqVO) throws UnirestException, IOException {
CloseableHttpClient httpClient = HttpClients.createDefault();
RagEmbedRespVO ragEmbedRespVO = new RagEmbedRespVO();
HttpGet request = new HttpGet(ragUploadReqVO.getFileUrl());
@ -108,7 +108,7 @@ public class RagHttpService {
return ragEmbedRespVO;
}
private static String detectCharset(InputStream inputStream) throws IOException {
private static String detectCharset (InputStream inputStream) throws IOException {
byte[] buffer = new byte[4096];
int nread;
UniversalDetector detector = new UniversalDetector(null);
@ -122,24 +122,25 @@ public class RagHttpService {
return charset;
}
public String ragDocumentsDel(String url, List<Long> documentIds){
// 创建JSON数组
JsonArray jsonArray = new JsonArray();
for (Long id : documentIds) {
jsonArray.add(String.valueOf(id));
}
// 发送DELETE请求
HttpResponse<String> response = Unirest.delete(url)
.header("Content-Type", "application/json")
.body(jsonArray.toString())
.asString();
// 返回响应体
return response.getBody();
public String ragDocumentsDel (String url, List<Long> documentIds) {
// 创建JSON数组
JsonArray jsonArray = new JsonArray();
for (Long id : documentIds) {
jsonArray.add(String.valueOf(id));
}
// 发送DELETE请求
HttpResponse<String> response = Unirest.delete(url)
.header("Content-Type", "application/json")
.body(jsonArray.toString())
.asString();
// 返回响应体
return response.getBody();
}
/**
* 获取所有向量id
*/
public List<String> ragIds(Map<String, String> headers){
public List<String> ragIds (Map<String, String> headers) {
String ragIds = llmBackendProperties.getRagIds();
String res = HttpUtils.get(ragIds, headers);
log.info(" ragIds:{}", res);
@ -151,7 +152,7 @@ public class RagHttpService {
/**
* 根据id获取文档
*/
public RagDocumentsRespVO ragDocuments(Map<String, String> headers, RagDocumentsReqVO ragDocumentsReqVO){
public RagDocumentsRespVO ragDocuments (Map<String, String> headers, RagDocumentsReqVO ragDocumentsReqVO) {
String ragDocuments = llmBackendProperties.getRagDocuments();
String res = HttpUtils.get(ragDocuments, headers);
log.info(" ragDocuments:{}", res);
@ -165,7 +166,7 @@ public class RagHttpService {
/**
* 根据id删除文档
*/
public String ragDocumentsDel(Map<String,String> headers){
public String ragDocumentsDel (Map<String, String> headers) {
String ragDocumentsDel = llmBackendProperties.getRagDocumentsDel();
String res = HttpUtils.del(ragDocumentsDel, headers);
log.info(" ragDocumentsDel:{}", res);
@ -175,7 +176,7 @@ public class RagHttpService {
/**
* 根据file_id检索向量
*/
public RagQueryRespVO ragQuery(Map<String, String> headers, RagQueryReqVo ragQueryReqVo){
public RagQueryRespVO ragQuery (Map<String, String> headers, RagQueryReqVo ragQueryReqVo) {
String ragQuery = llmBackendProperties.getRagQuery();
String res = HttpUtils.post(ragQuery, headers, JSON.toJSONString(ragQueryReqVo));
log.info(" ragQuery:{}", res);
@ -195,10 +196,31 @@ public class RagHttpService {
/**
* 支持多个文件id查询向量
*/
public String ragQueryMultiple(Map<String, String> headers, RagQueryMultipleReqVo ragQueryReqVo){
public String ragQueryMultiple (Map<String, String> headers, RagQueryMultipleReqVo ragQueryReqVo) {
String ragQueryMultiple = llmBackendProperties.getRagQueryMultiple();
String res = HttpUtils.getBody(ragQueryMultiple, headers, JSON.toJSONString(ragQueryReqVo));
return res;
}
/**
* 知识库向量嵌入
*
* @param reqVO 请求参数
*/
public void knowledgeEmbed (KnowledgeRagEmbedReqVO reqVO) {
// 获取知识库向量嵌入的url
String ragEmbed = llmBackendProperties.getEmbed();
log.info("url : {}", ragEmbed);
// 构建请求参数
HttpResponse<String> response = Unirest.post(ragEmbed)
.field("file_id", reqVO.getFileId())
.field("entity_id", reqVO.getEntityId())
.field("file", reqVO.getFileInputStream(), reqVO.getFileName())
.asString();
log.info(" ========= Response Body Result: {}", response.getBody());
}
}

View File

@ -0,0 +1,39 @@
package cn.iocoder.yudao.module.llm.service.http.vo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import java.io.ByteArrayInputStream;
/**
* @Description 知识库向量化上传请求参数
* @Author Liu Yang
* @Date 2025/2/8 11:21
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
@ToString(callSuper = true)
public class KnowledgeRagEmbedReqVO {
/**
* 文件id
*/
private String fileId;
/**
* 文件名
*/
private String fileName;
/**
* 文件流
*/
private ByteArrayInputStream fileInputStream;
/**
* 用户Id
*/
private String entityId;
}

View File

@ -25,6 +25,7 @@ import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@ -61,46 +62,59 @@ public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
@Override
public void updateKnowledgeBase (KnowledgeBaseSaveReqVO updateReqVO) {
// 校验存在
// 1. 校验知识库是否存在
validateKnowledgeBaseExists(updateReqVO.getId());
// 2. 校验知识库名称是否重复
validateKnowledgeBaseNameExists(updateReqVO);
// 更新
// 3. 更新知识库主表
KnowledgeBaseDO updateObj = BeanUtils.toBean(updateReqVO, KnowledgeBaseDO.class);
knowledgeBaseMapper.updateById(updateObj);
List<KnowledgeDocumentsDO> knowledgeDocumentsList = new ArrayList<>();
// 附表增加数据
// 4. 处理附表知识文档数据
if (!CollectionUtils.isAnyEmpty(updateReqVO.getKnowledgeDocuments())) {
List<Long> ids = updateReqVO.getKnowledgeDocuments().stream().filter(
knowledgeDocuments -> knowledgeDocuments.getId() != null
).map(KnowledgeDocumentsSaveReqVO::getId).collect(Collectors.toList());
if (!CollectionUtils.isAnyEmpty(ids)) {
knowledgeDocumentsMapper.delete(new LambdaQueryWrapperX<KnowledgeDocumentsDO>()
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, updateReqVO.getId())
.notIn(KnowledgeDocumentsDO::getId, ids));
} else {
knowledgeDocumentsMapper.delete(new LambdaQueryWrapperX<KnowledgeDocumentsDO>()
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, updateReqVO.getId()));
// 4.1 获取需要保留的文档 ID
List<Long> retainedIds = updateReqVO.getKnowledgeDocuments().stream()
.map(KnowledgeDocumentsSaveReqVO::getId)
.filter(Objects::nonNull)
.collect(Collectors.toList());
// 4.2 删除不需要保留的文档
LambdaQueryWrapperX<KnowledgeDocumentsDO> deleteWrapper = new LambdaQueryWrapperX<KnowledgeDocumentsDO>()
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, updateReqVO.getId());
if (!CollectionUtils.isAnyEmpty(retainedIds)) {
deleteWrapper.notIn(KnowledgeDocumentsDO::getId, retainedIds);
}
updateReqVO.getKnowledgeDocuments().forEach(
knowledgeDocuments -> {
KnowledgeDocumentsDO knowledgeDocumentsDO = BeanUtils.toBean(knowledgeDocuments, KnowledgeDocumentsDO.class);
knowledgeDocumentsDO.setKnowledgeBaseId(updateReqVO.getId());
if (knowledgeDocuments.getId() == null) {
knowledgeDocumentsList.add(knowledgeDocumentsDO);
}
knowledgeDocumentsMapper.insertOrUpdate(knowledgeDocumentsDO);
}
);
knowledgeDocumentsMapper.delete(deleteWrapper);
// 4.3 更新或插入文档数据
List<KnowledgeDocumentsDO> newDocuments = new ArrayList<>();
updateReqVO.getKnowledgeDocuments().forEach(doc -> {
KnowledgeDocumentsDO docDO = BeanUtils.toBean(doc, KnowledgeDocumentsDO.class);
docDO.setKnowledgeBaseId(updateReqVO.getId());
if (doc.getId() == null) {
newDocuments.add(docDO); // 收集新增文档
}
knowledgeDocumentsMapper.insertOrUpdate(docDO); // 更新或插入文档
});
// 4.4 异步处理新增文档和删除的文档
List<Long> deleteIds = knowledgeDocumentsMapper.selectDeleteIds(updateReqVO.getId());
asyncKnowledgeBase.createKnowledgeBase(knowledgeDocumentsList, deleteIds);
asyncKnowledgeBase.createKnowledgeBase(newDocuments, deleteIds);
// 4.5 异步处理知识库外挂
asyncKnowledgeBase.knowledgeEmbed(newDocuments);
} else {
// 5. 如果传入的文档列表为空则删除所有关联文档
knowledgeDocumentsMapper.delete(new LambdaQueryWrapperX<KnowledgeDocumentsDO>()
.eq(KnowledgeDocumentsDO::getKnowledgeBaseId, updateReqVO.getId()));
// 5.1 异步处理删除的文档
List<Long> deleteIds = knowledgeDocumentsMapper.selectDeleteIds(updateReqVO.getId());
if (!CollectionUtils.isAnyEmpty(deleteIds)) {
asyncKnowledgeBase.createKnowledgeBase(null, deleteIds);
}
}
}
@Override

View File

@ -282,7 +282,8 @@ llm:
# 文生图
text_to_image: http://36.103.199.104:5123/generate-image
# 知识库向量嵌入
embed: http://36.103.199.104:8123/embed
--- #################### iot相关配置 TODO 芋艿:再瞅瞅 ####################
iot:
emq:

View File

@ -170,12 +170,12 @@ debug: false
--- #################### 微信公众号、小程序相关配置 ####################
wx:
mp: # 公众号配置(必填),参见 https://github.com/Wechat-Group/WxJava/blob/develop/spring-boot-starters/wx-java-mp-spring-boot-starter/README.md 文档
# app-id: wx041349c6f39b268b # 测试号(牛希尧提供的)
# secret: 5abee519483bc9f8cb37ce280e814bd0
# app-id: wx041349c6f39b268b # 测试号(牛希尧提供的)
# secret: 5abee519483bc9f8cb37ce280e814bd0
app-id: wx5b23ba7a5589ecbb # 测试号(自己的)
secret: 2a7b3b20c537e52e74afd395eb85f61f
# app-id: wxa69ab825b163be19 # 测试号Kongdy 提供的)
# secret: bd4f9fab889591b62aeac0d7b8d8b4a0
# app-id: wxa69ab825b163be19 # 测试号Kongdy 提供的)
# secret: bd4f9fab889591b62aeac0d7b8d8b4a0
# 存储配置,解决 AccessToken 的跨节点的共享
config-storage:
type: RedisTemplate # 采用 RedisTemplate 操作 Redis会自动从 Spring 中获取
@ -184,10 +184,10 @@ wx:
miniapp: # 小程序配置(必填),参见 https://github.com/Wechat-Group/WxJava/blob/develop/spring-boot-starters/wx-java-miniapp-spring-boot-starter/README.md 文档
# appid: wx62056c0d5e8db250 # 测试号(牛希尧提供的)
# secret: 333ae72f41552af1e998fe1f54e1584a
# appid: wx63c280fe3248a3e7 # wenhualian的接口测试号
# secret: 6f270509224a7ae1296bbf1c8cb97aed
# appid: wxc4598c446f8a9cb3 # 测试号Kongdy 提供的)
# secret: 4a1a04e07f6a4a0751b39c3064a92c8b
# appid: wx63c280fe3248a3e7 # wenhualian的接口测试号
# secret: 6f270509224a7ae1296bbf1c8cb97aed
# appid: wxc4598c446f8a9cb3 # 测试号Kongdy 提供的)
# secret: 4a1a04e07f6a4a0751b39c3064a92c8b
appid: wx66186af0759f47c9 # 测试号puhui 提供的)
secret: 3218bcbd112cbc614c7264ceb20144ac
config-storage:
@ -325,6 +325,8 @@ llm:
# 文生图
text_to_image: http://36.103.199.104:5123/generate-image
# 知识库向量嵌入
embed: http://36.103.199.104:8123/embed
--- #################### iot相关配置 TODO 芋艿:再瞅瞅 ####################
iot:
emq:

View File

@ -325,6 +325,8 @@ llm:
# 文生图
text_to_image: http://36.133.1.230:5123/generate-image
# 知识库向量嵌入
embed: http://36.103.199.104:8123/embed
--- #################### iot相关配置 TODO 芋艿:再瞅瞅 ####################
iot:
emq: