multi task训练torch_Multi-task Learning的三个小知识

a3168cef16209be56750fcf3fec14ce2.png

本文译自Deep Multi-Task Learning – 3 Lessons Learned by Zohar Komarovsky

在过去几年里,Multi-Task Learning (MTL)广泛用于解决多个Taboola(公司名)的业务问题。在这些业务问题中, 人们使用一组相同的特征以及深度学习模型来解决MTL相关问题。在这里简单分享一下我们做MTL时学习到的一些小知识。

小知识第一条: 整合损失函数

MTL模型中的第一个挑战: 如何为multiple tasks定义一个统一的损失函数?
最简单的办法,我们可以整合不同tasks的loss function,然后简单求和。这种方法存在一些不足,比如当模型收敛时,有一些task的表现比较好,而另外一些task的表现却惨不忍睹。其背后的原因是不同的损失函数具有不同的尺度,某些损失函数的尺度较大,从而影响了尺度较小的损失函数发挥作用。这个问题的解决方案是把多任务损失函数“简单求和”替换为“加权求和”。加权可以使得每个损失函数的尺度一致,但也带来了新的问题:加权的超参难以确定。

幸运的是,有一篇论文《Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics》通过“不确定性(uncertainty)”来调整损失函数中的加权超参,使得每个任务中的损失函数具有相似的尺度。该算法的keras版本实现,详见github。

小知识第二条:调整学习率 learning rate

在神经网络的参数中,learning rate是一个非常重要的参数。在实践过程中,我们发现某一个learnig rate=0.001能够把任务A学习好,而另外一个learning rate=0.1能够把任务B学好。选择较大的learning rate会导致某个任务上出现dying relu;而较小的learning rate会使得某些任务上模型收敛速度过慢。怎么解决这个问题呢?对于不同的task,我们可以采用不同的learning rate。这听上去很复杂,其实非常简单。通常来说,训练一个神经网络的tensorflow代码如下:

optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)

其中AdamOptimizer定义了梯度下降的方式,minimize则计算梯度并最小化损失函数。我们可以通过自定义一个minimize函数来对某个任务的变量设置合适的learning rate。

all_variables = shared_vars + a_vars + b_vars
all_gradients = tf.gradients(loss, all_variables)shared_subnet_gradients = all_gradients[:len(shared_vars)]
a_gradients = all_gradients[len(shared_vars):len(shared_vars + a_vars)]
b_gradients = all_gradients[len(shared_vars + a_vars):]shared_subnet_optimizer = tf.train.AdamOptimizer(shared_learning_rate)
a_optimizer = tf.train.AdamOptimizer(a_learning_rate)
b_optimizer = tf.train.AdamOptimizer(b_learning_rate)train_shared_op = shared_subnet_optimizer.apply_gradients(zip(shared_subnet_gradients, shared_vars))
train_a_op = a_optimizer.apply_gradients(zip(a_gradients, a_vars))
train_b_op = b_optimizer.apply_gradients(zip(b_gradients, b_vars))train_op = tf.group(train_shared_op, train_a_op, train_b_op)

值得一提的是,这样的trick在单任务的神经网络上效果也是很好的。

小知识第三条:任务A的评估作为其他任务的特征

当我们构建了一个MTL的神经网络时,该模型对于任务A的估计可以作为任务B的一个特征。在前向传播时,这个过程非常简单,因为模型对于A的估计就是一个tensor,可以简单的将这个tensor作为另一个任务的输入。但是后向传播时,存在着一些不同。因为我们不希望任务B的梯度传给任务A。幸运的是,Tensorflow提供了一个API tf.stop_gradient。当计算梯度时,可以将某些tensor看成是constant常数,而非变量,从而使得其值不受梯度影响。代码如下:

all_gradients = tf.gradients(loss, all_variables, stop_gradients=stop_tensors)

再次值得一提的是,这个trick不仅仅可以在MTL的任务中使用,在很多其他任务中也都发挥着作用。比如,当训练一个GAN模型时,我们不需要将梯度后向传播到对抗样本的生成过程中。

感谢阅读,希望本文对您有所帮助! 谢谢!

如果觉得文章对您有帮助,可以关注本人的微信公众号:机器学习小知识

9fd9a2d41c786a3a1c003eb9966175b8.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/332899.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java8多线程运行程序_线程,代码和数据–多线程Java程序实际运行的方式

java8多线程运行程序有些事情是您在学术或培训班上没有学到的,经过几年的工作经验后才逐渐了解,然后才意识到,这是非常基本的事情,我为什么错过了这么多年。 了解多线程Java程序的执行方式就是其中之一。 您肯定已经听说过线程&am…

zsh命令行界面/zsh终端界面粘贴卡顿的问题

因为安装了某些zsh插件导致,在zsh命令行中粘贴文本非常卡顿,解决方案就是把下面的代码复制到 ~/.zshrc 文件中: pasteinit() {OLD_SELF_INSERT${${(s.:.)widgets[self-insert]}[2,3]}zle -N self-insert url-quote-magic # I wonder if youd…

java连接mongodb的jar包_Java实战之管家婆记账系统(1)——项目简述

项目简述:该项目是一个通过JavaFX实现的管家婆记账系统,具有记账的功能。使用软件:IntelliJ IDEA 2018.3.5(Ultim ate Edition):编写Java项目代码。JavaFX Scene Builder 2.0:生成fxml界面文件。Navicat for MySQL&…

oauth2.0协议流程_正确的工作流程:我应该使用哪个OAuth 2.0流程?

oauth2.0协议流程什么是OAuth 2.0 OAuth 2.0是一个已被广泛采用的委托授权框架,已经存在了很多年,并且似乎已经存在。 如果您不熟悉OAuth 2.0的基本概念,可以使用 川崎孝彦写的优秀文章 。 这只是OAuth 2.0各方的简要提醒: 资源…

远程Linux主机安装zsh插件zsh-syntax-highlighting

安装说明: https://github.com/zsh-users/zsh-syntax-highlighting/blob/master/INSTALL.md 根据安装说明: 1.Clone this repository in oh-my-zsh’s plugins directory: git clone https://github.com/zsh-users/zsh-syntax-highlightin…

scare机器人如何手眼标定_基于视觉伺服的工业机器人系统研究(摄像机标定、手眼标定、目标单目定位)...

击上方“新机器视觉”,选择加"星标"或“置顶”重磅干货,第一时间送达标定技术常见的机器人视觉伺服中要实现像素坐标与实际坐标的转换,首先就要进行标定,对于实现视觉伺服控制,这里的标定不仅包括摄像机标定…

单元测试junit参数_使用Junit参数在更少的时间内编写更好的单元测试

单元测试junit参数大多数人都知道单元测试的重要性和好处,以及为什么要在进行的项目中使用它们。 而且,大多数人不喜欢在他们从事的项目中编写单元测试。 TDD的人当然处于另一面,但根据我的经验,他们在IT行业中是少数派。 说到我…

Linux CentOS安装zsh插件提示/usr/bin/env: python: No such file or directory。

执行 ./install.py 文件时,提示: /usr/bin/env: python: No such file or directory查看系统已安装的 python 版本: ➜ ~ ls -l /usr/bin | grep python lrwxrwxrwx 1 root root 36 11月 2 18:11 python -> /etc/alternativ…

3l如何使用_慢阻肺患者如何选购呼吸机和制氧机,需要注意哪些?

慢阻肺(COPD)是慢性阻塞性肺疾病的简称,进一步发展为肺心病和呼吸衰竭的常见慢性疾病。与有害气体及有害颗粒的异常炎症反应有关,致残率和病死率很高,全球大约有2.1亿人患有慢阻肺,中国大概约有4000-8000万人。慢阻肺已成为全球范…

apache ignite_通过示例获取Apache Ignite Baseline拓扑

apache ignite点燃基准拓扑或BLT代表群集中的一组服务器节点,这些服务器节点将数据持久存储在磁盘上。 其中,N1-2和N5服务器节点是具有本机持久性的Ignite集群的成员,该集群使数据能够持久存储在磁盘上。 N3-4和N6服务器节点是Ignite群集的…

自定义报错返回_MybatisPlus基础篇学习笔记(五)------自定义sql及分页查询

本章目录自定义sql分页查询1. 自定义sql在dao文件中编写自定义接口,并在方法上使用注解形式注入SQL,如图所示:第一种:第二种① application.yml加入下面配置mybatis-plus:mapper-locations: com/ethan/mapper/*② MemberMapper.ja…

精简jdk包_具有JDK 12精简数字格式的自定义精简数字模式

精简jdk包帖子“ 紧凑数字格式出现在JDK 12中 ”已经成为有关Java subreddit线程的讨论主题 。 在那个线程中表达的与紧凑数字格式表示有关的问题涉及显示的精度数字和显示的紧凑数字模式。 可以通过使用CompactNumberFormat.setMinimumFractionDigits(int)来解决精度数字问题&…

两个数相乘积一定比每个因数都大_人教版五年级数学:因数、倍数与分数的整理与复习...

写在前面的话:因数与倍数和分数基本性质之间存在紧密的联系,可以将之放在一起学习,对分数基本性质的学习有促进作用,分数的基本性质对分数的加法和减法也非常重要,因此可以放在一起学习、复习。【整理与复习】因数与倍…

Linux中在zsh下如何安装autojump

文章目录介绍安装介绍 autojump is a faster way to navigate your filesystem. It works by maintaining a database of the directories you use the most from the command line. Directories must be visited first before they can be jumped to. 关于 autojump 有以下几个…

fork join框架_Java中的Fork / Join框架的简要概述

fork join框架Fork / Join框架是使用并发分治法解决问题的框架。 引入它们是为了补充现有的并发API。 在介绍它们之前,现有的ExecutorService实现是运行异步任务的流行选择,但是当任务同质且独立时,它们会发挥最佳作用。 运行依赖的任务并使用…

3模型大小_Github推荐一个国内牛人开发的超轻量级通用人脸检测模型

Ultra-Light-Fast-Generic-Face-Detector-1MB1MB轻量级通用人脸检测模型作者表示该模型设计是为了边缘计算设备以及低功耗设备(如arm)设计的实时超轻量级通用人脸检测模型。它可以用于arm等低功耗计算设备,实现实时的通用场景人脸。检测推理同…

macOS如何使用命令启动服务/停止服务/查看服务

文章目录开启服务停止服务查看服务是否启动开启服务 使用命令开启 sshd 服务: $ sudo launchctl load -w /System/Library/LaunchDaemons/ssh.plist注:成功启动不会有任何输出 停止服务 $ sudo launchctl unload -w /System/Library/LaunchDaemons/s…

如何在用例之间传递值_如何从0搭建自己的自动化测试体系

大家好,我是爱吃面条,今天给大家讲讲如何从0搭建自己的自动化测试体系1. 需求和目标在我开展自动化测试之前,其实该项目以前的测试人员也已经写了很多的接口测试用例,但是大多数用例处于“半瘫痪”状态,在CI上无人维护…

envoy api 网关_为Envoy构建控制平面的指南-特定于域的配置API

envoy api 网关建立您的控制平面交互点和API表面 一旦考虑了哪些组件可能构成控制平面体系结构(请参见上一章),您将要确切考虑用户将如何与控制平面进行交互,甚至更重要的是, 用户将是谁? 要回答这个问题&a…

异步非阻塞_细说同步异步、阻塞非阻塞

同步、异步同步、异步分别指的是一种通讯方式,当 cpu 不需要执行线程上下文切换就能完成任务,此时便认为这种通讯方式是同步的,相对的如果存在cpu 上下文切换,这种方式便是异步。这里通过一个去食堂打饭的示例来理解什么是同步、异…