Paper Reading:Neural Prototype Trees for Interpretable Fine-grained Image Recognition

news/2025/11/8 22:53:39/文章来源:https://www.cnblogs.com/linfangnan/p/19202528

目录
  • 研究动机
  • 文章贡献
  • 预备知识
  • 本文方法
    • ProtoTree 结构
    • 训练 ProtoTree
    • 原型可视化
    • 确定性推理
  • 实验结果
    • 数据集和实验设置
    • 对比实验
    • 树的高度影响
    • 剪枝与原型替换的影响
      • 确定性推理策略评估
    • 可视化分析
  • 优点和创新点

Paper Reading 是从个人角度进行的一些总结分享,受到个人关注点的侧重和实力所限,可能有理解不到位的地方。具体的细节还需要以原文的内容为准,博客中的图表若未另外说明则均来自原文。

论文概况 详细
标题 《Neural Prototype Trees for Interpretable Fine-grained Image Recognition》
作者 Meike Nauta, Ron van Bree, Christin Seifert
发表会议 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)
发表年份 2021
会议等级 CCF-A
论文代码 https://github.com/M-Nauta/ProtoTree

作者单位:

  1. University of Twente, the Netherlands
  2. University of Duisburg-Essen, Germany

研究动机

深度学习模型(如 DNNs)虽然在计算机视觉等任务上表现出色,但其复杂的内部机制和高维特征空间使其如同“黑箱”,缺乏透明度。这在医疗、金融等高风险决策领域尤为关键,引发了业界对模型可解释性(Interpretability)透明度(Transparency) 的迫切需求。机器学习中长期存在一个公认的困境,即模型的预测能力与其可解释性往往相互制约。简单的模型(如决策树)易于理解但性能有限,而复杂的模型(如 DNNs)性能强大却难以解释。探索如何将深度学习的强大表示能力与决策树的直观可解释性结合起来,打造一个既强大又透明的模型具有重要的意义。
当前许多解释方法属于“事后解释”,即在一个训练好的黑箱模型之上,通过近似手段(如显著性图、代表点)来推测其决策原因。这类方法可能不稳定、不忠实于原模型。本文主张从模型设计之初就内置可解释性,构建一个“本质可解释”的模型。这样,模型提供的解释是其真实决策过程的反映,是全局且可靠的,而非局部的、可能具有误导性的近似。

文章贡献

本文提出了一种面向细粒度图像识别的本质可解释深度学习模型——神经原型树(ProtoTree),将原型学习与决策树结构相结合。首先使用卷积神经网络(CNN)将输入图像映射为潜在特征表示;然后通过一个二叉决策树进行层次化推理,其中每个内部节点包含一个可学习的原型,通过计算图像特征与原型之间的相似度来决定路由方向(向左或向右);最终,样本以一定概率到达各个叶子节点,节点的类别分布加权汇总产生预测结果。ProtoTree 的决策过程类似于人类玩游戏时的渐进式问答(如“这只鸟有红喉咙吗?有细长喙吗?那么它是蜂鸟!”),从而实现了全局可解释性(整个树结构可被理解)和局部可解释性(单样本预测路径可追溯)。ProtoTree 在性能上显著优于同类可解释模型,在 CUB-200-2011 和 Stanford Cars 数据集上达到更高准确率的同时,将所需原型数量减少约 90%,并通过剪枝、确定性推理等技术进一步优化了解释效率。
image

预备知识

原型学习是一种模仿人类认知过程中概念形成方式的机器学习方法。其核心思想是:对于每个类别,学习一个或多个最具代表性的典型样本(即"原型"),然后将新样本与这些原型进行比较来做出分类决策。模型通过展示"这个看起来像那个"来进行推理,提供了直观的决策依据。在技术实现上,原型学习通常包含三个关键步骤:

  1. 原型发现:从训练数据中自动学习每个类别的代表性特征模式。这些原型不是简单的训练样本复制,而是通过优化过程学习到的"典型特征表示"。
  2. 相似度计算:对于新的输入样本,计算其与所有原型的相似度或距离。这通常在模型的潜在特征空间中进行,使用如欧氏距离或余弦相似度等度量方法。
  3. 决策制定:基于相似度得分进行最终分类。最简单的方式是将样本分配给最相似原型所属的类别,也可以采用更复杂的加权组合方式。

与使用整张图像作为原型的 ProtoAttend 等方法不同,本文与 ProtoPNet 一致关注原型部分,将决策过程分解为基于局部特征的小步骤,这更符合细粒度识别的需求。原型部分网络(ProtoPNet)是该领域的重要基线,它为每个类别学习一定数量的原型,通过计算输入图像与所有原型的相似度并进行加权组合来完成分类,如下图所示。这种方法存在解释性瓶颈:对于 CUB 数据集(200 类),需要同时考虑 2000个 原型,决策过程缺乏层次性。
image
软决策树(SDTs) 允许样本以概率形式遍历所有路径,比传统硬决策树更易于用梯度下降训练。近年来,深度 SDTs(如 DNDF, ANTs)将神经网络作为特征学习器融入树中,但这些方法往往牺牲了决策树的可解释性。一些研究尝试恢复 SDTs 的可解释性,例如通过可视化显著性图或感知机权重来解释节点决策,但这些方法要么解释不直观,要么因模型表示能力有限而影响精度。另一项工作在线性分裂参数中引入空间正则化,但主要应用于简单数据集。
image

本文方法

ProtoTree 结构

ProtoTree 旨在解决监督学习分类问题,给定一个包含 N 个标注图像的训练集 \(\{(x^{(1)},y^{(1)}),...,(x^{(N)},y^{(N)})\} \in \mathcal{X} \times \mathcal{Y}\),模型的目标是:对于输入图像 \(x\),预测其类别概率分布 \(\hat{y}\)(共 K 个类)。训练过程通过最小化预测分布 \(\hat{y}\) 与真实标签 \(y\)(one-hot 编码)之间的交叉熵损失来实现。一个 ProtoTree 模型 \(T\) 由两部分串联而成:

  1. 卷积神经网络(CNN):函数 \(f\),参数为 \(\omega\)。输入图像 \(x\) 经过该网络后,得到一个潜空间表示 \(z = f(x; \omega)\),其形状为 \(H \times W \times D\)(即 D 个 \(H \times W\) 的特征图)。
  2. 软神经二叉决策树:以 CNN 的输出 \(z\) 作为输入进行决策。

决策树包含三种元素,如下表所示:

树结构组件 说明
内部节点集合 \(\mathcal{N}\) 每个内部节点 \(n \in \mathcal{N}\) 都关联一个可训练的原型(prototype) \(p_n \in P\)。原型是一个形状为 \(H_1 \times W_1 \times D\) 的张量(在实现中 \(H_1 = W_1 = 1\))。
叶子节点集合 \(\mathcal{L}\) 每个叶子节点 \(\ell \in \mathcal{L}\) 关联一个可训练的类别分布参数 \(c_{\ell}\),经过softmax 函数 \(\sigma(c_{\ell})\) 后得到该叶子节点上的类别概率分布。
边集合 \(\mathcal{E}\) 每个内部节点 \(n\) 有两条出边,分别连接左子节点 \(n\).left 和右子节点 \(n\).right。

对于每个内部节点 \(n\),将其原型 \(p_n\) 作为卷积核,在 CNN 输出 \(z\) 上滑动,计算与每个图像块 \(\tilde{z}\) 的欧氏距离。找到最相似的图像块:

\[\tilde{z}^{*} = \underset{\tilde{z} \in \text{patches}(z)}{\operatorname{argmin}} ||\tilde{z} - p_{n}|| \]

然后,根据该最小距离计算图像 \(x\) 在节点 \(n\) 处被路由到右子节点的概率(即右边 \(e(n, n.\text{right})\) 的激活概率)如下公式所示,路由到左子节点的概率为:\(p_{e(n, n.\text{left})} = 1 - p_{e(n, n.\text{right})}\)

\[p_{e(n, n.\text{right})}(z) = \exp\left(-||\tilde{z}^{*} - p_{n}||\right) \]

由于是软决策树,一个样本会以一定概率到达所有叶子节点。从根节点到叶子节点 \(\ell\) 的路径记为 \(\mathcal{P}_{\ell}\),样本到达该叶子的概率 \(\pi_{\ell}\) 是路径上所有边概率的乘积:

\[\pi_{\ell}(z) = \prod_{e \in \mathcal{P}_{\ell}} p_{e}(z) \]

最终的类别概率分布 \(\hat{y}\) 是所有叶子节点类别分布的加权平均,权重即为对应的路径概率:

\[\hat{y}(x) = \sum_{\ell \in \mathcal{L}} \sigma(c_{\ell}) \cdot \pi_{\ell}(f(x;\omega)) \]

一个 Prototree 对实例进行预测的图例如下所示:
image

训练 ProtoTree

训练一个 ProtoTree 需要学习以下三组参数,通过最小化预测类别分布 \(\hat{y}\) 与真实标签 \(y\) 之间的标准交叉熵损失进行优化。

ProtoTree 参数 说明
CNN参数(\(\omega\) 卷积神经网络 \(f\) 的参数,用于从输入图像中提取有意义的特征表示
原型参数(\(P\) 所有内部节点中可训练的原型集合,用于决策路由
叶子分布参数(\(c\) 所有叶子节点上的类别分布对数几率(logits),用于最终预测

模型中的不同组件的初始化操作如下:

模型组件 初始化操作
树结构 设定一个最大高度 \(h\) 来初始化一个完整的二叉树。这会产生 \(2^{h}-1\) 个内部节点(即原型数量 \(|P|\))和 \(2^{h}\) 个叶子节点。计算复杂度随树高 \(h\) 呈指数增长
CNN骨干网络 需要一个预训练的CNN(如在ImageNet或特定任务上预训练)
原型初始化 原型 \(P\) 中的张量通过从分布 \(\mathcal{N}(0.5, 0.1)\) 中采样进行初始化
叶子分布初始化 所有叶子节点的参数 \(c_{\ell}^{(1)}\) 被初始化为零,使得初始的类别分布 \(\sigma(c_{\ell}^{(1)})\) 是均匀的

将叶子参数 \(c\)\(\omega\)\(P\) 一同通过反向传播优化会使得优化问题过于复杂,导致较差的分类结果。本文没有采用计算开销巨大的传统方法,提出了一种将小批量梯度下降与无导数更新相结合的高效算法,如下伪代码所示:
image

该算法在每个训练周期(epoch)的小批量迭代中交织进行。首先更新 \(\omega\)\(P\),对于一个小批量数据 \((x_b, y_b)\),计算预测 \(\hat{y}_b\) 和损失,然后通过梯度下降更新 CNN 参数 \(\omega\) 和原型参数 \(P\)。接着增量式更新 \(c\),对于每个叶子节点 \(\ell\) 基于当前小批量数据计算出的信息,来增量式地更新叶子分布参数 \(c_{\ell}\)。公式如下,其中 \(t\) 是周期索引,\(B\) 是批次大小,\(\odot\)\(\oslash\) 分别表示逐元素乘法和除法。
$$c_{\ell}^{(t+1)} = c_{\ell}^{(t)} - \frac{1}{B} \cdot c_{\ell}^{(t)} + \frac{1}{B} \cdot \left[ \sum_{(x_b, y_b) \in \text{batch}} \left( \sigma(c_{\ell}^{(t)}) \odot y_b \odot \pi_{\ell} \right) \oslash \hat{y}_b \right]$$

为了提升模型的全局可解释性,一个关键策略是减少其解释规模(explanation size),即模型做出决策时所依据的因素数量。在ProtoTree 中,这直接对应于原型(prototype)的数量。训练过程中 ProtoTree 可以使用以下可选的剪枝步骤:

  • 剪枝准则:设定一个阈值 \(\tau\)(略大于 \(1/K\), \(K\)为类别数)。如果一个叶子节点 \(\ell\) 的最大类别概率 \(\max(\sigma(c_{\ell})) \leq \tau\),则认为该叶子缺乏判别力,将其剪除。
  • 子树剪枝:如果某个子树 \(T' \subset T\) 中的所有叶子都被剪除,那么整个子树 \(T'\) 及其原型都可以被移除,并进一步简化树结构(移除变得多余的父节点)。该步骤效果如下图所示,它可以显著减少模型中的原型数量,使其更紧凑、更易于理解。
    image

原型可视化

为了使学习到的原型能够被人理解,必须将潜在空间中的原型张量 \(p_n\) 映射回可理解的图像块。类似于 ProtoPNet,在训练结束后执行原型替换操作,将每个潜在原型 \(p_n\) 替换为整个训练集中与其最相似的潜在图像块 \(\tilde{z}_n^*\)

\[p_n \leftarrow \tilde{z}_n^*, \quad \tilde{z}_n^* = \underset{z \in \{f(x), \forall x \in \mathcal{T}\}}{\operatorname{argmin}} ||\tilde{z}^* - p_n|| \]

与 ProtoPNet 在训练期间每 10 个周期进行替换不同,因为 ProtoTree 的路由机制已隐式地将原型优化到接近某个训练图像块,因此 ProtoTree 在训练后替换即可。可视化过程如下:

  1. 找到产生最近潜在块 \(\tilde{z}_n^*\) 的训练图像 \(x_n^*\)
  2. \(x_n^*\) 再次输入网络,计算其潜在表示 \(z = f(x_n^*)\)
  3. 计算原型 \(p_n\)\(z\)每一个图像块的相似度,生成一个二维的相似度图 \(S_n\):$$S_{n}^{(i,j)} = \exp(-||\tilde{z}^{(i,j)} - p_n||)$$ 其中 \((i, j)\) 表示块在特征图上的位置。
  4. 将该相似度图通过双三次插值上采样到输入图像 \(x_n^*\) 的尺寸。
  5. \(x_n^*\) 上,于最近块 \(\tilde{z}_n^*\) 对应的位置,用一个矩形框标出原型所代表的图像区域。

image

确定性推理

在训练时,ProtoTree 是一个软决策树,所有节点和路径都对最终预测有贡献。因为人类更习惯清晰的、序列化的决策路径,ProtoTree 软决策有利于梯度传播,但不利于解释。因此在测试阶段,ProtoTree 可以转换为硬决策树,本文提供两种策略:

  1. 最大路径概率(Max \(\pi\):选择路径概率 \(\pi_{\ell}\) 最高的叶子节点的类别作为预测结果。
  2. 贪婪路径(Greedy):在树的每个内部节点 \(n\) 上,如果向右的概率 \(p_{e(n, n.right)} > 0.5\) 则向右走,否则向左走,从而形成一条唯一的决策路径。

实验结果

数据集和实验设置

实验在两个细粒度图像识别基准数据集上进行,分别是 CUB-200-2011 (CUB,包含 200 种鸟类和 11,788 张图像)和 Stanford Cars (CARS,包含 196 种汽车类型和 16,185 张图像)。
模型配置方面,骨干网络使用 ResNet50 的卷积层作为特征提取器 \(f\),CUB 数据集使用在 iNaturalist2017 上预训练的权重,CARS 使用 ImageNet 预训练权重。通过交叉验证选择原型深度,CUB 为 \(D=256\),CARS 为 \(D=128\)。原型尺寸为 \(H_1 = W_1 = 1\),与 ProtoPNet 保持一致。图像尺寸统一调整为 \(224\times 224\),保证公平比较。

对比实验

比较结果如下表所示,在 CUB 数据集上 ProtoTree 达到 82.2% 的准确率,显著优于 ProtoPNet 的 79.2%,同时原型数量减少约 90%。5 个ProtoTree 的集成在 CUB 上达到 87.2% 准确率,在 CARS 上达到 91.5%,接近甚至超越非可解释的 state-of-the-art 方法。
image

树的高度影响

树高度影响分析表明,存在一个最优树高度范围。当树高度设置使得叶子节点数不少于类别数时(\(2^h \geq K\)),模型能够获得最佳性能。下图显示了随着高度增加 Prototree 的准确性,它证实了设置初始高度 h 是明智的,这样叶的数量至少与类的数量 K 一样大。对于 CUB,准确性增加到一定高度(h = 9)之后准确性趋于稳定。
image

剪枝与原型替换的影响

剪枝效果如下表所示,可见剪枝后模型的预测准确率几乎无变化(CUB: 82.206% → 82.199%),CUB 数据集剪枝率达 60.5%,CARS 达 90.5%,模型复杂度大幅降低。
image

确定性推理策略评估

软决策与硬决策对比发现,最大路径概率策略的准确率为 82.19%,与软决策(82.20%)几乎相同,保真度达0.999。贪婪策略的准确率 82.07%,保真度0.987,性能轻微下降但仍在可接受范围。平均路径长度 8.3 步,最大仅需追踪 9 个原型即可完成解释。这表明 ProtoTree 可以安全转换为确定性树,极大提升单样本预测的可解释性,同时保持高准确性。
image

可视化分析

下图显示了在 CUB 上训练的 Prototree 的一个片段和一个本地解释。分析表明,大多数学习到的原型具有明确的感知意义,能够成功聚类外观相似的类别,模型自动发现了有意义的视觉特征组合用于区分细粒度类别。ProtoTree 能够同时暴露数据集中存在的偏差,例如某些鸟类与特定背景(如树叶、天空)的虚假相关性,这为模型审计和偏差修正提供了宝贵洞察。
image

优点和创新点

个人认为,本文有如下一些优点和创新点可供参考学习:

  1. 提出的神经原型树 ProtoTree 将卷积神经网络的特征提取能力、原型的可解释性以及决策树的层次化推理结构结合,实现了全局和局部均透明的模型设计;
  2. 通过树结构的路径共享机制,ProtoTree 将所需原型数量减少了约90%,降低了模型的解释复杂度,并取得了优于原型模型 ProtoPNet 的准确率;
  3. 针对叶子节点分布参数优化问题,创新性地将小批量梯度下降与无导数更新策略交织进行,在保证模型收敛的前提下,将训练效率提升近一倍;
  4. 通过引入基于判别力的剪枝、训练后原型可视化替换以及软决策到硬决策的转换策略,确保了最终模型既紧凑易于理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/960043.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025.11 NOIP 集训模拟赛选记

希望能学到点儿东西吧。 Public NOIP Round #8 【NOIP Round #8】位集 【NOIP Round #8】偷塔 【NOIP Round #8】降雨 【NOIP Round #8】矩阵 Public NOIP Round #7 【NOIP Round #7】填写数字 【NOIP Round #7】排列计…

从指令遵循到价值对齐:医疗大语言模型的进阶优化、对齐与工具集成综合技能白皮书

从指令遵循到价值对齐:医疗大语言模型的进阶优化、对齐与工具集成综合技能白皮书pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font…

Mrakdown - YQR

Mrakdown标题# **粗体 *斜体 ***粗斜 ~~删除引用 ---或***分割 ----图 超链接列表 1,代码

20232322 2025-2026-1 《网络与系统攻防技术》实验四实验报告

一.实验内容恶意代码的文件类型识别,脱壳与字符串提取。 使用IDA Pro静态或动态分析所给exe文件,找到输出成功信息的方法。 分析自制恶意代码样本并撰写报告。 取证分析实践。二.实验目的掌握恶意代码的分析技术,像…

高级语言程序设计第四节个人作业

这个作业属于哪个课程:https://edu.cnblogs.com/campus/fzu/gjyycx 这个作业要求在哪里: https://edu.cnblogs.com/campus/fzu/gjyycx/homework/14577 学号:102500426 姓名:康凯帆 Fan.: 11-08 22:34:19Fan.: 11-0…

Vue3 项目首屏加载性能优化全攻略 - 详解

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

2025.11.8 测试

2025.11.8 测试T1 刚开场想到贪心方向,然后一直在证明 然后证了1个多小时,证明没证出来,但思路理清了 就是有一种情况,就是最小值和某个非最大值搭配,不会写,但举了几个例子都不会出现这种情况成为最优解 然后就…

C# 变量详解:从基础概念到高级应用 - 实践

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

CF285G AGC003D

给定 \(n, k\),问有多少个长度为 \(n\) 的排列 \(p\),满足恰好有 \(k\) 个 \(i\) 使得 \(|p_i - i| = 1\)(称这个 \(i\) 为好的)。 \(k \le n \le 1000\)令 \(g(k)\) 表示恰好有 \(k\) 个好的 \(i\) 的排列数。 这…

用 Kubernetes 原生机制取代 Nacos 注册中心:可行性、代价与边界

我在使用k8s部署 Java 分布式应用时发现,k8s自带服务发现功能,而且K8s提供的Service、DNS、ConfigMap 等级制似乎能完全替代Nacos的注册中心和配置中心。 Kubernetes 的 Service + Endpoints + CoreDNS 机制,本质上…

获取设置开发授权激活统信uos

获取设置开发授权激活统信uos申请开发授权

AtCoder Beginner Contest 431 ABCDEF 题目解析

A - Robot Balance 题意 一个机器人由一个头部零件与一个身体零件组成。如果头部零件的重量大于身体零件的重量,机器人就会摔倒。 目前,高桥有一个头部零件和一个身体零件,头部零件的重量为 \(H\) 克,身体零件的重…

基于单片机的智能洗碗机设计 - 指南

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

实用指南:AI学习日记——深度学习

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

赫尔曼黑塞《德米安》—生活之难,难在直面内心的自己

《德米安》开篇的第一句话: 我所渴望的, 无非是试着依我内心自发的本性去生活。为何如此之难?生活的难,似乎是刻在人生里的底色。生老病死的必然,悲欢起落的无常,得到时的辗转,失去时的拉扯。我们总轻易遗忘快乐…

安装openjdk21

安装openjdk211、打开应用商店,搜索openjdk,搜索结果列出多个版本的openjdk,如openjdk8、openjdk19、openjdk21等。 2、可以点击对应图标,进入详细信息查看版本,并进行安装。 3、安装后打开,如打开openjdk(长期维护…

中科麒麟passwd弱密码授权

中科麒麟桌面版默认拒绝“123456”这类弱密码,报错 “无效的密码:没有足够的字符种类”。 下面把亲测可行的修改步骤贴出来,复制-粘贴即可。1. 打开密码策略文件 sudo nano /etc/pam.d/common-password2. 定位到 pa…

暴字迹

都是平常笔记一类的字迹所以写的很潦草( 宣:CSP 2025 游记:https://www.luogu.com.cn/article/fz1ol19h CSP 2025 GD 迷惑行为大赏:https://www.luogu.com.cn/article/dihhq10t

体验CodeBuddy免费领取轻量云服务器

近期 AI 编程热潮席卷行业,各大科技厂商纷纷布局 AI IDE 赛道,推出专属开发平台。 腾讯也顺势入局,正式发布自研 AI IDE 工具 CodeBuddy。依托腾讯完善的产品生态,CodeBuddy 带来了一大核心亮点功能 ——“一句话落…

Git 命令完全手册

Git 命令完全手册 目录Git 基础配置 仓库操作 核心常用命令 分支操作 远程协作 查看信息 撤销与回退 标签管理 高级操作 故障排查1. Git 基础配置 # 查看配置 git config --list git config --global --list# 设置用户…