深度学习论文阅读之【Distilling the Knowledge in a Neural Network】提炼神经网络中的知识

论文:link
代码:link

摘要

  提高几乎所有机器学习算法性能的一个非常简单的方法是在相同的数据上训练许多不同的模型,然后对它们的预测进行平均[3]。不幸的是,使用整个模型集合进行预测非常麻烦,并且计算成本可能太高,无法部署到大量用户,尤其是在单个模型是大型神经网络的情况下。 Caruana 和他的合作者 [1] 已经证明,可以将集成中的知识压缩到单个模型中,该模型更容易部署,并且我们使用不同的压缩技术进一步开发了这种方法。我们在 MNIST 上取得了一些令人惊讶的结果,并且表明我们可以通过将模型集合中的知识提炼为单个模型来显着改进频繁使用的商业系统的声学模型。我们还引入了一种由一个或多个完整模型和许多专业模型组成的新型集成,这些模型学习区分完整模型混淆的细粒度类别。与专家的混合不同,这些专业模型可以快速并行地进行训练。

1.Introduction

  许多昆虫都有幼虫形态和完全不同的成虫形态,幼虫形态可以经过优化,可以从环境中获取能量和营养,成虫形态可以满足不同的旅行和繁殖要求,在大规模机器学习中,我们通常在训练阶段和部署阶段使用非常相似的模型,尽管它们的要求非常不同,对于语音和对象识别等任务,训练必须从非常大、高度冗余的数据集中提取结构,但它并不需要这样做,需要实时操作,并且会使用大量的计算量。
  然而,部署到大量用户对延迟和计算资源有更严格的要求。与昆虫的类比表明,如果可以更轻松地从数据中提取结构,我们应该愿意训练非常繁琐的模型。繁琐的模型可能是单独训练的模型的集合,也可能是使用非常强大的正则化器(例如 dropout)训练的单个非常大的模型[9]。一旦繁琐的模型经过训练,我们就可以使用不同类型的训练,我们称之为“蒸馏”,将知识从繁琐的模型转移到更适合部署的小模型。 Rich Caruana 及其合作者已经率先提出了该策略的一个版本 [1]。在他们的重要论文中,他们令人信服地证明,通过大型模型集合获得的知识可以转移到单个小型模型中。通常认为模型学习到的参数代表了知识,无法直接迁移,但教师网络预测结果中各类别概率的相对大小也隐式包含知识。

2.Distillation

  神经网络通常通过使用“softmax”输出层来生成类别概率,该输出层通过将 z i z_i zi与其他logits进行比较。将为每一个类别计算的logits的 z i z_i zi转换为概率 p i p_i pi.
q i = exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) {q_i} = \frac{{\exp \left( {{z_i}/T} \right)}}{{{\sum _j}\exp \left( {{z_j}/T} \right)}} qi=jexp(zj/T)exp(zi/T)
其中 T 是温度,通常设置为 1。使用较高的 T 值会在类别上产生较软的概率分布。在最简单的蒸馏形式中,知识被转移到蒸馏模型中,方法是在转移集上进行训练,并使用转移集中每种情况的软目标分布,该软目标分布是通过使用其 softmax 中温度较高的繁琐模型生成的。训练蒸馏模型时使用相同的高温,但训练后它使用温度 1。
一种方法是使用正确的标签来修改软目标,但我们发现更好的方法是简单地使用两个不同目标函数的加权平均值。第一个目标函数是与软目标的交叉熵,并且该交叉熵是使用蒸馏模型的 softmax 中与用于从繁琐模型生成软目标相同的高温来计算的。第二个目标函数是具有正确标签的交叉熵。这是在蒸馏模型的 softmax 中使用完全相同的对数计算的,但温度为 1。我们发现,通常通过在第二个目标函数上使用相当低的权重来获得最佳结果。由于软目标产生的梯度大小为 1/T 2 ,因此在使用硬目标和软目标时将其乘以 T 2 非常重要。这确保了如果在元参数实验时用于蒸馏的温度发生变化,硬目标和软目标的相对贡献保持大致不变。

2.1 匹配logits是蒸馏的一个特例

  传输集中每个案例都贡献一个交叉熵梯度 d C / d z i dC/d{z_i} dC/dzi,相当于蒸馏模型的每个logit z i z_i zi,并且繁琐的模型具有产生软目标概率 p i p_i pi的logits v i v_i vi,并且转移训练是在温度T下完成的,则该梯度由下式给出:
∂ C ∂ z i = 1 T ( q i − p i ) = 1 T ( e z i / T ∑ j e z j / T − e v i / T ∑ j e v j / T ) \frac{{\partial C}}{{\partial {z_i}}} = \frac{1}{T}\left( {{q_i} - {p_i}} \right) = \frac{1}{T}\left( {\frac{{{e^{{z_i}/T}}}}{{{\sum _j}{e^{{z_j}/T}}}} - \frac{{{e^{{v_i}/T}}}}{{{\sum _j}{e^{{v_j}/T}}}}} \right) ziC=T1(qipi)=T1(jezj/Tezi/Tjevj/Tevi/T)
如果温度与logits的大小相比比较高,我们可以近似:
∂ C ∂ z i ≈ 1 T ( 1 + e z i / T N + ∑ j z j / T − 1 + e v i / T N + ∑ j v j / T ) \frac{{\partial C}}{{\partial {z_i}}} \approx \frac{1}{T}\left( {\frac{{1 + {e^{{z_i}/T}}}}{{N + {\sum _j}{z_j}/T}} - \frac{{1 + {e^{{v_i}/T}}}}{{N + {\sum _j}{v_j}/T}}} \right) ziCT1(N+jzj/T1+ezi/TN+jvj/T1+evi/T)
如果我们现在假设每个转移情况的logits都是零均值的,则 ∑ j z j = ∑ j v j = 0 {\sum _j}{z_j} = \sum {}_j{v_j} = 0 jzj=jvj=0,原式可简化为:
∂ C ∂ z i ≈ 1 N T 2 ( z i − v i ) \frac{{\partial C}}{{\partial {z_i}}} \approx \frac{1}{{N{T^2}}}\left( {{z_i} - {v_i}} \right) ziCNT21(zivi)
  因此,在高温极限下,蒸馏相当于最小化 1 / 2 ( z i − v i ) 2 1/2(z_i-v_i)^2 1/2(zivi)2 ,前提是每个分动箱的 logits 分别为零均值。在较低的温度下,蒸馏很少关注比平均值负得多的匹配 logits。这是潜在的优势,因为这些逻辑几乎完全不受用于训练繁琐模型的成本函数的约束,因此它们可能非常嘈杂。另一方面,非常负的逻辑可能会传达有关通过繁琐模型获得的知识的有用信息。这些影响中哪一个占主导地位是一个经验问题。我们表明,当蒸馏模型太小而无法捕获繁琐模型中的所有知识时,中间温度效果最好,这强烈表明忽略大的负对数可能会有所帮助。

3.MNIST初步实验

  为了了解蒸馏的效果如何,我们在所有 60,000 个训练案例上训练了一个大型神经网络,该神经网络具有两个隐藏层,每个隐藏层包含 1200 个校正线性隐藏单元。该网络使用 dropout 和权重约束进行了强烈正则化,如 [5] 中所述。 Dropout 可以被视为训练共享权重的指数级大模型集合的一种方式。此外,输入图像在任何方向上抖动最多两个像素。该网络出现了 67 个测试错误,而具有两个隐藏层(由 800 个校正线性隐藏单元且无正则化)的较小网络出现了 146 个错误。但是,如果仅通过添加在 20 ℃ 的温度下匹配大网络产生的软目标的附加任务来对较小的网络进行正则化,则它会出现 74 个测试错误。这表明软目标可以将大量知识转移到蒸馏模型中,包括如何概括从翻译的训练数据中学到的知识,即使转移集不包含任何翻译。当蒸馏网络的两个隐藏层中每个都有 300 个或更多单位时,所有高于 8 的温度都会给出相当相似的结果。但当这从根本上减少到每层 30 个单位时,2.5 至 4 范围内的温度明显优于更高或更低的温度。然后,我们尝试从传输集中省略数字 3 的所有示例。所以从蒸馏模型的角度来看,3是一个它从未见过的神话数字。尽管如此,蒸馏模型仅出现 206 个测试错误,其中 133 个位于测试集中的 1010 个三元组上。大多数错误是由于第 3 类的学习偏差太低而引起的。如果此偏差增加 3.5(这会优化测试集的整体性能),则蒸馏模型会出现 109 个错误,其中 14 个错误位于 3 上。因此,在正确的偏差下,尽管在训练期间从未见过 3,但蒸馏模型在测试 3 中的正确率达到 98.6%。如果传输集仅包含训练集中的 7 和 8,则蒸馏模型的测试误差为 47.3%,但当 7 和 8 的偏差减少 7.6 以优化测试性能时,测试误差将降至 13.2%。

discussion

  我们已经证明,蒸馏对于将知识从集成或从大型高度正则化模型转移到较小的蒸馏模型非常有效。在 MNIST 上,即使用于训练蒸馏模型的传输集缺少一个或多个类的任何示例,蒸馏也能表现得非常好。对于 Android 语音搜索所使用的深度声学模型版本,我们已经证明,通过训练深度神经网络集合所实现的几乎所有改进都可以被提炼为相同大小的单个神经网络,部署起来要容易得多。对于非常大的神经网络,甚至训练一个完整的集合也是不可行的,但是我们已经证明,经过很长时间训练的单个非常大的网络的性能可以通过学习大量的专家来显着提高网络,每个网络都学会区分高度混乱的集群中的类别。我们还没有证明我们可以将专家的知识提炼回单一的大网络中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/777781.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HTML文本信息

标题 使用h1~h6标签定义标题。通常一个HTML网页只有一个主标题和副标题&#xff0c;主标题和副标题分别使用h1和h2表示。 <h1>主标题</h1> <h2>副标题</h2><p>正文</p>段落 p元素用来表示段落文本。通常用来显示大片的文字。每一个p元素…

中国信通院 X StarRocks金融用户社区正式成立

在国家战略的推动下&#xff0c;开源技术正逐渐成为金融行业创新发展的重要驱动力。2024 年 3 月 26 日&#xff0c;中国信息通信研究院 X StarRocks 金融用户社区&#xff08;以下简称“社区”&#xff09;正式成立&#xff0c;这一举措旨在深化国内金融领域的开源生态建设&am…

粗略总结AI大模型学习需要了解的要点

目录 一、概念简介 二、兴起原因 三、相关要点 四、不足之处 五、总结 一、概念简介 AI大模型学习是指利用大规模数据集和强大计算能力进行深度学习模型的训练。随着数据的爆炸式增长和计算资源的提升&#xff0c;AI大模型学习成为了现代人工智能研究的重要方向。 二、兴起…

单元测试11213123231313131231231231

使用技术 junit Mockito s[romg 示例代码&#xff1a; SpringBootTest(classes启动类.class) public class AbstractTes{ MockBean protected A a; } AutoConfigureMockMvc(printOnlyOnFailure false) public abstract class AbstractWebTes extends AbstractTes imple…

使用pytorch构建一个初级的无监督的GAN网络模型

在这个系列中将系统的构建GAN及其相关的一些变种模型&#xff0c;来了解GAN的基本原理。本片为此系列的第一篇&#xff0c;实现起来很简单&#xff0c;所以不要期待有很好的效果出来。 第一篇我们搭建一个无监督的可以生成数字 (0-9) 手写图像的 GAN&#xff0c;使用MINIST数据…

精准测试——BCEL字节码检测

精准测试是通过源代码变更分析&#xff0c;确定改动代码影响的范围&#xff0c;从而进行针对性测试&#xff0c;进一步提升测试效率。不仅如此&#xff0c;精准测试还可以将测试用例与程序代码之间的逻辑映射关系建立起来&#xff0c;采集测试过程执行的代码逻辑及测试数据。怎…

Android--重构

重构不是一朝一夕的事情&#xff0c;是一个持续的过程 要注重代码注释&#xff0c;对创建的每一个页面&#xff0c;类&#xff0c;方法&#xff0c;关键变量都要有对应的注释&#xff0c;对于类要写明作者是谁&#xff0c;创建修改时间&#xff0c;还有是做什么。 这样对后面的…

入门指南|营销中人工智能生成内容的主要类型 [新数据、示例和技巧]

由于人工智能技术的进步&#xff0c;内容生成不再是一项令人头疼的任务。随着人工智能越来越多地接管手动内容制作任务&#xff0c;营销人员明智的做法是了解现有的不同类型的人工智能生成内容&#xff0c;以及哪些内容从中受益最多。这些工具可以帮助我们制作对您的受众和品牌…

Synchronized锁、公平锁、悲观锁乐观锁、死锁等

悲观锁 认为自己在使用数据的时候一定会有别的线程来修改数据,所以在获取数据前会加锁,确保不会有别的线程来修改 如: Synchronized和Lock锁 适合写操作多的场景 乐观锁 适合读操作多的场景 总结: 线程8锁🔐 调用 声明 结果:先打印发送短信,后打印发送邮件 结论…

【WPF应用16】WPF如何让Canvas上的元素响应鼠标点击事件?

在WPF中&#xff0c;要让Canvas上的元素响应鼠标点击事件&#xff0c;你需要为这些元素添加事件处理程序来处理MouseLeftButtonDown事件。这个事件会在鼠标左键被按下时触发。下面是一篇详细的博客&#xff0c;展示了如何在Canvas上的元素上添加鼠标点击事件处理程序。 1. Can…

AI大模型学习和实践

目录 第一章:AI大模型概述 1.1 什么是AI大模型? 1.2 AI大模型的发展历程 1.3 AI大模型的应用领域 1.4 AI大模型的挑战与机遇 第二章:数学基础与模型理论 2.1 数学在AI大模型学习中的重要性 2.1.1 线性代数 2.2.2 微积分 2.2.3 概率论与统计学 2.2、模型理论的基础…

机器学习(三)

神经网络: 神经网络是由具有适应性的简单单元组成的广泛并行互连的网络&#xff0c;它的组织能够模拟生物神经系统对真实世界物体所作出的交互反应。 f为激活(响应)函数: 理想激活函数是阶跃函数&#xff0c;0表示抑制神经元而1表示激活神经元。 多层前馈网络结构: BP(误差逆…

OpenPLC_Editor 在Ubuntu 虚拟机安装记录

1. OpenPLC_Editor在虚拟机上费劲的装了一遍&#xff0c;有些东西已经忘了&#xff0c;主要还是python3 的缺失库版本对应问题&#xff0c;OpenPLC_Editor使用python3编译的&#xff0c;虚拟机的Ubuntu 18.4 有2.7和3.6两个版本&#xff0c;所以需要注意。 2. OpenPLC_Editor …

Svg Flow Editor 原生svg流程图编辑器(四)

系列文章 Svg Flow Editor 原生svg流程图编辑器&#xff08;一&#xff09; Svg Flow Editor 原生svg流程图编辑器&#xff08;二&#xff09; Svg Flow Editor 原生svg流程图编辑器&#xff08;三&#xff09; Svg Flow Editor 原生svg流程图编辑器&#xff08;四&#xf…

Mac命令行查找SDK/JDK安装位置

要在命令行中查询 Android SDK Platform Tools 的安装位置,可以使用以下步骤: 使用 which 命令: 在命令行中执行以下命令: which adb这将输出 adb 命令的安装路径,通常情况下,它会在 Android SDK 的 platform-tools 目录下。 手动查找: 如果 which adb 没有输出,可以手…

unity中判断方向 用 KeyVertical ,KeyHorizontal 判断ui物体的 方向

float KeyVertical Input.GetAxis("Vertical"); float KeyHorizontal Input.GetAxis("Horizontal"); // 假设 UI 物体在竖直方向上为 Y 轴&#xff0c;水平方向上为 X 轴 Vector2 direction new Vector2(KeyHorizontal, KeyVertical); if (direction…

贪心算法--最大数

个人主页&#xff1a;Lei宝啊 愿所有美好如期而遇 本题链接https://leetcode.cn/problems/largest-number/description/ class Solution { public:bool static compare(int a, int b){return (to_string(a) to_string(b)) > (to_string(b) to_string(a));}bool operato…

幽默记忆TCP/UDP/DNS/三次握手

三次握手 把客户端和服务端比作两个小孩想象一下&#xff0c;你正在和朋友一起玩“猜拳”游戏&#xff0c;但是你们之间的通信线路不够稳定&#xff0c;为了确保游戏开始前大家都准备好了&#xff0c;你们进行了这样一段对话&#xff1a; 第一次握手&#xff1a;你对朋友说&am…

探索 2024 年 Web 开发最佳前端框架

前端框架通过简化和结构化的网站开发过程改变了 Web 开发人员设计和实现用户界面的方法。随着 Web 应用程序变得越来越复杂&#xff0c;交互和动画功能越来越多&#xff0c;这是开发前端框架的初衷之一。 在网络的早期&#xff0c;网页相当简单。它们主要以静态 HTML 为特色&a…

数据库---PDO

以pikachu数据库为例&#xff0c;数据库名&#xff1a; pikachu 1.连接数据库 <?php $dsn mysql:hostlocalhost; port3306; dbnamepikachu; // 这里的空格比较敏感 $username root; $password root; try { $pdo new PDO($dsn, $username, $password); var_dump($pdo)…