MoCo 算法阅读记录

论文地址:🐰

何凯明大神之作,通过无监督对比学习预训练Image Encoder的表征能力。后也被许多VLP算法作为ITC的底层算法来使用。

一方面由于源代码本身并不复杂,但是要求多GPU分布式训练,以及需要下载ImageNet这个大规模的数据集;另一方面 本次只是测试和阅读算法原理的实现,并不完整使用。因此,重写了一个低配版(流程不变,超参数没有严格要求设置,单GPU跑,数据集自己配置,几十张图片, no Shuffling BN)。

queue 即文中所构建的字典,起名为这个就是因为 C++ 中 的queue 容器,因为它是一种先进先出的数据结构。

目录

一、数据预处理

二、前向传播

网络结构

算法流程


一、数据预处理

对同一张图片进行数据增强操作,得到 query 和 key。

增强操作包括

transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),transforms.RandomGrayscale(p=0.2),transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),transforms.RandomHorizontalFlip(),normalize,

所以,dataloader中的每个输入样本是一个样本对儿。

通过下列方法实现

class TwoCropsTransform:"""Take two random crops of one image as the query and key."""def __init__(self, base_transform):self.base_transform = base_transformdef __call__(self, x):q = self.base_transform(x)k = self.base_transform(x)return [q, k]

二、前向传播

网络结构

代码中 encoder q 和 encoder k的网络结构用的都是ReNet 。ResNet最终的输出层包含了

(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=2048, out_features=128, bias=True)

所以,输出的特征向量维度为 (N,C)。N为文中的Mini batch大小,代码中的超参数为batch size。C应该没有什么具体的含义,只是经验的设置为这一长度了(没找出来C的大小关乎什么)。

其输出还经过了L2归一化。 

算法流程

1、 q 送入 encoder q 得到输出,并经过L2归一化, (N,C)

2、 momentum 更新 key encoder。

3、 Shuffling BN(当然我重写的代码并没有实现这个,因为它需要多GPU,但这并不妨碍认识它的作用)

文中所述

大致意思由于ResNet使用了BN操作,因此由于Batch 数据之间的交互,使得模型利用它欺骗预设任务从而简单的找到一个低损失的解决方案,然而这个解决方案效果并不好,使得模型学习不到好的表征能力。

其提出的Shuffling BN

首先,把所有进程的Tensor的收集起来(如果分布式训练,一般每个GPU包含一个进程,所以收集的数据总量大小为 num GPUs * batch size),参考这里🤖

x_gather = concat_all_gather(x)

接下来制作打乱的索引,整个过程如下所示

    def _batch_shuffle_ddp(self, x):"""Batch shuffle, for making use of BatchNorm.*** Only support DistributedDataParallel (DDP) model. ***"""# gather from all gpusbatch_size_this = x.shape[0]x_gather = concat_all_gather(x)  # 将所有进程的数据收集起来batch_size_all = x_gather.shape[0]num_gpus = batch_size_all // batch_size_this# random shuffle indexidx_shuffle = torch.randperm(batch_size_all).cuda()  # torch.randperm 将[0,n)数随机排列# broadcast to all gpustorch.distributed.broadcast(idx_shuffle, src=0)  # 将这个信息广播到所有其他进程# index for restoringidx_unshuffle = torch.argsort(idx_shuffle)  # 按照值大小顺序返回下标# shuffled index for this gpugpu_idx = torch.distributed.get_rank()  # 返回当前的进程idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]  # idx_shuffle view 后 (num_gpus, batch size) 但是batch size中的索引是打乱顺序的return x_gather[idx_this], idx_unshuffle

最终返回 随机打乱顺序后挑选的当前进程的 batch size 大小的数据,也就是说进行 BN归一化后的数据已经不在 同一个原来的批 中了。

4、k 送入 encoder k 中,在经过L2 归一化, 和q一样。  (N,C)

5、Shuffling BN 对齐 q 和 k

如下面举例

# idx_shuffle
tensor([10, 16, 13,  2,  4,  0,  6, 21, 22, 31, 29,  3, 19, 17, 14, 30, 28, 12,24, 26,  8, 25, 11, 18,  5,  7, 27,  1, 15, 23, 20,  9])# idx_unshuffle
tensor([ 5, 27,  3, 11,  4, 24,  6, 25, 20, 31,  0, 22, 17,  2, 14, 28,  1, 13,23, 12, 30,  7,  8, 29, 18, 21, 19, 26, 16, 10, 15,  9])# q 的 idx_this
tensor([10, 16, 13,  2,  4,  0,  6, 21])# k 的 idx_this
tensor([ 5, 27,  3, 11,  4, 24,  6, 25])

这里主要关注的点是 这步是为了使 k对齐打乱顺序的q。q之前是打乱了顺序从而改变了每个batch的内容,相当于从所有的batch中随机挑选了 batch size的q,从而保证去除BN的影响。

而 k 不需要 再打乱了, 只需要从原有的batch size 数据分布中挑选出与q对应的数据即可。所以才在 shuffle BN q的过程中记录了indx unshuffle。

这里的对应关系举例,比如 index shuffle 中的 0 现在位于原来没打乱状态的索引 5处, 类似的 1 -->27, 2-->3, 以此类推。

注:不要被上面单进程的(即idx this)不对齐所迷惑,上面的只是分进程处理的,分布式训练最终会把所有进程的数据拼接起来一起处理,所以所有进程的数据对齐就行。

6、计算损失,即文中公式1

其中 用到的计算方法举例如下,分别用爱因斯坦求和公式实现,参考这里🤖

a = torch.tensor([[1, 2, 3], [1, 1, 1], [2, 2, 2]])
b = torch.tensor([[2, 2, 2], [2, 2, 2], [1, 1, 1]])
print(a)
print(b)
c = torch.einsum("nc, nc->n", [a, b])  # (3)
d = c.unsqueeze(-1)  # (3,1)
print(c)#=== 输出
tensor([[1, 2, 3],[1, 1, 1],[2, 2, 3]])
tensor([[2, 2, 2],[2, 2, 2],[1, 1, 1]])
tensor([12,  6,  7])
tensor([[12],[ 6],[ 7]])
a = torch.tensor([[1, 2, 3], [1, 1, 1], [2, 2, 3]])  # (3,3)
a1 = torch.tensor([[1, 2], [1, 1], [2, 2]])  # (3,2)
c = torch.einsum("nc,ck->nk", [a, a1])
print(a)
print(a1)
print(c)# ===输出
tensor([[1, 2, 3],[1, 1, 1],[2, 2, 3]])
tensor([[1, 2],[1, 1],[2, 2]])
tensor([[ 9, 10],[ 4,  5],[10, 12]])

这里的self.queue 即文中的字典 queue,初始化为

self.register_buffer("queue", torch.randn(dim, K))
self.queue = nn.functional.normalize(self.queue, dim=0)

K为字典的长度,默认设置65536。这里为什么设置为这个可能是由于ImageNet数据集比较大,所以设置的字典比较长,具体的长度设置好像没有做固定的要求,

来源于github官网。但代码中有要求,K必须是batch size 的倍数,这个为了确保字典的更新,方便执行入栈和弹出操作。这个字典像是C++的 queue容器的FIFO数据结构,即先进先出

self.K % batch_size == 0
        l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)  #  (8,1)  对应元素相乘并第一维加和# negative logits: NxKl_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])  # (8,65536)  矩阵相乘# logits: Nx(1+K)logits = torch.cat([l_pos, l_neg], dim=1)  # (8,65537)# apply temperaturelogits /= self.Tlabels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()  # (8,)loss = criterion(output, target)

这里看标签都是0,即第一个也就是0维数据为正样本。因为在拼接cat的时候正样本是在前面的。

7、更新字典

按mini batch 更新。具体地,如果 训练次数*mini batch size 小于字典长度,则字典queue每次都会填充新的key。若训练次数*mini batch size 大于 字典长度,则之前的被替换掉。

ptr = (ptr + batch_size) % self.K  # move pointer  8

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/805282.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu 20.04.06 PCL C++学习记录(二十一)【切记使用rm * -rf前先确认是否是对应文件夹】

[TOC]PCL中点云分割模块的学习 学习背景 参考书籍:《点云库PCL从入门到精通》以及官方代码PCL官方代码链接,,PCL版本为1.10.0,CMake版本为3.16,测试点云下载地址 学习内容 根据欧几里得距离和需要保持的用户可自定义条件对点进…

在服务器部署MySQL

在服务器opt 新建文件夹 mysql/data,新建文件 mysql/conf.d/my.cnf 其中my.cnf 内容如下 [mysqld] log_timestampsSYSTEM default-time-zone8:00server-id1log-binmysql-binbinlog-do-db mall # 要监听的库binlog_formatROW 配置解读: ① server-id &…

【分析 GClog 的吞吐量和停顿时间、heapdump 内存泄漏分析】

文章目录 🔊博主介绍🥤本文内容GClog分析以优化吞吐量和停顿时间步骤1: 收集GClog步骤2: 分析GClog步骤3: 优化建议步骤4: 实施优化 Heapdump内存泄漏分析步骤1: 获取Heapdump步骤2: 分析Heapdump步骤3: 定位泄漏对象步骤4: 分析泄漏原因步骤5: 修复泄漏…

2024.4.3力扣每日一题——找出克隆二叉树中的相同节点

2024.4.3 题目来源我的题解方法一 深度优先搜索方法二 广度优先遍历 题目来源 力扣每日一题;题序:1379 我的题解 方法一 深度优先搜索 同时对二叉树 original 与 cloned 进行深度优先搜索,如果 original当前搜索的节点的引用等于 target 节…

预训练的启蒙:浅谈BERT、RoBERTa、ALBERT、T5

文章目录 Transformer揭开预训练序幕为什么RNN/LSTM需要从头训练? BERT核心特点预训练任务架构应用和影响 RoBERTa改进点BERT和RoBERTa的MASK策略对比BERT的静态MASK策略RoBERTa的动态MASK策略效果 总结 ALBERT改进点参数共享因式分解嵌入参数和LoRa对比 总结 T5核心…

Electron打包vue+java+nginx 踩坑记录

记录下遇到的问题: ⚠注意:64位系统和32位系统的配置不太一样 1、运行npm run packager失败 原因:在package.json没有对应命令 解决:在package.json 中添加对应命令,其中testApp是你想要的输入的项目名称&#xff0…

编程:不只是工作,是我生活的一部分

开篇 大家好,今天想聊聊我怎么把对编程的爱好变成了自己的饭碗。是的,我现在是个程序员,每天的工作就是和代码打交道。但说实话,这工作对我来说,不只是敲敲键盘那么简单,它是我对生活的一种态度&#xff0…

element用户上传头像组件带大图预览,和删除功能

element 用户上传组件不支持大图预览&#xff0c;拿组件的简单修改一些&#xff0c;发表上来主要是记一下&#xff0c;以后可以用 效果图 <template><div class"flex-img"><div class"el-upload-list el-upload-list--picture-card" v-sh…

word从零基础到高手【办公】

第1课 - word基础操作快速入门第2课 - 让你效率10倍提升的快捷操作第3课 - word排版快速入门第4课 - 排版实战案例讲解第5课 - 搞定论文排版全过程第6课 - 让你的word更强大的神技第7课 - 提高工作效率必备的批量操作 资料截图如下: 发送: "word办公" 获取提取码

动态规划-入门理解

一、什么情况可以使用动态规划 动态规划 最优子结构 重叠子问题 转移方程 最优子结构&#xff1a;保证能从局部解推出全局解&#xff0c;也就是保证能够写出转移方程 重叠子问题&#xff1a;说明暴力解法太耗时&#xff0c;我们可以使用动态规划进行优化 转移方程&#xff…

基于GAN的图像补全实战

数据与代码地址见文末 论文地址:http://iizuka.cs.tsukuba.ac.jp/projects/completion/data/completion_sig2017.pdf 1.概述 图像补全,即补全图像中的覆盖和缺失部分, 网络整体结构如下图所示,整体网络结构还是采取GAN,对于生成器,网络结构采取Unet的形式,首先使用卷积…

深入浅出 -- 系统架构之负载均衡Nginx跨域配置

一、Nginx跨域配置 跨域问题在之前的单体架构开发中&#xff0c;其实是比较少见的问题&#xff0c;除非是需要接入第三方SDK时&#xff0c;才需要处理此问题。但随着现在前后端分离、分布式架构的流行&#xff0c;跨域问题也成为了每个Java开发必须要懂得解决的一个问题。 跨域…

rac数据库默认网关不通导致集群异常

集群CSSD进程reconfiguration完成&#xff0c;显示2个节点都在线。但ora.net1.network服务启动失败&#xff0c;且有依赖关系的资源随后启动失败并且已经达到上限。 查看两个节点的网络信息&#xff0c;发现两个节点的默认网关是不一致的。 修改故障节点网关 在RAC中&#xff0…

基于springboot+vue+Mysql的职称评审管理系统

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…

在线人数统计功能怎么实现?

一、前言 大家好&#xff01;我是sum墨&#xff0c;一个一线的底层码农&#xff0c;平时喜欢研究和思考一些技术相关的问题并整理成文&#xff0c;限于本人水平&#xff0c;如果文章和代码有表述不当之处&#xff0c;还请不吝赐教。 在线人数统计这个功能相信大家一眼就明白是…

基于奇异值分解(Singular Value Decomposition,SVD)的信号去噪算法

01.基于奇异值分解(SVD)去噪原理 奇异值分解&#xff08;Singular Value Decomposition, SVD&#xff09;是线性代数中一种重要的矩阵分解方法&#xff0c;它可以用于信号处理、图像去噪、数据压缩等多种应用。在图像去噪的过程中&#xff0c;SVD可以用来分离图像中的信号和噪…

Transformer详解和知识点总结

目录 1. 注意力机制1.1 注意力评分函数1.2 多头注意力&#xff08;Multi-head self-attention&#xff09; 2. Layer norm3. 模型结构4. Attention在Transformer中三种形式的应用 论文&#xff1a;https://arxiv.org/abs/1706.03762 李沐B站视频&#xff1a;https://www.bilibi…

SpringBoot Starter子模块下无法生成spring-configuration-metadata.json文件

一.SpringBoot Starter的作用 Starter的机制极大的方便了业务系统接入相关能力&#xff0c;它有一个非常友好的能力就是引入starter后&#xff0c;在配置相关的配置项时&#xff0c;能自动提示&#xff0c;极大的提升了使用的友好度。 二.遇到的问题 我在为Juggle开发系统star…

【图论】链式前向星实现图的BFS搜索

&#x1f4ab;【图论】链式前向星–BFS宽搜遍历 &#x1f44f;宽搜背景和实现的功能 输入: n m n:结点数量m:边的数量 输出:到达结点编号为n的最短路径, 每条路长度为1(宽度搜索的前前提条件) &#x1f914;思路&#xff1a; 采用链式前向星存图数组模拟队列的方法只要队列不…

[C++/Linux] Linux线程详解

目录 一.什么是线程&#xff1f; 并发&#xff08;Concurrency&#xff09; 并行&#xff08;Parallelism&#xff09; 1.1 线程的概念 1.2 线程的基本函数 1.3 线程的基本使用例子&#xff1a; 二.线程的属性 2.1线程属性使用例子 三.线程互斥 3.1互斥锁 3.2互斥锁常用函…