https://arxiv.org/pdf/2503.20757v1这篇论文提出了MCTS-RAG框架,用于解决小型语言模型在知识密集型任务上的推理能力不足问题。具体来说,
- MCTS-RAG框架:MCTS-RAG通过迭代地精炼检索和推理过程来工作。给定一个查询,它探索多个推理路径,并在关键决策点动态地结合检索操作。检索到的知识用于评估中间状态,并通过反向传播强化有益的检索路径。
- 动作空间定义:在每个MCTS决策点,设计了以下离散动作:
- A1:直接回答:基于现有推理或预先已知的上下文提供即时响应。
- A2:快速推理:基于当前上下文执行快速的增量推理步骤。
- A3:分解问题:将复杂查询分解为更小、更易管理的子问题。
- A4:检索推理:在进行下一步推理之前,从内部或外部源主动检索相关知识。
- A5:检索分解:首先分解复杂问题,然后获取相关知识以解决各个子问题。
- A6:总结回答:生成简洁且结构化的总结,综合之前的推理和检索信息。
-
检索过程:动态地在演变的MCTS推理环境中检索信息,确保及时且相关地整合外部知识。模型自主确定何时需要检索,生成有针对性的查询,并将外部知识关键地整合到推理中以提高准确性。
-
确定最终答案:在MCTS探索结束时,通过投票机制和一致性分析选择最佳答案。每个推理轨迹产生一个候选答案,这些候选答案根据语义一致性分组为唯一答案集。最终分数计算如下:
其中, 是沿着其对应推理轨迹的所有节点的奖励的乘积。最终答案通过以下公式确定:
实验设计
当然可以!我们来深入解读 MCTS-RAG 项目的核心代码逻辑。该项目结合了 蒙特卡洛树搜索(MCTS) 和 检索增强生成(RAG),通过动态迭代优化推理路径和知识检索,提升小型语言模型在复杂任务中的表现。
🧠 项目核心思想
https://github.com/yale-nlp/MCTS-RAG
📁 核心文件结构(简化)
从 GitHub 项目的目录结构来看,关键模块包括:
mcts_rag/
├── mcts.py # MCTS 主要逻辑
├── rag_model.py # RAG 模型封装
├── generator.py # 生成器模块
├── discriminator.py # 判别器模块
├── utils.py # 工具函数
└── scripts/ # 运行脚本
接下来我们逐一解析这几个模块的关键代码逻辑。
🌲 一、mcts.py
—— MCTS 核心逻辑
这是整个项目的灵魂部分,实现了基于 MCTS 的推理框架。
1. 节点类 Node
class Node:def __init__(self, state, parent=None):self.state = state # 当前状态(比如问题 + 推理步骤)self.parent = parent # 父节点self.children = [] # 子节点self.visits = 0 # 被访问次数self.value = 0 # 节点价值self.untried_actions = [...] # 可选动作(如生成下一步 or 检索新信息)
每个节点代表一个“推理状态”,包含当前的问题、已有的推理步骤以及可能的动作。
2. MCTS 四步流程
标准的 MCTS 包括四个步骤:
- 选择(Selection):从根节点开始,按照某种策略(如 UCB)选择最有潜力的子节点。
- 扩展(Expansion):如果当前节点未完全展开,则创建新的子节点。
- 模拟(Simulation):从当前节点随机 rollout 直到结束(即生成完整回答)。
- 反向传播(Backpropagation):根据 rollout 结果更新路径上所有节点的价值。
在 mcts.py
中,这四个步骤被封装为函数:
def tree_policy(node):# 实现 UCT 或其他选择策略...def expand(node):# 创建新的子节点...def simulate(node):# 使用 generator 和 retriever 生成一个完整回答...def backpropagate(node, reward):# 更新路径上的节点价值...
3. Rollout 示例(模拟)
def rollout(node):state = node.state.copy()for _ in range(max_depth):if is_terminal(state):breakaction = random.choice(get_possible_actions(state))state = apply_action(state, action)return evaluate(state) # 返回最终得分
其中:
apply_action
可能调用generator
或retriever
evaluate
使用discriminator
对当前推理路径进行评分
🧠 二、rag_model.py
—— RAG 模型封装
该模块负责封装语言模型和检索器,提供统一接口。
class RAGModel:def __init__(self, model_name, retriever):self.model = AutoModelForCausalLM.from_pretrained(model_name)self.tokenizer = AutoTokenizer.from_pretrained(model_name)self.retriever = retrieverdef generate(self, prompt):inputs = self.tokenizer(prompt, return_tensors="pt").to(device)outputs = self.model.generate(**inputs, max_new_tokens=200)return self.tokenizer.decode(outputs[0], skip_special_tokens=True)def retrieve(self, query):return self.retriever.retrieve(query)
这里你可以看到典型的 RAG 架构:
- 输入 prompt 会先经过
retrieve
获取相关信息 - 再拼接到 prompt 中送入
generate
生成答案
✍️ 三、generator.py
—— 生成器模块
这个模块定义了如何使用语言模型进行推理生成。
class Generator:def __init__(self, model):self.model = modeldef propose_reasoning_step(self, current_state):prompt = self._build_prompt(current_state)reasoning_step = self.model.generate(prompt)return reasoning_step
每次 MCTS 扩展节点时,都可能调用 propose_reasoning_step
来生成一个新的推理步骤。
🔍 四、discriminator.py
—— 判别器模块
判别器用于评估某个推理路径的可信度,是 MCTS 中 reward 的来源。
class Discriminator:def __init__(self, model):self.model = modeldef score_path(self, path):prompt = self._build_score_prompt(path)score = self.model.generate(prompt) # 输出一个分数return float(score)
判别器可以是一个独立的小型模型,也可以是同一个模型的不同模式。
🛠️ 五、utils.py
—— 工具函数
提供一些辅助函数,例如:
- 构建 prompt 模板
- 处理 JSON 数据
- 日志记录
- 缓存机制(避免重复生成或检索)
示例:
def build_prompt(question, retrieved_docs, steps):prompt = f"Question: {question}\n"prompt += "Retrieved Docs:\n" + "\n".join(retrieved_docs) + "\n"prompt += "Reasoning Steps So Far:\n" + "\n".join(steps) + "\n"prompt += "Next Step:"return prompt
🚀 总结:MCTS-RAG 的运作流程
以下是整个系统的执行流程图:
开始
│
├─ 初始化 MCTS 根节点(初始问题)
│
└─ 循环以下步骤直到达到最大迭代次数:│├─ Selection: 选择最有潜力的节点├─ Expansion: 如果节点未完全展开,则生成新推理步骤或检索新内容├─ Simulation: 随机 rollout 完成推理├─ Backpropagation: 使用判别器评分并更新节点价值│
└─ 最终选择评分最高的路径作为答案
📚 建议阅读源码顺序
如果你打算深入研究这个项目,建议按如下顺序阅读源码:
mcts.py
→ 理解 MCTS 整体架构utils.py
→ 看懂 prompt 构造与数据处理方式generator.py
/discriminator.py
→ 了解生成和评估机制rag_model.py
→ 理解 RAG 模型集成方式scripts/run_*.sh
→ 查看运行参数配置
📌 小贴士
- 你可以在本地运行这个项目,但需要安装 PyTorch、Transformers、vLLM 等依赖。
- 训练自己的判别器 可以进一步提高系统效果。
- 可尝试更换不同 SLM 模型(如 Llama-3-8B、Phi-3 等)测试性能差异。
结果与分析
- 主要发现:MCTS-RAG在不同数据集上一致优于基线方法,展示了强大的多步推理和检索能力。在CWQA上,Llama 3.1-8B提高了超过20%,Qwen2.5-7B提高了约6%。在GPQA上,分别提高了约15%和10%。在FMT上,Llama 3.1-8B提高了超过10%,Qwen2.5-7B提高了4%。与基线方法相比,MCTS-RAG的平均性能提高了约14%。
- 不同动作的影响:检索动作(特别是A4和A5)对于多步推理至关重要。启用所有检索提高了GPQA(+20.88%)和FMT(+16.10%)。禁用A5提高了GPQA(+7.84%)和FMT(+6.43%),表明A4的作用更强。CWQA的影响最小(+1.25%)。
- 不同展开策略的影响:更多的展开提高了性能,特别是对GPQA。从4增加到8略微提高了CWQA(+3%),从8到12提高了GPQA(+11%)。扩展到16进一步提高了GPQA(+9%)和FMT(+5%),强化了迭代推理的价值。
总体结论
这篇论文提出了MCTS-RAG框架,通过结合MCTS的推理和检索能力以及自适应检索机制,提高了多步推理的准确性和可靠性。MCTS-RAG在处理需要深入外部知识的跨域任务时表现出色,不仅能够灵活地制定高质量的检索查询,还能通过迭代树探索细化推理路径,从而减少由浅层检索或简单推理引起的幻觉。实验结果表明,MCTS-RAG在复杂推理任务、知识增强的科学问答任务和具有挑战性的事实核查任务中取得了良好的效果。未来的工作将集中在优化搜索效率、开发自适应动作选择策略、基于置信度的检索过滤和错误感知剪枝机制上,以进一步提高MCTS探索的效率。
思考与见解
与传统RAG的静态管道不同,MCTS-RAG引入了一种“推理驱动的知识增强”的新范式。它提供了比CoT更多的灵活性和控制,可能是迈向下一代智能问答系统的关键一步。
话虽如此,我有几点想法。
虽然它比端到端使用LLMs更高效,但MCTS中的推出阶段仍然引入了不可忽视的计算成本。
此外,奖励被定义为推理路径上行动分数的乘积,但值得质疑这是否完全捕捉了推理轨迹的“正确性”。
总的来说,MCTS-RAG提供了一种有意义的方法论转变。在我看来,它试图弥合结构化算法与非结构化语言推理之间的差距——本质上通过离散控制流引导语言生成。这就是它的潜力所在,但也是其最大风险所在。
关键问题及回答
问题1:MCTS-RAG框架中的动作空间是如何定义的?这些动作各自的作用是什么?
在MCTS-RAG框架中,动作空间定义了在每个MCTS决策点可以执行的操作。具体动作包括:
- A1:直接回答:基于现有推理或已知上下文提供即时响应,适用于简单查询或不需要额外分析的情况。
- A2:快速推理:在当前上下文基础上执行快速的增量推理步骤,适用于探索性路径或初步判断,以高效指导搜索。
- A3:分解问题:将复杂查询分解为更小、更易管理的子问题,有助于清晰的问题解决路径和提高推理效率,特别适用于多部分或复杂问题。
- A4:检索推理:在下一步推理之前,从内部或外部源主动检索相关信息,对于需要补充信息的查询或现有上下文不完整的情况至关重要。
- A5:检索分解:结合分解和检索,首先分解复杂问题,然后获取相关知识以解决各个子问题,特别适用于涉及详细上下文依赖子问题的查询。
- A6:总结回答:生成简洁、结构化的总结,综合之前推理和检索的信息,提供连贯且全面的响应,特别适用于需要总结或整合多方面信息的查询。
这些动作旨在解决推理-检索相互作用的特定方面,确保模型能够在导航问题空间时动态调整其策略。
问题2:MCTS-RAG框架中的检索过程是如何设计的?它在推理过程中起到了什么作用?
MCTS-RAG框架中的检索过程是动态地在MCTS推理环境中进行的,确保及时和相关地整合外部知识。具体步骤包括:
- R1:查询生成:如果检测到知识缺口,模型生成搜索查询。
- R2:查询执行:使用外部检索工具获取最相关的信息。
- R3:知识反思:评估检索数据的相关性和一致性,以确定其是否包含在推理过程中。
- R4:总结回答:整合精炼后的信息,使模型能够回答子问题或推进推理。
检索过程在推理过程中的作用是确保模型的推理不断更新并通过外部数据进行验证,从而减少错误并增强最终输出的稳健性。通过将检索与推理交织在一起,信息流得以简化,产生既简洁又富有信息量的输出。如果先前检索的数据足以回答当前推理步骤确定的问题,模型会跳过额外的检索,避免冗余。
问题3:MCTS-RAG在实验中表现如何?与其他基线方法相比有哪些优势?
MCTS-RAG在实验中表现出色,展示了强大的多步推理和检索能力。具体结果如下:
- ComplexWebQA (CWQA):在Llama 3.1-8B和Qwen2.5-7B上分别提高了20%和6%。
- Graduate-Level Google-Proof QA (GPQA):在Llama 3.1-8B和Qwen2.5-7B上分别提高了约15%和10%。
- FoolMeTwice (FMT):在Llama 3.1-8B和Qwen2.5-7B上分别提高了10%和4%。
与其他基线方法相比,MCTS-RAG的优势包括:
- 提高推理准确性:通过新的检索动作,小型语言模型能够获取外部知识,增强了问答的质量。
- 优化查询构建:精炼过程确保每个查询专注于特定的信息需求,提高了检索查询生成的有效性。
- 增强检索质量:反思和总结检索信息有助于减少语义差异,确保与核心问题的对齐。
- 减少幻觉:通过详细的显式推理步骤,MCTS-RAG减少了由于浅层检索或简单推理引起的幻觉。
总体而言,MCTS-RAG在复杂推理任务、知识增强的科学问答任务和具有挑战性的事实核查任务中展示了卓越的性能和效率。