ICLR盲审阶段就被评审赞不绝口的论文:会是Transformer架构的一大创新吗?

96b3accdc71199a4fbc5b583b309e75d.jpeg

编|杜伟、陈萍
源|机器之心

首次!无残差连接或归一化层,也能成功训练深度transformer。

尽管取得了很多显著的成就,但训练深度神经网络(DNN)的实践进展在很大程度上独立于理论依据。大多数成功的现代 DNN 依赖残差连接和归一化层的特定排列,但如何在新架构中使用这些组件的一般原则仍然未知,并且它们在现有架构中的作用也依然未能完全搞清楚。

残差架构是最流行和成功的,最初是在卷积神经网络(CNN)的背景下开发的,后来自注意力网络中产生了无处不在的 transformer 架构。残差架构之所以取得成功,一种原因是与普通 DNN 相比具有更好的信号传播能力,其中信号传播指的是几何信息通过 DNN 层的传输,并由内核函数表示。

最近,使用信号传播原则来训练更深度的 DNN 并且残差架构中没有残差连接和 / 或归一化层的参与,成为了社区感兴趣的领域。原因有两个:首先验证了残差架构有效性的信号传播假设,从而阐明对 DNN 可解释性的理解;其次这可能会实现超越残差范式的 DNN 可训练性的一般原则和方法。

对于 CNN,Xiao et al. (2018)的工作表明,通过更好初始化提升的信号传播能够高效地训练普通深度网络,尽管与残差网络比速度显著降低。Martens et al. (2021) 的工作提出了 Deep Kernel Shaping (DKS),使用激活函数转换来控制信号传播,使用 K-FAC 等强二阶优化器在 ImageNet 上实现了普通网络和残差网络的训练速度相等。Zhang et al. (2022) 的工作将 DKS 扩展到了更大类的激活函数,在泛化方面也实现了接近相等。

信号传播中需要分析的关键量是 DNN 的初始化时间内核,或者更准确地说,是无限宽度限制下的近似内核。对于多层感知机(MLP)以及使用 Delta 初始化的 CNN,该内核可以编写为仅包含 2D 函数的简单层递归,以便于进行直接分析。跨层 transformer 的内核演化更加复杂,因此 DKS 等现有方法不适用 transformer 或实际上任何包含自注意力层的架构。

在 MLP 中,信号传播是通过查看(一维)内核的行为来判断的,而 transformer 中的信号传播可以通过查看(高维)内核矩阵在网络层中的演化来判断。

该研究必须避免一种情况:对角线元素随深度增加快速增长或收缩,这与不受控制的激活范数有关,可能导致饱和损失或数值问题。避免秩崩溃(rank collapse)对于深度 transformer 的可训练性是必要的,而是否可以训练深度无残差 transformer 仍是一个悬而未决的问题。

ICLR 2023 盲审阶段的这篇论文解决了这个问题,首次证明了无需残差连接或归一化层时也可能成功训练深度 transformer。为此,他们研究了深度无残差 transformer 中的信号传播和秩崩溃问题,并推导出三种方法来阻止它们。具体而言,方法中使用了以下组合:参数初始化、偏置矩阵和位置相关的重缩放,并强调了 transformer 中信号传播特有的几种复杂性,包括与位置编码和因果掩蔽的交互。研究者实证证明了他们的方法可以生成可训练的深度无残差 transformer。

在实验部分,在 WikiText-103 和 C4 数据集上,研究者展示了使用他们主要的方法——指数信号保持注意力(Exponential Signal Preserving Attention, E-SPA),可以通过延长大约五倍的训练时间使得标准 transformer 与文中无残差 transformer 的训练损失相当。此外通过将这一方法与残差连接结合,研究者还表明无归一化层的 transformer 能够实现与标准 transformer 相当的训练速度。

5d8bd4e5a499dd0c75f4b15152b54ffa.png

论文地址:https://openreview.net/pdf?id=NPrsUQgMjKK

对于这篇论文,Google AI 首席工程师 Rohan Anil 认为是 Transformer 架构向前迈出的一大步,还是一个基础性的改进。

d22fadec19cb3aa6f8ce00e349f928e7.png

构造无捷径可训练的深层 Transformer

迄今为止,纠正 Transformer 秩崩溃(rank collapse)的唯一策略依赖于残差连接,该方式跳过了自注意力层固有的可训练性问题。与此相反,该研究直接解决这个问题。首先通过注意力层更好地理解信号传播,然后根据见解(insights)进行修改,以在深度 transformer 中实现对忠实信号的传输,无论是否使用残差连接,都可以对信号进行训练。

具体而言,首先,该研究对仅存在注意力的深度 vanilla transformer 进行了一下简单设置,之后他们假设该 transformer 具有单一头(h = 1)设置或具有多头设置,其中注意力矩阵 A 在不同头之间不会变化。如果块 l≤L 初始化时有注意力矩阵 A_l,则最终块的表示形式为 X_L:

3a84e1ec9b0835088a1b1a13aa848ac4.png

对于上式而言,如果和采用正交初始化,那么就可以在初始化时正交。

在上述假设下,如果采用表示跨位置输入核矩阵,经过一些简化处理后,可以得到如下公式:

e1a763f1e9c513d7f40804e73e36fe2c.png

从这个简化公式(深度仅注意力 transformer 中的核矩阵)中,可以确定对 (A_l)_l 的三个要求:

  1. 必须在每个块中表现良好,避免退化情况,如秩崩溃和爆炸 / 消失的对角线值;

  2. A_l 必须是元素非负 ∀l;

  3. A_l 应该是下三角∀l,以便与因果掩码注意力兼容。

在接下来的 3.1 和 3.2 节中,该研究专注于寻找满足上述需求的注意力矩阵,他们提出了 3 种方法 E-SPA、U-SPA 和 Value-Skipinit,每种方法都用来控制 transformer 的注意力矩阵,即使在很深的深度也能实现忠实的信号传播。此外,3.3 节演示了如何修改 softmax 注意力以实现这些注意力矩阵。

下图中,该研究对提出的两个 SPA 方案进行了验证,U-SPA 和 E-SPA,结果显示即使在网络较深时也能成功地避免仅注意力 vanilla transformers 中的秩崩溃现象。

a7077fbc5acd7cb47e8f1dce750696b5.png

实验

WikiText-103 基线:首先,该研究验证了没有残差连接的标准深度 transformer 是不可训练的,即使它们有归一化层 (LN) 和 transformed 激活,但本文的方法可以解决这个问题。如图 2 所示,可以清楚地看到,从标准 transformer 中移除残差连接使其不可训练,训练损失稳定在 7.5 左右。正如图 1 所示,标准 transformer 遭受了秩崩溃。

16d9c8e7f77bf167aef42bdfe3892807.png

另一方面,该研究提出的 E-SPA 方法优于 U-SPA 和 Value-Skipinit。然而,与本文无残差方法相比,带有残差和 LN 的默认 transformer 仍然保持训练速度优势。

在表 1 中,该研究使用提出的方法评估了 MLP 块中不同激活函数的影响,以及 LN 在无残差 transformer 的使用。可以看到在深度为 36 处,本文方法针对一系列激活实现了良好的训练性能:DKS-transformed GeLU、TAT-transformed Leaky ReLU 以及 untransformed GeLU ,但不是 untransformed Sigmoid。通过实验还看到,层归一化对于训练速度而言相对不重要,甚至在使用 SPA 时对 transformed activation 的激活有害,因为 SPA 已经具有控制激活规范的内置机制。

10afdbfa51f6408a1b4448660b83377a.png

在图 3 中,我们看到一种不需要更多迭代就能匹配默认 transformer 训练损失的方法是使用归一化残差连接。

12fe90103194ef6be3b12aea25324cb0.png

表 2 显示带有归一化残差和 LN 的 E-SPA 优于默认的 PreLN transformer。

c5bb2b5efa688932b84e0be0f0b75443.png

下图 4(a)表明 E-SPA 再次优于其他方法;4(b)表明训练损失差距可以通过简单地增加训练时间来消除。

0b298fe74ec213f89005aba90f5251b0.png

d6f805a806a87d1353e287c245971563.jpeg后台回复关键词【入群

加入卖萌屋NLP、CV、搜推广与求职讨论群

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/476570.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python自带的shell、其性能优于ipython_python3.4 shell

实验1 目的和要求(1) (2) (3) (4) (5) (6) (7) 开始 python 编程 了解什么是 python? 了解 python 的特性 学习下载和安装 python 学习执行 python ...... (2) 了解 python 的特性 (3) 学习下载和安装 python (4) 学习执行 python 命令和脚本文件的方法 (5) 学习 python 语音的…

程序员面试金典 - 面试题 08.13. 堆箱子(DP)

1. 题目 堆箱子。给你一堆n个箱子,箱子宽 wi、深 di、高 hi。 箱子不能翻转,将箱子堆起来时,下面箱子的宽度、高度和深度必须大于上面的箱子。 实现一种方法,搭出最高的一堆箱子。箱堆的高度为每个箱子高度的总和。 输入使用数组…

一篇文章带你了解国企程序员(超详细)

源|Hollis、国企程序锅最近互联网大厂裁员不断,今天为大家带来了一篇超详细国企程序员攻略。文章目录入职国企心得体会一、入职前二、入职后三、工作开发内容四、钱总结:北京户口相关问题一、北京户口咋获得?二、北京户口有啥用&a…

Delphi作为客户端调用.Net写的WCF服务端?

这方面的文章太少了,查了半天也只看到一两篇,关键点 1.wcf的Binding要配成 basicHttpBinding,否则Delphi无法通过WebService的方式调用服务 2.Delphi IDE中的导入WSDL和安装目录下的wsdlimp.exe都无法正确识别WCF消息,需要到网上下…

html表单代码例子_关于React的这些细节,你知道吗?-表单

在 React 里,HTML 表单元素的工作方式和其他的 DOM 元素有些不同,这是因为表单元素通常会保持一些内部的 state。例如这个纯 HTML 表单只接受一个名称:此表单具有默认的 HTML 表单行为,即在用户提交表单后浏览到新页面。如果你在 …

程序员面试金典 - 面试题 17.23. 最大黑方阵(DP)

1. 题目 给定一个方阵,其中每个单元(像素)非黑即白。 设计一个算法,找出 4 条边皆为黑色像素的最大子方阵。 返回一个数组 [r, c, size] ,其中 r, c 分别代表子方阵左上角的行号和列号,size 是子方阵的边长。 若有多个满足条件的…

js的oop方式和this指针问题

js的oop其实很简单,function本身就充当了类和构造函数的角色。然后通过传给构造函数的参数,完成类属性的赋值,从而实际化不同的对象。可是,js的oop也有很让人头疼的地方,其中之一就是this的指向。在js中,普…

史诗级学术骗局!一博士狂编 200 多篇论文,被揭发后畏罪自杀....

源|募格学术一博士狂编200多篇论文,被揭发后畏罪自杀,可他造成的撤稿影响直至今日还在继续,更有人称其的造假为科学史上最大的学术骗局之一。狂编200多篇论文发表,这个博士有点狠在著名学术打假网站 Retraction Watch …

docker安装_Docker安装

简介:Docker是一个供开发人员和系统管理员通过容器的方式构建、运行和共享应用程序的平台,通过容器的方式部署应用(打包成标准化单元,类似于一个集装箱),具有安全、灵活、轻量、松耦合、可移植、可扩展等特点。概念:仓…

LeetCode 1139. 最大的以 1 为边界的正方形(DP)

1. 题目 给你一个由若干 0 和 1 组成的二维网格 grid,请你找出边界全部由 1 组成的最大 正方形 子网格,并返回该子网格中的元素数量。如果不存在,则返回 0。 示例 1: 输入:grid [[1,1,1],[1,0,1],[1,1,1]] 输出&…

大幅超越DALL·E 2和Imagen,斯坦福发布RA-CM3模型,融合检索与生成

文|QvQ最近,DALL-E和CM3等模型在多模态任务尤其是图文理解上表现出色。然而,这些模型似乎需要将所有学到的知识存储都存储在模型参数中,这就不得不需要越来越大的模型和训练数据来获取更多的知识,俨然将bigger and bet…

什么是域名服务器(DNS)

问题:什么是域名服务器?域名服务器是什么意思? 域名服务器即DNS,全称是Domain Name Server,一种程序,它保存了一张域名(domain name)和与之相对应的IP地址 (IP address)的表,以解析消息的域名。…

python判断正负的函数_Python |在计算操作的函数内将负数转换为正数?

我一直在寻找将负数转换为正数,我发现了一些东西,但没有成功.. 这是一个来自在线Python页面的练习,我正在学习Python。 我希望你明白这一点。 这是去洛杉矶旅行,我用功能计算钱,但现在有一个问题,我“从洛杉…

LeetCode 1325. 删除给定值的叶子节点(递归)

1. 题目 给你一棵以 root 为根的二叉树和一个整数 target ,请你删除所有值为 target 的 叶子节点 。 注意,一旦删除值为 target 的叶子节点,它的父节点就可能变成叶子节点; 如果新叶子节点的值恰好也是 target ,那么…

[翻译] python Tutorial 之一

声明:本文做为IronPython-2.0 B3的Tutorial 中文译文,内容全部来自其英文原文,其中本人认为存在疑问的或翻译不当之处会用原文中的内容加以标记,且本文内容完全用于研 究和学习IronPython 之用,限于本人英文翻译功底有…

用python控制钉钉软件_Python—实现钉钉后台开发

二、实现钉钉免登流程 免登流程分四步:1、前端获取钉钉免登授权码code;2、后端获取access_token;3、使用授权码code和access_token换取用户userid;4、通过access_token和userid换取用户详情userinfo。 前端获取授权码code。// 获取…

LeetCode 1123. 最深叶节点的最近公共祖先(递归比较子树高度)

1. 题目 给你一个有根节点的二叉树,找到它最深的叶节点的最近公共祖先。 回想一下: 叶节点 是二叉树中没有子节点的节点树的根节点的 深度 为 0,如果某一节点的深度为 d,那它的子节点的深度就是 d1如果我们假定 A 是一组节点 S…

万字综述:目标检测模型YOLOv1-v7深度解析

文|Rocky Ding源|WeThinkln大家好,我是Rocky。近年来YOLO系列层出不穷,更新不断,已经到v7版本。Rocky认为不能简单用版本高低来评判一个系列的效果好坏,YOLOv1-v7不同版本各有特色,在不同场景&a…

python手枪_Python入门,爬虫训练——枪械查询

一、效果图:二、怎么做到的? 1,首先安装requests、bs4. 这两个第三方模块。 我们按住winR 在弹出来的窗口上输入cmd,来到命令窗口,输入pip install requests、pip install bs4即可,网速慢的可以切换至国内源…

LeetCode 865. 具有所有最深结点的最小子树(递归)

1. 题目 给定一个根为 root 的二叉树,每个结点的深度是它到根的最短距离。 如果一个结点在整个树的任意结点之间具有最大的深度,则该结点是最深的。 一个结点的子树是该结点加上它的所有后代的集合。 返回能满足“以该结点为根的子树中包含所有最深的…