RAG 标准流程:
- 索引:外挂知识库
- 检索
- 生成
Advanced RAG
针对上述 3 个阶段做了优化。例如检索阶段,新增了 检索前处理 以及 检索后处理。
检索前处理:
- 查询转换
- 查询扩充
- ......
查询扩充(Query Expansion)
在不改变用户意图的前提下,添加相关词语或同义表达,让检索系统能够匹配到更多语义相关的文档。例如用户输入:
项目合同
扩充后的 Query 变成了:
["项目合同", "合作协议", "法律文件", "合同模板"]
检索后处理:
常见的处理有:
- 重排序
- 过滤无关内容
- 合并去重
- 精简摘要
- 格式优化
MultiQueryRetriever
MultiQueryRetriever 是 LangChain 中的一个工具类,作用是 用一个大模型 把用户问题改写成多种表述(多视角查询),对每个改写分别检索,然后合并去重结果,缓解“单一措辞导致召回不足”的问题。
召回率(Recall)
指的是所有相关内容中被成功检索出来的比例。公式为:
召回率 = 检索到的相关内容数量 / 所有实际相关内容的总数
例如:实际相关文档有 10 个,系统只检索到其中 6 个,那么召回率 = 6 / 10 = 60%
改写原理
非常简单,就是调用大模型,给予大模型如下的提示词:
// 演示 MultiQueryRetriever 背后原理
import { PromptTemplate } from "@langchain/core/prompts";
import { StringOutputParser } from "@langchain/core/output_parsers";
import { ChatOllama } from "@langchain/ollama";const pt = PromptTemplate.fromTemplate(`
You are an AI language model assistant. Your task is to generate {queryCount}
different versions of the given user question to retrieve relevant documents from a vector database.
By generating multiple perspectives on the user question, your goal is to help the user overcome some of
the limitations of the distance-based similarity search.
Provide these alternative questions separated by newlines.
Original question: {question}`);const model = new ChatOllama({model: "llama3",temperature: 0.7,
});const parser = new StringOutputParser();const chain = pt.pipe(model).pipe(parser);
const res = await chain.invoke({queryCount: 3,question: "奥特曼的技能有哪些?",
});console.log(res);
快速上手
// 演示 MultiQueryRetriever 基本使用
import { MemoryVectorStore } from "@langchain/classic/vectorstores/memory";
import { NomicEmbeddings } from "../utils/embed2.js";
import { TextLoader } from "@langchain/classic/document_loaders/fs/text";
import { RecursiveCharacterTextSplitter } from "@langchain/classic/text_splitter";
import { MultiQueryRetriever } from "@langchain/classic/retrievers/multi_query";
import { ChatOllama } from "@langchain/ollama";const loader = new TextLoader("../data/kong.txt");
const docs = await loader.load();const splittedDocs = await new RecursiveCharacterTextSplitter({chunkSize: 64,chunkOverlap: 0,
}).splitDocuments(docs);const store = new MemoryVectorStore(new NomicEmbeddings(4));await store.addDocuments(splittedDocs);const retriever = MultiQueryRetriever.fromLLM({llm: new ChatOllama({model: "llama3",temperature: 0.7,}), // 使用什么模型来改写提示词retriever: store.asRetriever(2), // 基础检索器queryCount: 3, // 提示词要扩充的数量,verbose: true, // 打开调试日志
});const res = await retriever.invoke("茴香豆是做什么用的?");console.log(res);
- llm:用于改写多版本提示词的模型
- retriever:底层的单查询检索器
- queryCount:要生成多少条不同的查询改写。
- verbose:控制台输出调试日志
课堂练习:使用 MultiQueryRetriever 改写用户原始提示词
🤔思考:创建检索器的时候 k 指定的是 2,为什么有 4 条结果?
k 指定的是 2,理论上应该是 2 条结果
理论上是 2 * 3 = 6 条结果,不过经过去重之后,得到了 4 条。
ContextualCompressionRetriever
该工具类用于在检索出原始相关文档之后,进一步“压缩”内容或 过滤掉不相关部分,返回更简洁、更聚焦于问题的文档片段,属于检索后处理的一种。
基本语法:
new ContextualCompressionRetriever({baseRetriever: r, // 指定基础的检索器baseCompressor: LLMChainExtractor.fromLLM(llm), // 配置压缩器
});
LLMChainExtractor.fromLLM(llm) 相当于配置了一个压缩器,调用 LLM,从候选文档中提取出与问题相关的段落/句子。
课堂演示:使用 LLMChainExtractor 对文档进行压缩
// 演示 LLMChainExtractor 对文档进行压缩
import { ChatOllama } from "@langchain/ollama";
import { Document } from "@langchain/classic/document";
import { LLMChainExtractor } from "@langchain/classic/retrievers/document_compressors/chain_extract";const llm = new ChatOllama({model: "llama3",temperature: 0.7,
});// 创建了一个压缩器
const extractor = LLMChainExtractor.fromLLM(llm);// 构建一段文档
const docs = [new Document({pageContent: `孔乙己走到酒店,说“要一碟茴香豆”。茴香豆常作为下酒小菜,价格便宜,常见于短衣帮和穿长衫的顾客点酒时。`,}),new Document({pageContent: `另一个段落:今天的天气很好,阳光明媚。`,}),new Document({pageContent: `游泳池边上,游过一群小黄鸭`,}),new Document({pageContent: `茴香豆可以拿来泡茶,上次我就看到有一个大婶儿拿茴香豆来泡茶`,}),
];const query = "茴香豆的作用";const res = await extractor.compressDocuments(docs, query);console.log(res);
课堂练习:使用 ContextualCompressionRetriever 过滤检索结果
// 演示 ContextualCompressionRetriever 的使用import { MemoryVectorStore } from "@langchain/classic/vectorstores/memory";
import { NomicEmbeddings } from "../utils/embed2.js";
import { TextLoader } from "@langchain/classic/document_loaders/fs/text";
import { RecursiveCharacterTextSplitter } from "@langchain/classic/text_splitter";
import { ChatOllama } from "@langchain/ollama";
import { ContextualCompressionRetriever } from "@langchain/classic/retrievers/contextual_compression";
import { LLMChainExtractor } from "@langchain/classic/retrievers/document_compressors/chain_extract";const loader = new TextLoader("../data/kong.txt");const docs = await loader.load();const splitter = new RecursiveCharacterTextSplitter({chunkSize: 64,chunkOverlap: 0,
});const splittedDocs = await splitter.splitDocuments(docs);const embeddings = new NomicEmbeddings(3);const store = new MemoryVectorStore(embeddings);await store.addDocuments(splittedDocs);const retriever = store.asRetriever(2); // 创建一个基础的检索器const res = await retriever.invoke("茴香豆是做什么用的?");console.log(`压缩前的检索结果:`, res);const llm = new ChatOllama({model: "llama3",temperature: 0.7,
});const r = new ContextualCompressionRetriever({baseRetriever: retriever,baseCompressor: LLMChainExtractor.fromLLM(llm), // 使用什么模型来压缩文档
});const res2 = await r.invoke("茴香豆是做什么用的?");
console.log(`压缩后的检索结果:`, res2);
ScoreThresholdRetriever
一个基于相似度“阈值过滤”的检索器:它在向量库中检索时,只返回 相似度分数不低于 你设定阈值的文档。该类继承自 VectorStoreRetriever,提供 invoke() 等可运行接口方法。
使用示例:
const r = ScoreThresholdRetriever.fromVectorStore(vectorstore, {minSimilarityScore: 0.7,
});const res = await r.invoke("茴香豆是做什么用的?");console.log(res);
- minSimilarityScore:分数阈值,返回的文档的相似度分数不能小于该值。
除了该配置项以外,还支持下面的配置项:
-
kIncrement: number(默认 10):每次追加抓取的候选数量增量,用于在需要时扩大候选集,再按阈值过滤。 -
maxK: number(默认 100):抓取候选的上限,避免无限扩大搜索范围。 -
searchType: "similarity" | "mmr":可选设为"mmr"做多样性与相关性的权衡;默认"similarity"。 -
其余通用项:
filter?、verbose?、callbacks?、tags?、metadata?。
import { MemoryVectorStore } from "@langchain/classic/vectorstores/memory";
import { NomicEmbeddings } from "../utils/embed2.js";
import { TextLoader } from "@langchain/classic/document_loaders/fs/text";
import { RecursiveCharacterTextSplitter } from "@langchain/classic/text_splitter";
import { ScoreThresholdRetriever } from "@langchain/classic/retrievers/score_threshold";const loader = new TextLoader("../data/kong.txt");const docs = await loader.load();const splittedDocs = await new RecursiveCharacterTextSplitter({chunkSize: 64,chunkOverlap: 0,
}).splitDocuments(docs);const store = new MemoryVectorStore(new NomicEmbeddings(3));await store.addDocuments(splittedDocs);const r = ScoreThresholdRetriever.fromVectorStore(store, {minSimilarityScore: 0.72,
});
const res = await r.invoke("茴香豆是做什么用的?");
console.log(res);