RAG原理
大模型没有本地私有知识,所以用户在向大模型提问的时候,大模型只能在它学习过的知识范围内进行回答,而RAG就是在用户在提问的时候 将本地与问题相关的私有知识连同问题一块发送给大模型,进而大模型从用户提供的私有知识范围内进行更精确的回答。
核心技术栈
- SpringAI
- MybatisPlus
- Chroma
- Elasticsearch
- MySQL
核心步骤
文本分块向量化
将文本切分成多个文本块,作者使用markdown来存储文本内容,markdown格式的文本相对来说是比较容易且分的,将文本切分之后 请求向量化接口进行文本向量化,最后将向量的结果写入到原本的数据块中 存储到向量数据库
向量数据库
- Elasticsearch 混合检索使用,知识召回准确度比较高
- Chroma 本地测试 或者小数据集使用 也能混合检索 但是无法像es那样可以模糊混合检索
向量检索
将用户的问题进行向量化,然后调用向量数据库的检索
实现
文本分块存储到向量数据库
@Service("docMarkdownFileParseService") public class DocMarkdownFileParseServiceImpl implements DocFileParseService { @Override public List<Document> parse(MultipartFile file,Integer kdId) { // 初始化markdown配置 MarkdownDocumentReaderConfig config = MarkdownDocumentReaderConfig.builder() .withHorizontalRuleCreateDocument(true) .withIncludeCodeBlock(true) .withIncludeBlockquote(true) .withAdditionalMetadata("knowledgeDocId", kdId) .build(); MarkdownDocumentReader reader = new MarkdownDocumentReader(file.getResource(), config); // 文档切分读取 return reader.get(); } }分块的时候会涉及一些metadata,metadata用来存储数据块的元数据,也可以存储一些自定义字段,可以更好的为混合检索提供支持! 这里我存储了知识文本的ID
MarkdownDocumentReader
我在SpringAI的基础上扩展了MarkdownDocumentReader,主要是将markdown各级标题提取出来组合成titleExpander,最终形成 一级标题-二级标题-三级标题-当前标题 这样的格式,进而为后续的混合检索提供支持
SpringAI默认提供的类没有对表格解析做支持,所以我也支持了表格的解析,所有源码都粘贴到下面
package cn.dataling.rag.application.reader; import org.commonmark.ext.gfm.tables.*; import org.commonmark.ext.gfm.tables.TableBlock; import org.commonmark.ext.gfm.tables.TablesExtension; import org.commonmark.node.*; import org.commonmark.parser.Parser; import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentReader; import org.springframework.ai.reader.markdown.config.MarkdownDocumentReaderConfig; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.core.io.Resource; import java.io.IOException; import java.io.InputStreamReader; import java.util.*; /** * Reads the given Markdown resource and groups headers, paragraphs, or text divided by * horizontal lines (depending on the * {@link MarkdownDocumentReaderConfig#horizontalRuleCreateDocument} configuration) into * {@link Document}s. * * @author Piotr Olaszewski */ public class MarkdownDocumentReader implements DocumentReader { /** * The resource points to the Markdown document. */ private final Resource markdownResource; /** * Configuration to a parsing process. */ private final MarkdownDocumentReaderConfig config; /** * Markdown parser. */ private final Parser parser; /** * Create a new {@link MarkdownDocumentReader} instance. * * @param markdownResource the resource to read */ public MarkdownDocumentReader(String markdownResource) { this(new DefaultResourceLoader().getResource(markdownResource), MarkdownDocumentReaderConfig.defaultConfig()); } /** * Create a new {@link MarkdownDocumentReader} instance. * * @param markdownResource the resource to read * @param config the configuration to use */ public MarkdownDocumentReader(String markdownResource, MarkdownDocumentReaderConfig config) { this(new DefaultResourceLoader().getResource(markdownResource), config); } /** * Create a new {@link MarkdownDocumentReader} instance. * * @param markdownResource the resource to read */ public MarkdownDocumentReader(Resource markdownResource, MarkdownDocumentReaderConfig config) { this.markdownResource = markdownResource; this.config = config; this.parser = Parser.builder() .extensions(Collections.singletonList(TablesExtension.create())) .build(); } /** * Extracts and returns a list of documents from the resource. * * @return List of extracted {@link Document} */ @Override public List<Document> get() { try (var input = this.markdownResource.getInputStream()) { Node node = this.parser.parseReader(new InputStreamReader(input)); DocumentVisitor documentVisitor = new DocumentVisitor(this.config); node.accept(documentVisitor); return documentVisitor.getDocuments(); } catch (IOException e) { throw new RuntimeException(e); } } /** * A convenient class for visiting handled nodes in the Markdown document. */ static class DocumentVisitor extends AbstractVisitor { private final List<Document> documents = new ArrayList<>(); private final List<String> currentParagraphs = new ArrayList<>(); private final MarkdownDocumentReaderConfig config; private Document.Builder currentDocumentBuilder; /** * 存储各级标题的文本内容,用于构建层级title * 数组索引对应标题级别(1-6) */ private final String[] headingLevels = new String[7]; /** * 用于构建表格内容的构建器 */ private final StringBuilder tableBuilder = new StringBuilder(); /** * 是否正在处理表格 */ private boolean inTable = false; /** * 当前表格的列数,用于生成分隔行 */ private int tableColumns = 0; /** * 是否正在处理表头 */ private boolean inTableHeader = false; DocumentVisitor(MarkdownDocumentReaderConfig config) { this.config = config; } /** * Visits the document node and initializes the current document builder. */ @Override public void visit(org.commonmark.node.Document document) { this.currentDocumentBuilder = Document.builder(); super.visit(document); } @Override public void visit(Heading heading) { buildAndFlush(); // 更新当前级别的标题文本(在visit(Text)中设置) // 这里先设置当前级别及更高级别保持不变,清除更低级别的标题 int level = heading.getLevel(); for (int i = level; i < headingLevels.length; i++) { headingLevels[i] = null; } super.visit(heading); } @Override public void visit(ThematicBreak thematicBreak) { if (this.config.horizontalRuleCreateDocument) { buildAndFlush(); } super.visit(thematicBreak); } @Override public void visit(SoftLineBreak softLineBreak) { translateLineBreakToSpace(); super.visit(softLineBreak); } @Override public void visit(HardLineBreak hardLineBreak) { translateLineBreakToSpace(); super.visit(hardLineBreak); } @Override public void visit(ListItem listItem) { translateLineBreakToSpace(); super.visit(listItem); } @Override public void visit(Image image) { String alt = image.getDestination(); // 注意:这里应为getTitle()或getFirstChild()获取alt文本 String url = image.getDestination(); String title = image.getTitle(); // 将图片信息格式化后添加到当前段落中 String imageInfo = String.format("", alt, url, title); this.currentParagraphs.add(imageInfo); super.visit(image); } @Override public void visit(BlockQuote blockQuote) { if (!this.config.includeBlockquote) { return; } translateLineBreakToSpace(); this.currentDocumentBuilder.metadata("category", "blockquote"); super.visit(blockQuote); } @Override public void visit(Code code) { this.currentParagraphs.add(code.getLiteral()); this.currentDocumentBuilder.metadata("category", "code_inline"); super.visit(code); } @Override public void visit(FencedCodeBlock fencedCodeBlock) { if (!this.config.includeCodeBlock) { return; } translateLineBreakToSpace(); String literal = fencedCodeBlock.getLiteral(); Integer openingFenceLength = fencedCodeBlock.getOpeningFenceLength(); Integer closingFenceLength = fencedCodeBlock.getClosingFenceLength(); StringJoiner literalJoiner = new StringJoiner(""); literalJoiner.add("\n"); // 构建开头的代码块标记,包含语言标识 for (int i = 0; i < openingFenceLength; i++) { literalJoiner.add(fencedCodeBlock.getFenceCharacter()); } // 添加语言标识(如果有) String language = fencedCodeBlock.getInfo(); if (language != null && !language.trim().isEmpty()) { literalJoiner.add(language); } literalJoiner.add("\n"); literalJoiner.add(literal); // 构建结尾的代码块标记 for (int i = 0; i < closingFenceLength; i++) { literalJoiner.add(fencedCodeBlock.getFenceCharacter()); } literalJoiner.add("\n"); this.currentParagraphs.add(literalJoiner.toString()); this.currentDocumentBuilder.metadata("category", "code_block"); this.currentDocumentBuilder.metadata("lang", language); // 同时保存在元数据中 super.visit(fencedCodeBlock); } @Override public void visit(CustomBlock customBlock) { if (customBlock instanceof TableBlock tableBlock){ inTable = true; inTableHeader = false; tableBuilder.setLength(0); // 清空表格构建器 tableColumns = 0; // 设置元数据 this.currentDocumentBuilder.metadata("category", "table"); super.visit(tableBlock); // 继续访问表格子节点 // 表格处理完成 if (tableBuilder.length() > 0) { this.currentParagraphs.add(tableBuilder.toString()); } inTable = false; inTableHeader = false; } else { super.visit(customBlock); } } @Override public void visit(CustomNode customNode) { if (customNode instanceof TableBody tableBody){ inTableHeader = false; super.visit(tableBody); } else if (customNode instanceof TableRow tableRow){ if (inTable) { // 处理表格行 int columnCount = 0; StringBuilder rowBuilder = new StringBuilder("|"); // 遍历行中的所有单元格 Node child = tableRow.getFirstChild(); while (child != null) { if (child instanceof TableCell) { columnCount++; String cellContent = extractCellContent((TableCell) child); rowBuilder.append(cellContent).append("|"); } child = child.getNext(); } // 如果是表头行,记录列数并添加分隔行 if (inTableHeader && tableColumns == 0) { tableColumns = columnCount; tableBuilder.append(rowBuilder).append("\n"); // 添加分隔行 tableBuilder.append("|"); tableBuilder.append("---|".repeat(Math.max(0, tableColumns))); tableBuilder.append("\n"); } else { tableBuilder.append(rowBuilder).append("\n"); } } super.visit(tableRow); } else if (customNode instanceof TableCell tableCell){ // 单元格内容在visit(Text)中处理,这里直接继续访问 super.visit(tableCell); } else if (customNode instanceof TableHead tableHead){ inTableHeader = true; super.visit(tableHead); } else { super.visit(customNode); } } @Override public void visit(Text text) { if (text.getParent() instanceof Heading heading) { int level = heading.getLevel(); String currentTitle = text.getLiteral(); // 存储当前级别的标题 headingLevels[level] = currentTitle; // 构建层级title String hierarchicalTitle = buildHierarchicalTitle(level); this.currentDocumentBuilder.metadata("category", "header_%d".formatted(level)) .metadata("title", currentTitle) .metadata("titleExpander", hierarchicalTitle); } else if (!inTable) { // 如果不是在表格中,才添加到当前段落 this.currentParagraphs.add(text.getLiteral()); } // 表格中的文本在extractCellContent方法中处理 super.visit(text); } /** * 构建层级标题 * @param currentLevel 当前标题级别 * @return 层级标题字符串,如 "一级标题 - 二级标题 - 三级标题" */ private String buildHierarchicalTitle(int currentLevel) { List<String> titleParts = new ArrayList<>(); // 从1级标题开始,收集到当前级别为止的所有标题 for (int i = 1; i <= currentLevel; i++) { if (headingLevels[i] != null && !headingLevels[i].trim().isEmpty()) { titleParts.add(headingLevels[i].trim()); } } // 用 " - " 连接所有标题部分 return String.join(" - ", titleParts); } /** * 提取表格单元格内容 */ private String extractCellContent(TableCell tableCell) { StringBuilder cellBuilder = new StringBuilder(); Node child = tableCell.getFirstChild(); while (child != null) { cellBuilder.append(extractNodeText(child)); child = child.getNext(); } // 清理内容:移除首尾空格,将内部多个空格/换行替换为单个空格 String content = cellBuilder.toString().trim(); content = content.replaceAll("\\s+", " "); // 如果单元格内容为空,添加一个空格 if (content.isEmpty()) { content = " "; } return content; } /** * 递归提取节点文本 */ private String extractNodeText(Node node) { if (node instanceof Text) { return ((Text) node).getLiteral(); } else if (node instanceof Code) { return ((Code) node).getLiteral(); } else if (node instanceof StrongEmphasis) { // 加粗文本 return extractChildrenText(node); } else if (node instanceof Emphasis) { // 斜体文本 return extractChildrenText(node); } else if (node instanceof Link) { // 链接 - 提取链接文本 return extractChildrenText(node); } else { // 其他节点类型,递归提取子节点文本 return extractChildrenText(node); } } /** * 提取所有子节点的文本 */ private String extractChildrenText(Node node) { StringBuilder result = new StringBuilder(); Node child = node.getFirstChild(); while (child != null) { result.append(extractNodeText(child)); child = child.getNext(); } return result.toString(); } public List<Document> getDocuments() { buildAndFlush(); return this.documents; } private void buildAndFlush() { if (!this.currentParagraphs.isEmpty() || (inTable && tableBuilder.length() > 0)) { String content; if (inTable && tableBuilder.length() > 0) { // 如果正在处理表格,使用表格内容 content = tableBuilder.toString(); } else { // 否则使用段落内容 content = String.join("\n", this.currentParagraphs); } Document.Builder builder = this.currentDocumentBuilder.text(content); this.config.additionalMetadata.forEach(builder::metadata); Document document = builder.build(); this.documents.add(document); this.currentParagraphs.clear(); tableBuilder.setLength(0); } this.currentDocumentBuilder = Document.builder(); } private void translateLineBreakToSpace() { if (!this.currentParagraphs.isEmpty() && !inTable) { this.currentParagraphs.add(" "); } } } }表格支持还需要添加一下依赖
<dependency> <groupId>org.commonmark</groupId> <artifactId>commonmark-ext-gfm-tables</artifactId> <version>0.22.0</version> </dependency>下面是接受前端上传的markdown文件,以及所选择的知识库ID,然后做文本切块 向量化存储
public List<Document> embeddingDocumentsForMarkdown(Integer kdId, MultipartFile file) { String fileExtension = getFileExtension(file); // 文档切分读取 List<Document> documents = switch (fileExtension) { case "md" -> docFileParseServiceMap.get("docMarkdownFileParseService").parse(file, kdId); case "pdf" -> docFileParseServiceMap.get("docPdfFileParseService").parse(file, kdId); case "docx", "doc" -> docFileParseServiceMap.get("docWordFileParseService").parse(file, kdId); default -> throw new ExceptionCore("不支持的文件类型"); }; if (CollectionUtils.isEmpty(documents)) { return Collections.emptyList(); } vectorStoreComponent.getVectorStore().add(documents); return Collections.emptyList(); }向量数据库
存储文本向量 为向量检索提供支持
package cn.dataling.rag.application.provider; import cn.dataling.rag.application.properties.ChromaProperties; import cn.dataling.rag.application.properties.ElasticsearchProperties; import cn.dataling.rag.application.util.JsonUtils; import cn.dataling.rag.application.vectorstore.ChromaVectorStore; import cn.dataling.rag.application.vectorstore.ElasticsearchAiSearchFilterExpressionConverter; import cn.dataling.rag.application.vectorstore.ElasticsearchVectorStore; import cn.dataling.rag.application.vectorstore.SimpleVectorStore; import com.google.common.collect.Lists; import org.springframework.ai.chroma.vectorstore.ChromaApi; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.vectorstore.VectorStore; /** * 向量存储提供者 */ public final class VectorStoreProvider { /** * 获取向量存储 * * @param vectorStoreType 向量存储类型 * @param embeddingModel 嵌入模型 * @param jsonConfig 配置 */ public static VectorStore getVectorStore(String vectorStoreType, EmbeddingModel embeddingModel, String jsonConfig) { VectorStoreProviderEnum vectorStoreProviderEnum = VectorStoreProviderEnum.valueOf(vectorStoreType); switch (vectorStoreProviderEnum) { case ELASTICSEARCH: ElasticsearchProperties elasticsearchProperties = JsonUtils.toObject(jsonConfig, ElasticsearchProperties.class); elasticsearchProperties.setSimilarity(ElasticsearchVectorStore.SimilarityFunction.cosine); return elasticsearchVectorStore(embeddingModel, elasticsearchProperties); case SIMPLE: return simpleVectorStore(embeddingModel); case CHROMA: ChromaProperties chromaProperties = JsonUtils.toObject(jsonConfig, ChromaProperties.class); return chromaVectorStore(embeddingModel, chromaProperties); default: throw new RuntimeException("vectorStoreType not support"); } } /** * 获取ES向量存储 * * @param embeddingModel 嵌入模型 * @param elasticsearchProperties es配置 */ public static VectorStore elasticsearchVectorStore(EmbeddingModel embeddingModel, ElasticsearchProperties elasticsearchProperties) { return ElasticsearchVectorStore.builder(elasticsearchProperties, embeddingModel) .withFilterExpressionConverter(new ElasticsearchAiSearchFilterExpressionConverter()) .batchingStrategy(docs -> Lists.partition(docs, elasticsearchProperties.getBatchSize())) .build(); } /** * 获取内存向量存储 * * @param embeddingModel 嵌入模型 */ public static VectorStore simpleVectorStore(EmbeddingModel embeddingModel) { return SimpleVectorStore.builder(embeddingModel) .batchingStrategy(docs -> Lists.partition(docs, 100)) .build(); } /** * 获取Chroma向量存储 * * @param embeddingModel 嵌入模型 * @param chromaProperties chroma配置 */ public static VectorStore chromaVectorStore(EmbeddingModel embeddingModel, ChromaProperties chromaProperties) { ChromaApi chromaApi = ChromaApi.builder() .baseUrl(chromaProperties.getBaseUrl()) .build(); return ChromaVectorStore.builder(chromaApi, embeddingModel) .collectionName(chromaProperties.getCollectionName()) .tenantName(chromaProperties.getTenantName()) .batchingStrategy(docs -> Lists.partition(docs, chromaProperties.getBatchSize())) .databaseName(chromaProperties.getDatabaseName()) .initializeSchema(true) .initializeImmediately(true) .build(); } /** * 向量存储提供者枚举 */ public enum VectorStoreProviderEnum { ELASTICSEARCH("ES"), SIMPLE("内存"), CHROMA("Chroma"), ; private final String value; VectorStoreProviderEnum(String value) { this.value = value; } public String getValue() { return value; } } }RAG检索增强
public Flux<AssistantMessage> chatWithRag(ChatWithRagDTO data) { // 查询知识文档 KnowledgeDoc knowledgeDoc = knowledgeDocService.getKnowledgeDocById(data.getKnowledgeDocId()); if (ObjectUtils.isEmpty(knowledgeDoc)) { return Flux.just(new AssistantMessage("知识库不存在")); } // 获取知识文档的提示词 Integer promptId = knowledgeDoc.getPromptId(); PromptInfo promptInfo = promptInfoMapper.selectById(promptId); // 查询模型信息 Model model = modelMapper.selectById(data.getChatModelId()); // 获取对话客户端 ChatClient chatClient = chatClientProvider.getChatClient(model.getProvider(), model.getName(), model.getApiUrl(), model.getApiKey()); String delimiterToken = ObjectUtils.isEmpty(promptInfo) ? "{}" : promptInfo.getDelimiterToken(); StTemplateRenderer stTemplateRenderer = ObjectUtils.isEmpty(delimiterToken) ? StTemplateRenderer.builder().startDelimiterToken('{').endDelimiterToken('}').build() : StTemplateRenderer.builder().startDelimiterToken(delimiterToken.charAt(0)).endDelimiterToken(delimiterToken.charAt(1)).build(); // 构建提示词 同时将工具信息添加到提示词模板中 PromptTemplate promptTemplate = ObjectUtils.isEmpty(promptId) ? defaultPromptTemplate : PromptTemplate.builder() .template(promptInfoService.getPromptInfoById(promptId).getContent()) // 自定义模板分隔符(避免与 JSON 冲突 ) 默认分隔符 {} 可能与 JSON 语法冲突,可修改为 <> .renderer(stTemplateRenderer) .variables(Map.of("tools", getMcpToolsDefinition())) .build(); VectorStore vectorStore = vectorStoreComponent.getVectorStore(); RetrievalAugmentationAdvisor augmentationAdvisor = RetrievalAugmentationAdvisor.builder() // 阶段一:优化用户问题 将单个查询扩展为多个相关查询 .queryExpander(query -> data.getQueryExpander() ? queryExpander(chatClient, query.text()) : List.of(query)) // 阶段二: 根据查询检索相关文档 根据扩展后的查询进行检索 默认会使用线程池并行查询 .documentRetriever(query -> similaritySearch(data.getTopK(), data.getSimilarityThreshold(), query.text(), data.getKnowledgeDocId(), vectorStore)) // 阶段三:合并来自多个查询结果 合并多查询/多数据源的检索结果,去重 .documentJoiner(new ConcatenationDocumentJoiner()) // 阶段四:对检索到的文档进行后置处理 对检索到的文档进行后处理,如重排序 .documentPostProcessors((query, documents) -> data.getRerank() ? documentRerank(documents, query.text()) : documents) // 阶段五:查询增强阶段 将检索到的文档上下文融入原始查询 生成最终的prompt prompt中要包含 context 和 query 分别代表上下文和查询 .queryAugmenter(ContextualQueryAugmenter.builder() .documentFormatter(documents -> documents.stream() .map(e -> { String temp = """ 标题: %s 内容: %s """; Map<String, Object> metadata = e.getMetadata(); String titleExpander = CollectionUtils.isEmpty(metadata) ? "无标题" : (metadata.containsKey("titleExpander") ? metadata.get("titleExpander").toString() : "无标题"); return String.format(temp, titleExpander, e.getText()); }) .reduce((a, b) -> a + "\n\n" + b) .orElse("未检测到相关知识")) // 允许空上下文 如果为true的话 当上下文为空 模型会跳过上下文 使用自己的知识进行回答 .allowEmptyContext(false) .emptyContextPromptTemplate(emptyContextPrompt) .promptTemplate(promptTemplate) .build()) .build(); return chatClient.prompt() .user(data.getText()) .toolCallbacks(toolCallbackProvider) .advisors(MessageChatMemoryAdvisor.builder(jdbcChatMemory).build(), augmentationAdvisor) .advisors(a -> a.param(ChatMemory.CONVERSATION_ID, data.getConversationId())) .stream() .chatResponse() .map(e -> e.getResult().getOutput()) .takeWhile(assistantMessage -> IS_STREAM.getOrDefault(data.getConversationId(), true)) .onErrorResume(throwable -> Flux.just(AssistantMessage.builder().content(String.format("模型调用异常 %s", throwable.getCause().getMessage())).build())) .doFinally(d -> IS_STREAM.remove(data.getConversationId())); }