《Learning to Reweight Examples for Robust Deep Learning》笔记

[1] 用 meta-learning 学样本权重,可用于 class imbalance、noisy label 场景。之前对其 (7) 式中 ϵ i , t = 0 \epsilon_{i,t}=0 ϵi,t=0对应 Algorithm 1 第 5 句、代码 ex_wts_a = tf.zeros([bsize_a], dtype=tf.float32))不理解:如果 ϵ \epsilon ϵ 已知是 0,那 (4) 式的加权 loss 不是恒为零吗?(5) 式不是优化了个吉而 θ ^ t + 1 ( ϵ ) ≡ θ t \hat\theta_{t+1}(\epsilon) \equiv \theta_t θ^t+1(ϵ)θt ?有人在 issue 提了这个问题[2],但其人想通了没解释就关了 issue。

看到 [3] 代码中对 ϵ \epsilon ϵ 设了 requires_grad=True 才反应过来:用编程的话说, ϵ \epsilon ϵ 不应理解成常数,而是变量; 用数学的话说,(5) 的求梯度( ∇ \nabla )是算子,而不是函数,即 (5) 只是在借梯度下降建立 θ ^ t + 1 \hat\theta_{t+1} θ^t+1 ϵ \epsilon ϵ 之间的函数(或用 TensorFlow 的话说,只是在建图),即 θ ^ t + 1 ( ϵ ) \hat\theta_{t+1}(\epsilon) θ^t+1(ϵ),而不是基于常数 θ t \theta_t θt ϵ = 0 \epsilon=0 ϵ=0 算了一步 SGD 得到一个常数 θ ^ t + 1 \hat\theta_{t+1} θ^t+1

一个符号细节:无 hat 的 θ t + 1 \theta_{t+1} θt+1 指由 (3) 用无 perturbation 的 loss 经 SGD 从 θ t \theta_t θt 优化一步所得; θ ^ t + 1 \hat\theta_{t+1} θ^t+1 则是用 (4) perturbed loss。文中 (6)、(7) 有错用作 θ t + 1 \theta_{t+1} θt+1 的嫌疑。

所以大思路是用 clean validation set 构造一条关于 ϵ \epsilon ϵ 的 loss J ( ϵ ) J(\epsilon) J(ϵ),然后用优化器求它,即 ϵ t ∗ = arg ⁡ min ⁡ ϵ J ( ϵ ) \epsilon_t^*=\arg\min_\epsilon J(\epsilon) ϵt=argminϵJ(ϵ)。由 (4) - (6) 有: J ( ϵ ) = 1 M ∑ j = 1 M f j v ( θ ^ t + 1 ( ϵ ) ) ( 6 ) = 1 M ∑ j = 1 M f j v ( θ t − α [ ∇ θ ∑ i = 0 n f i , ϵ ( θ ) ] ∣ θ = θ t ⏟ g 1 ( ϵ ; θ t ) ) ( 5 ) = 1 M ∑ j = 1 M f j v ( θ t − α [ ∇ θ ∑ i = 0 n ϵ i f i ( θ ) ] ∣ θ = θ t ) ( 4 ) = g 2 ( ϵ ; θ t ) \begin{aligned} J(\epsilon) &= \frac{1}{M}\sum_{j=1}^M f_j^v \left(\hat\theta_{t+1}(\epsilon) \right) & (6) \\ &= \frac{1}{M}\sum_{j=1}^M f_j^v \left(\theta_t - \alpha \underbrace{\left[ \nabla_{\theta} \sum_{i=0}^n f_{i,\epsilon}(\theta) \right] \bigg|_{\theta=\theta_t}}_{g_1(\epsilon; \theta_t)} \right) & (5) \\ &= \frac{1}{M}\sum_{j=1}^M f_j^v \left(\theta_t - \alpha \left[ \nabla_{\theta} \sum_{i=0}^n \epsilon_i f_i(\theta) \right] \bigg|_{\theta=\theta_t} \right) & (4) \\ &= g_2(\epsilon; \theta_t) \end{aligned} J(ϵ)=M1j=1Mfjv(θ^t+1(ϵ))=M1j=1Mfjv θtαg1(ϵ;θt) [θi=0nfi,ϵ(θ)] θ=θt =M1j=1Mfjv(θtα[θi=0nϵifi(θ)] θ=θt)=g2(ϵ;θt)(6)(5)(4) 要注意的就是 (5) 那求导式,本质是个函数,而不是常数,其中 ϵ \epsilon ϵ 是自由的, θ \theta θ 由于被 ∣ θ = θ t |_{\theta=\theta_t} θ=θt 指定了,所以看成常数,所以记为 g 1 ( ϵ ; θ t ) g_1(\epsilon;\theta_t) g1(ϵ;θt),于是整个 J ( ϵ ) J(\epsilon) J(ϵ) 也可以看成一个 g 2 ( ϵ ; θ t ) g_2(\epsilon; \theta_t) g2(ϵ;θt)

按 (6) 求 ϵ t ∗ \epsilon_t^* ϵt 的思路就是:

  1. 随机初始化 ϵ t ( 0 ) \epsilon_t^{(0)} ϵt(0)
  2. ϵ t s + 1 ← ϵ t s − η ∇ ϵ J ( ϵ ) ∣ ϵ = ϵ t s \epsilon^{s+1}_t \leftarrow \epsilon^s_t - \eta \nabla_{\epsilon} J(\epsilon) \big|_{\epsilon=\epsilon^s_t} ϵts+1ϵtsηϵJ(ϵ) ϵ=ϵts,即 (7) 右边。可能由于 J ( ϵ ) J(\epsilon) J(ϵ) 形式上是带梯度的表达式, § \S § 3.3 就称此为「unroll the gradient graph」,而求 ϵ t ( s + 1 ) \epsilon^{(s+1)}_t ϵt(s+1) 的这一步就称为「backward-on-backward」吧。

而文章的 online approximation 就是:

  • ϵ t ( 0 ) = 0 \epsilon^{(0)}_t=0 ϵt(0)=0
  • ϵ t ∗ ≈ ϵ t ( 1 ) \epsilon^*_t \approx \epsilon^{(1)}_t ϵtϵt(1)

初始化为 0 可能不是最好的初始化方法,但不影响后续迭代优化,可参考 LoRA[7],它也用到全零初始化。

References

  1. (ICML’18) Learning to Reweight Examples for Robust Deep Learning - paper, code
  2. gradients of noisy loss w.r.t parameter \theta #2
  3. (PyTorch 复现 1)TinfoilHat0/Learning-to-Reweight-Examples-for-Robust-Deep-Learning-with-PyTorch-Higher
  4. (PyTorch 复现 2)danieltan07/learning-to-reweight-examples
  5. facebookresearch/higher
  6. Stateful vs stateless
  7. (ICLR’22) LoRA: Low-Rank Adaptation of Large Language Models - paper, code

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/640665.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode 每日一题 Day 47 - 50

2171. 拿出最少数目的魔法豆 给定一个 正整数 数组 beans ,其中每个整数表示一个袋子里装的魔法豆的数目。 请你从每个袋子中 拿出 一些豆子(也可以 不拿出),使得剩下的 非空 袋子中(即 至少还有一颗 魔法豆的袋子&a…

数据结构课程设计 仓储管理系统

仓储管理系统 【基本功能】 把货品信息表抽象成一个线性表,货品信息(包括ID、货品名、定价、数量等)作为线性表的一个元素,实现:按ID、货品名分别查找某货品信息(包括ID、货品名、定价、数量等&#xff0…

C++版QT:电子时钟

digiclock.h #ifndef DIGICLOCK_H #define DIGICLOCK_H ​ #include <QLCDNumber> ​ class DigiClock : public QLCDNumber {Q_OBJECT public:DigiClock(QWidget* parent 0);void mousePressEvent(QMouseEvent*);void mouseMoveEvent(QMouseEvent*); public slots:voi…

JVM常量池详解

欢迎大家关注我的微信公众号&#xff1a; 目录 Class常量池与运行时常量池 字符串常量池 字符串常量池的设计思想 三种字符串操作(Jdk1.7 及以上版本) 字符串常量池位置 字符串常量池设计原理 String常量池问题的几个例子 八种基本类型的包装类和对象池 Class常量…

防范水坑攻击:了解原理、类型与措施

水坑攻击是一种常见的网络攻击方式&#xff0c;它利用了人类在互联网上的行为习惯&#xff0c;诱导用户访问恶意网站或下载恶意软件&#xff0c;从而获取用户的个人信息或控制用户的计算机系统。本文将介绍水坑攻击的原理、类型和防范措施。 一、水坑攻击的原理 水坑攻击&…

Cyber RT 服务通信

场景&#xff1a; 用户乘坐无人出租车过程中&#xff0c;可能临时需要切换目的地&#xff0c;用户可以通过车机系统完成修改&#xff0c;路径规划模块需要根据新的目的地信息重新规划路径&#xff0c;并反馈修正后的结果给用户&#xff0c;那么用户的修正请求数据与修正结果是如…

使用STM32的SPI接口实现与外部传感器的数据交互

一、引言 外部传感器是嵌入式系统中常用的外设&#xff0c;用于检测环境参数、采集数据等。通过STM32微控制器的SPI接口&#xff0c;可以与外部传感器进行数据交互&#xff0c;从而实现数据的采集和控制。本文将介绍如何使用STM32的SPI接口实现与外部传感器的数据交互&#xff…

Web 安全之水坑攻击(Watering Hole Attack)详解

目录 什么是水坑攻击&#xff08;Watering Hole Attack&#xff09; 水坑攻击的原理 水坑攻击的实施案例 水坑攻击的防范方法 小结 什么是水坑攻击&#xff08;Watering Hole Attack&#xff09; 水坑攻击&#xff08;Watering Hole Attack&#xff09;是一种精心策划的网…

常用芯片学习——HC245芯片

HC245三态输出八路总线收发器 使用说明 这些八路总线收发器专为数据总线之间的异步双向通信而设计。控制功能实现可更大限度地减少外部时序要求。根据方向控制 (DIR) 输入上的逻辑电平&#xff0c;此类器件将数据从 A 总线发送至 B 总线&#xff0c;或者将数据从 B 总线发送至…

Windows安装Anaconda教程

windows环境搭建专栏&#x1f517;点击跳转 win系统环境搭建&#xff08;十八&#xff09;——Windows安装Anaconda教程 本文是我实践后写的&#xff0c;无脑跟随安装即可 在我看来&#xff0c;Anaconda的图标如同一只灵蛇咬住了自己的尾巴&#xff0c;优美而神秘。 全称&…

网络安全全栈培训笔记(55-服务攻防-数据库安全RedisHadoopMysqla未授权访问RCE)

第54天 服务攻防-数据库安全&Redis&Hadoop&Mysqla&未授权访问&RCE 知识点&#xff1a; 1、服务攻防数据库类型安全 2、Redis&Hadoop&Mysql安全 3、Mysql-CVE-2012-2122漏洞 4、Hadoop-配置不当未授权三重奏&RCE漏洞 3、Redis-配置不当未授权…

JVM的组成部分(类加载器、运行时数据区、执行引擎、本地库接口)

目录 JVM作用 JVM构成 1.类加载器 类加载子系统&#xff1a; 类加载器的分类&#xff1a; 双亲委派机制&#xff1a; 2.运行时数据区 程序计数器 虚拟机栈 本地方法栈 堆 方法区 3.执行引擎 4.本地库接口 JVM作用 jvm是将字节码文件加载到虚拟机中&#xff0c;…

lc11 盛最多水的容器

问题&#xff1a;给一个整数数组&#xff0c;数组中的元素值为高&#xff0c;数组元素之间的距离为边&#xff0c;计算任意两个元素之间的面积&#xff08;以元素值低的为高&#xff09;&#xff0c;求最大面积 题解&#xff1a;双指针题解 //设计算法&#xff1a;先计算索引…

2024年【广东省安全员B证第四批(项目负责人)】新版试题及广东省安全员B证第四批(项目负责人)作业模拟考试

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 广东省安全员B证第四批&#xff08;项目负责人&#xff09;新版试题参考答案及广东省安全员B证第四批&#xff08;项目负责人&#xff09;考试试题解析是安全生产模拟考试一点通题库老师及广东省安全员B证第四批&…

一.Winform使用Webview2(Edge浏览器核心) 创建demo(Demo1)实现回车导航到指定地址

Winform使用Webview2创建demo1实现回车导航到指定地址 往期目录参考文档实现1.安装visual studio2.创建单窗口应用3.修改项目中的窗体名称MainForm4.添加按钮5.添加窗口Demo16.在Demo1中添加WebView2 SDK7.在Demo1窗体中选择添加textbox和webview28.在MainForm.cs窗体中添加but…

[ComfyUI进阶教程] lcm+Lora+Controlnet工作流工作流

这是一个使用了LCMlora加载器CN&#xff08;depthtile&#xff09;的工作流。 工作流特性&#xff1a; LCM lora加载器&#xff0c;加快生成图片的时间。 配置了3个lora加载器&#xff0c;用来进行人物和风格设定。 提示词编辑器&#xff0c;预制了默认的动态提示词。 使用了…

【RabbitMQ】交换机详解看这一篇就够了

&#x1f389;&#x1f389;欢迎来到我的CSDN主页&#xff01;&#x1f389;&#x1f389; &#x1f3c5;我是Java方文山&#xff0c;一个在CSDN分享笔记的博主。&#x1f4da;&#x1f4da; &#x1f31f;推荐给大家我的专栏《RabbitMQ实战》。&#x1f3af;&#x1f3af; &am…

PointNet系列【语义分割】自定义数据的模型训练

目录 一、平台 二、数据 三、代码 3.1 文件组织结构 3.2 lasDataLoader.py 读取数据 3.3 修改原始模型的通道数量 3.4 lasTrainSS.py【训练】 3.5 lasTestSS.py【预测】 一、平台 Windows 10 GPU RTX 3090 CUDA 11.1 cudnn 8.9.6 Python 3.9 Torch 1.9.1 cu111…

每个人都可以是架构师,每个人都需要培养架构思维

您好&#xff0c; 如果喜欢我的文章或者想上岸大厂&#xff0c;可以关注公众号「量子前端」&#xff0c;将不定期关注推送前端好文、分享就业资料秘籍&#xff0c;也希望有机会一对一帮助你实现梦想 什么是架构 “架构”&#xff0c;即架设、构建。完成对于平台的合理架设&am…

VMware安装Linux-Redhat7.9 详细步骤

目录 一、安装准备二、安装步骤 一、安装准备 Redhat 7.9 镜像下载 VMware安装步骤可查看文章&#xff1a;https://blog.csdn.net/a2279338659/article/details/126346345 可去官网下载&#xff0c;或者加群下载镜像资源。 二、安装步骤 创建新的虚拟机&#xff1a; 我这边…