Pix2Seq:谷歌大脑提出 CV 任务统一接口!

edef627d82a868c2a83008ae38bfa084.png

文 | 青豆

最近一个大趋势就是将各类任务统一在一个大一统框架下。大规模预训练语言模型已成功打通各类文本任务,使得不同的NLP任务上,都可以用这种统一的sequence生成框架作为基础模型,只需要通过prompt的方式,指导模型生成目标结果。

这种大一统的sequence生成框架在NLP任务成功的关键是任务描述和任务输出都可以序列化成text tokens。

但CV任务输入输出都更加多样,那不是得为不同的任务定制不同的模型和损失函数?这也是CV任务大一统框架的瓶颈。

以自然语言为输出的任务,比如image captioning、visual question answering这类任务,天然可以转化为生成text token sequence。但模型的输出形式还存在很多其他的形式,例如bounding box、dense masks等。

Pix2Seq在这样的动机下诞生了:既然输出形式不同是难点,能否将各类输出形式都统一成token sequence?

去年Google Brain提出的Pix2Seq就以目标检测作为出发点,建立Pixel-to-Sequence的映射,探索了这种可能性(戳《图灵奖大佬+谷歌团队,为通用人工智能背书!CV 任务也能用 LM 建模!》)。

目前的Pix2Seq v2进一步统一了四个完全不同的视觉任务:目标检测(object detection)、实例分割(instance segmentation)、人体关键点检测(keypoint detection)、图像描述生成(image captioning),尽管他们的输出可以是bounding boxes,也可以是dense masks,都可以表示成token sequence。

这种离散的、统一化的表示,使得多种CV任务能够统一在一个模型架构或损失函数下。

对单个任务,不再需要对模型或损失函数做定制,而是只需要将任务描述放在prompt中,控制output sequence变成所需要的输出格式。

这种大一统的Pix2Seq框架,已经能够在这四个核心视觉任务上,媲美那些专门为各任务定制的state-of-the-art。

论文题目
A Unified Sequence Interface for Vision Tasks

论文链接:
https://arxiv.org/abs/2206.07669

94f552904e9f9665e404fb163d642439.png

436d6435c337630ed43e8fa743c870da.png背景介绍:4个视觉核心任务4d3693e7ba5a4e9426df33a7dae5ed7d.png

  • 目标检测(object detection):输入是一张图片,输出是所有object的bounding box和class label。

  • 实例分割(instance segmentation):输入是一张图片和其中的objects,输出是对每个object的dense pixel-wise mask。

  • 人体关键点检测(keypoint detection):输入是一张图片和其中的person objects,输出是keypoint坐标点来表示head、eyes等person instances。

  • 图像描述生成(image captioning):输入是一张图片,输出是一句话。

0ffca98b7debf94b322aed4b3fd38982.pngSequence建模四步走38df113e604c3b2223218c4a9fd650fd.png

要将CV任务统一建模成sequence生成,主要包括以下几步:

1. 统一输入输出:Tokenization序列化

输入是一张image;输出是一个离散的token sequence:task prompt + task output,其中task prompt用于描述具体任务(一般是任务指令+additional input tokens),task output是需要model生成的部分,是目标结果的序列化描述。例如对上述四个任务:

  • 目标检测(object detection):task prompt是detect指令,task output包括每个object的bounding box两个坐标点和object label。

  • 实例分割(instance segmentation):task prompt包括segment指令和给定的object instance,task output是segmentation多边形的坐标。

  • 人体关键点检测(keypoint detection):task prompt包括keypoint指令和给定的object instance,task output是一些keypoint坐标点。

  • 图像描述生成(image captioning):task prompt是Describe指令,task output是image caption sentence。

bf5cac7a22f4be36b225ecfa7806040d.png

2. 统一损失函数

现在数据变成了统一的image input和sequence output,那么input image可以自然地用一个vision encoder表示(CovNet、Transformer等都可),output sequence可以用一个sequence decoder建模,即给定encoder hidden state和之前生成的sequence,预测下一个token:

402 Payment Required

这里x代表image,y1:j-1是之前生成的sequence,yj是下一个token。但由于output sequence包括两个部分task prompt和task input,其中task prompt是给定的,不需要生成,因此不需要加到generation loss中。

所以,这里引入wj权重,当yj在task prompt中,wj设置成0,不参与loss计算。

3. 多任务联合训练

由于输入输出形式、损失都是统一的,在优化时可以选择两种联合训练的方式:

(1)直接混合所有数据,随机采样,进行优化:

2caeb9387ccf32102466363ee52e872e.png

(2)对各task分别计算loss,然后合并所有task的梯度,优化模型:

7d9cf0570dc0d06eb9713ce770cf7558.png

第一种更为简单,但涉及到image augmentations对不同output sequence可能是不同的。

同时,第二种可以控制每个task的权重,作者通过贪心策略逐个添加task并调整权重,确定最终的各个task权重。

4. 最终输出:反序列化Detokenization

反序列化就是把token再次数字化,例如对与objection detection,将output token sequence变成5个token一组,每组前4个token代表坐标,第5个token代表object class label。

其中,序列的生成和Pix2Seq第一个版本一样,都采用nucleus sampling。

e6696d8b46f4df37aafab65cad43e780.png实验结果c41c9c25f0217695350faa1f13ba74d9.png

实验的架构和Pix2Seq是一样的,采用了Vision Transformer (ViT-B) encoder和Transformer autoregressive decoder,共有132M的参数。

cc44ddd3355705f65a9f1b49d6ab43ad.png

值得注意的是该论文没有使用大规模图片-文本预训练。模型的初始化来自于Pix2Seq,是在Object Detection数据集上预训练得到的(因此image captioning的结果受限,加入图片-文本数据应该会有提升)。

图片的大小有640x640和1024x1024两种大小。同时作者比较了两个变种:single task单独训练各任务,multi-task会同时一起训练所有任务,即多任务联合训练。

主要的结论包括:

  • 该模型在4个任务中都取得了与主流模型相当的效果。

  • 多任务训练的影响:并不统一。

  • 图片大小的影响:图片越大,结果越好。

9a9eb8802f90d72f57f27167a27d06d6.png结论6b2a8181abfe93043b0273872837e405.png

这篇工作的模型架构和第一版的Pix2Seq基本一致,重点在于怎样将这种框架adapt到多种不同输出形式的CV任务上。目前对各个CV任务的序列化非常直观简单,但效果却是不错的。

b9397c36151a8a84bb1b87cded66e290.png最后的话2e185a0ffaf26ccad479a61f1a0d16e0.png

大一统模型近期层出不穷,而这种离散的token序列的表示方式,小编认为是非常有希望的一个方向,因为这种方式同时可以尝试把NLP和CV并入一个框架,同时离散token的方式也天然能够加入speech的处理。

因此,小编也很期待这种统一接口可以加入更多模态(modality),例如video、audio等。

小编在读的时候,主要的concern是这种localization真的可以准确吗?这个quantilize和dequantilize的过程把number变成了token,失去精度不准确怎么办?

作者在实验中针对这个问题,也做了简单的处理,对instance segmentation任务,通过nucleus sampling生成多个结果,并取平均。

但对数值化的token表示应该是需要更多思考的,这种token在未来是否可以具备计算能力,也是很有意思的议题。

cde63f6d4133caa01afb7ca3bc19ae01.jpeg后台回复关键词【入群

加入卖萌屋NLP、CV、搜广推与求职讨论群

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/477143.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

美团针对Redis Rehash机制的探索和实践

背景 Squirrel(松鼠)是美团技术团队基于Redis Cluster打造的缓存系统。经过不断的迭代研发,目前已形成一整套自动化运维体系:涵盖一键运维集群、细粒度的监控、支持自动扩缩容以及热点Key监控等完整的解决方案。同时服务端通过Doc…

剑指Offer - 面试题59 - II. 队列的最大值(deque模拟单调栈)

1. 题目 请定义一个队列并实现函数 max_value 得到队列里的最大值,要求函数max_value、push_back 和 pop_front 的时间复杂度都是O(1)。 若队列为空,pop_front 和 max_value 需要返回 -1 示例 1: 输入: ["MaxQueue","push…

行业现状令人失望,工作之后我又回到UC伯克利读博了

文 | SHREYA SHANKAR编 | 小舟、陈萍源 | 机器之心很多同学在面临读博和工作的选择时会犹豫不决,这篇文章也许能给你一点启发。机器学习领域近来受到大模型的冲击,很多小公司表示难以承担大模型的训练费用。但行业中机器学习工程的发展具体是怎样的&…

前端遇上Go: 静态资源增量更新的新实践

为什么要做增量更新 美团金融的业务在过去的一段时间里发展非常快速。在业务增长的同时,我们也注意到,很多用户的支付环境,其实是在弱网环境中的。 大家知道,前端能够服务用户的前提是 JavaScript 和 CSS 等静态资源能够正确加载。…

剑指Offer - 面试题26. 树的子结构(双重递归)

1. 题目 输入两棵二叉树A和B,判断B是不是A的子结构。(约定空树不是任意一个树的子结构) B是A的子结构, 即 A中有出现和B相同的结构和节点值。 例如: 给定的树 A:3/ \4 5/ \1 2 给定的树 B:4 /1 返回 true,因为 B 与 A 的一…

给1万帧视频做目标分割,显存占用还不到1.4GB | ECCV2022

文 | 明敏 发自 凹非寺源 | 量子位 | 公众号 QbitAI咦,怎么好好的藤原千花,突然变成了“高温红色版”?这大紫手,难道是灭霸在世??如果你以为上面的这些效果只是对物体后期上色了,那还真是被AI给…

互联网公司数据安全保护新探索

近年来,数据安全形势越发严峻,各种数据安全事件层出不穷。在当前形势下,互联网公司也基本达成了一个共识:虽然无法完全阻止攻击,但底线是敏感数据不能泄漏。也即是说,服务器可以被挂马,但敏感数…

剑指Offer - 面试题47. 礼物的最大价值(动态规划)

1. 题目 在一个 m*n 的棋盘的每一格都放有一个礼物,每个礼物都有一定的价值(价值大于 0)。你可以从棋盘的左上角开始拿格子里的礼物,并每次向右或者向下移动一格、直到到达棋盘的右下角。给定一个棋盘及其上面的礼物的价值&#…

大佬在线复盘:我在训练 DALL·E 时犯过的错

文 | jxyxiangyu在写了一周的业务代码后,沏一杯绿茶,总算可以有时间看看鸽了一个月的素材了。好的,小伙伴们,废话不多说,今天我们将跟随 Boris Dayma 大佬,看看他在训练 DALLE-Mega 时遇到的一系列问题。据…

Toast与Snackbar的那点事

背景 Toast是Android平台上的常用技术。从用户角度来看,Toast是用户与App交互最基本的提示控件;从开发者角度来看,Toast是开发过程中常用的调试手段之一。此外,Toast语法也非常简单,仅需一行代码。基于简单易用的优点&…

LintCode 1683. 杀怪兽(队列)

1. 题目 有 n 只怪兽和一个奥特曼,奥特曼和怪兽都有5个属性值。 当且仅当奥特曼的5个属性值都不小于怪兽时,奥特曼可以杀死怪兽。 当一个怪兽被杀掉时,这个怪兽的5个属性会增加到奥特曼身上。 请问奥特曼最多可以杀死多少怪兽? 样例 1: 输…

聊聊大火的多模态

多模态机器学习,英文全称 MultiModal Machine Learning (MMML),旨在通过机器学习的方法实现处理和理解多源模态信息的能力。每一种信息的来源或者形式,都可以称为一种模态。例如,人有触觉,听觉,视觉&#x…

2018开春大礼:750页电子书 + 33场技术沙龙资料 + 17场线上课程分享

2017年,美团成长为中国领先的生活服务电子商务平台,在吃喝玩乐住行等200多个品类,2800多个城区县,服务了亿万消费者、数百万商家,日订单数超过2200万,年度交易总额达到了3600亿。2017年10月,美团…

LintCode 1677. 石头(自定义优先队列)

1. 题目 给定数组 p 代表 n 个石头的位置和数组 d 代表这 n 块石头能够扔的距离。 从左(0位置)往右走。当你第 k 次碰到一个石头时, 如果 k 是奇数, 把这个石头往右扔; 如果 k 是偶数,跳过这个石头。 返回不再会碰到石头时&…

手机上也能训练BERT和ResNet了?!

源 | 机器之心研究者表示,他们将边缘训练看作一个优化问题,从而发现了在给定内存预算下实现最小能耗的最优调度。目前,智能手机和嵌入式平台等边缘设备上已经广泛部署深度学习模型来进行推理。其中,训练仍然主要是在具有 GPU 等高…

LintCode 125. 背包问题 II(DP)

1. 题目 有 n 个物品和一个大小为 m 的背包. 给定数组 A 表示每个物品的大小 数组 V 表示每个物品的价值. 问最多能装入背包的总价值是多大? 样例 1: 输入: m 10, A [2, 3, 5, 7], V [1, 5, 2, 4] 输出: 9 解释: 装入 A[1] 和 A[3] 可以得到最大价值, V[1] V[3] 9 样例…

大众点评App的短视频耗电量优化实战

前言 美团测试团队负责App的质量保证工作,日常除了App的功能测试以外,还会重点关注App的性能测试。现在大家对手机越来越依赖,而上面各App的耗电量,直接影响了手机的待机时间,是用户非常关心的一点。本文主要通过一个典…

解决CNN固有缺陷!通用 CNN 架构CCNN来了| ICML2022

文 | David W. Romero等源丨机器之心在 VGG、U-Net、TCN 网络中... CNN 虽然功能强大,但必须针对特定问题、数据类型、长度和分辨率进行定制,才能发挥其作用。我们不禁会问,可以设计出一个在所有这些网络中都运行良好的单一 CNN 吗&#xff1…

境外业务性能优化实践

本文根据第16期美团技术线上沙龙OnLine演讲内容整理而成。 前言 性能问题简介 应用性能是产品用户体验的基石,性能优化的终极目标是优化用户体验。当我们谈及性能,最直观能想到的一个词是“快”,Strangeloop在对众多的网站做性能分析之后得出…

LeetCode 第 21 场双周赛(779/1913,前40.7%)

文章目录1. 比赛结果2. 题目LeetCode 5336. 上升下降字符串 easyLeetCode 5337. 每个元音包含偶数次的最长子字符串 mediumLeetCode 5338. 二叉树中的最长交错路径 mediumLeetCode 5339. 二叉搜索子树的最大键值和 hard1. 比赛结果 只做出来了第1题,第3题有一个例子…