Pytorch实现扩散模型【DDPM代码解读篇1】

本篇内容属于对DDPM 原理-代码 项目的解读。

具体内容参考一篇推文,里面对DDPM讲解相对细致:

扩散模型的原理及实现(Pytorch)

下面主要是对其中源码的细致注解,帮助有需要的朋友更好理解代码。

目录

ConvNext块

 正弦时间戳嵌入

时间多层感知器

注意力

整合


ConvNext块

class ConvNextBlock(nn.Module):def __init__(self,in_channels,out_channels,mult=2,  # 输出通道数相对于输入通道数的倍数time_embedding_dim=None,  # 表示时间嵌入的维度norm=True,  # 是否使用归一化层group=8,  # 卷积操作的分组数量,默认为8):super().__init__()# 多层感知机(MLP),用于处理时间嵌入。# 如果time_embedding_dim不为None,则创建一个包含GELU激活函数和线性层的序列;否则为None。self.mlp = (nn.Sequential(nn.GELU(), nn.Linear(time_embedding_dim, in_channels))if time_embedding_dimelse None)# in_conv是一个输入卷积层,对输入进行卷积操作。self.in_conv = nn.Conv2d(in_channels, in_channels, 7, padding=3, groups=in_channels)# block是一个序列模块,包含一系列卷积操作和归一化层。这里使用了GELU作为激活函数。self.block = nn.Sequential(nn.GroupNorm(1, in_channels) if norm else nn.Identity(),nn.Conv2d(in_channels, out_channels * mult, 3, padding=1),nn.GELU(),nn.GroupNorm(1, out_channels * mult),nn.Conv2d(out_channels * mult, out_channels, 3, padding=1),)# residual_conv是一个残差连接的卷积层,用于调整输入和输出的通道数,# 如果输入通道数和输出通道数不同,则使用1x1卷积进行调整;否则为恒等映射。self.residual_conv = (nn.Conv2d(in_channels, out_channels, 1)if in_channels != out_channelselse nn.Identity())def forward(self, x, time_embedding=None):h = self.in_conv(x)  # 首先对输入x进行输入卷积操作。# 如果mlp不为None且time_embedding不为None,则对时间嵌入进行处理并与输入相加。if self.mlp is not None and time_embedding is not None:assert self.mlp is not None, "MLP is None"h = h + rearrange(self.mlp(time_embedding), "b c -> b c 1 1")  # 然后将处理后的特征输入到块中进行卷积操作。h = self.block(h)return h + self.residual_conv(x)  # 最后将卷积结果与输入进行残差连接,并返回。

 正弦时间戳嵌入

# 通常用于为序列数据添加位置信息。
class SinusoidalPosEmb(nn.Module):def __init__(self, dim, theta=10000):super().__init__()self.dim = dim  # 位置编码的维度。self.theta = theta  # theta是用于计算位置编码的参数,默认值为10000。def forward(self, x):device = x.device  # 首先获取输入x的设备信息。half_dim = self.dim // 2  # 然后计算位置编码的维度一半的值half_dimemb = math.log(self.theta) / (half_dim - 1)emb = torch.exp(torch.arange(half_dim, device=device) * -emb)  # 生成位置编码矩阵emb,其中每一行对应一个位置的编码,使用正弦和余弦函数计算。emb = x[:, None] * emb[None, :]emb = torch.cat((emb.sin(), emb.cos()), dim=-1)return emb# DownSample & UpSample 上下采样
class DownSample(nn.Module):def __init__(self, dim, dim_out=None):super().__init__()self.net = nn.Sequential(# 用于对输入进行重新排列,将2x2的空间块转换为通道数的维度,从而将空间维度减小为原来的四分之一。'''Rearrange层用于对输入的张量进行重新排列,将其从四维张量(batch size、通道数、高度、宽度)转换为新的形状b:表示batch size,保持不变。c:表示通道数,保持不变。(h p1)和(w p2):表示对高度和宽度进行的操作。p1和p2是两个额外的参数,用于指定在高度和宽度上的扩展倍数。这意味着将输入的高度和宽度分别扩展为原来的p1倍和p2倍。b (c p1 p2) h w:表示输出张量的形状,其中通道数乘以高度和宽度。这样做的效果是将原始的空间维度拼接到通道维度后面,使得输出的张量变为三维(batch size、新的通道数、高度、宽度)。'''Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),# 接着是一个1x1的卷积层,用于将输入通道数变换为dim_out或者保持不变nn.Conv2d(dim * 4, default(dim_out, dim), 1),)def forward(self, x):return self.net(x)  # 将输入通过Sequential网络进行前向传播,返回处理后的结果。class Upsample(nn.Module):def __init__(self, dim, dim_out=None):super().__init__()self.net = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"),  # 用于对输入进行上采样,采用最近邻插值的方式,并将图像沿着两个维度放大两倍。nn.Conv2d(dim, dim_out or dim, kernel_size=3, padding=1), # 接着是一个3x3的卷积层,用于将输入通道数变换为dim_out或者保持不变)def forward(self, x):return self.net(x)

时间多层感知器

sinu_pos_emb = SinusoidalPosEmb(dim, theta=10000)  # 用于生成正弦和余弦位置编码time_dim = dim * 4  # 四倍维度,通常用于增加时间信息的表示能力time_mlp = nn.Sequential(sinu_pos_emb,  # 将输入的时间信息进行正弦和余弦位置编码nn.Linear(dim, time_dim),  # 一个线性层,将输入维度dim映射为time_dim,以增加时间信息的表示能力。nn.GELU(),  # 激活函数,用于引入非线性。nn.Linear(time_dim, time_dim),  # 另一个线性层,将time_dim映射回time_dim,以保持输出维度不变。)

注意力

class BlockAttention(nn.Module):# gate_in_channel:门输入的通道数, residual_in_channel:残差输入的通道数, scale_factor:尺度因子,用于初始化门和残差卷积层的权重。def __init__(self, gate_in_channel, residual_in_channel, scale_factor):super().__init__()self.gate_conv = nn.Conv2d(gate_in_channel, gate_in_channel, kernel_size=1, stride=1)self.residual_conv = nn.Conv2d(residual_in_channel, gate_in_channel, kernel_size=1, stride=1)self.in_conv = nn.Conv2d(gate_in_channel, 1, kernel_size=1, stride=1)  # 输入卷积层,将门和残差的输出进行卷积处理,将结果映射为范围在0到1之间的值self.relu = nn.ReLU()self.sigmoid = nn.Sigmoid()# 前向传播方法接受两个张量作为输入:x表示残差输入,g表示门输入。def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:in_attention = self.relu(self.gate_conv(g) + self.residual_conv(x))in_attention = self.in_conv(in_attention)in_attention = self.sigmoid(in_attention)return in_attention * x

整合

将前面讨论的所有块(不包括注意力块)整合到一个Unet中。
每个块都包含两个残差连接,而不是一个。
这个修改是为了解决潜在的过度拟合问题。

# 这个模块实现了一个双重残差 U 型网络,用于图像处理任务,如图像去噪、超分辨率等。
class TwoResUNet(nn.Module):def __init__(self,dim,init_dim=None,out_dim=None,dim_mults=(1, 2, 4, 8),channels=3,sinusoidal_pos_emb_theta=10000,convnext_block_groups=8,  # 卷积块中的分组数):super().__init__()self.channels = channelsinput_channels = channels# init_dim 不为 None,则返回 init_dim;否则返回 dim。这样做的目的是提供了一种灵活的方式,允许用户在初始化模型时选择是否指定初始维度,如果未指定,则使用输入的维度作为初始维度。self.init_dim = default(init_dim, dim)self.init_conv = nn.Conv2d(input_channels, self.init_dim, 7, padding=3)dims = [self.init_dim, *map(lambda m: dim * m, dim_mults)]'''使用 map 函数对 dim_mults 中的每个值 m 进行操作,将其乘以 dim。这样可以得到一个新的列表,其中的每个值都是 dim 与 dim_mults 中的相应值相乘得到的结果。* 运算符用于解包操作,将 map 函数生成的结果解包成单独的元素。'''in_out = list(zip(dims[:-1], dims[1:]))# 使用 zip 函数将 dims 中的相邻两个元素组合成一个元组。这样可以得到一个列表,其中每个元素都是一个包含相邻两个阶段的输入通道数和输出通道数的元组。sinu_pos_emb = SinusoidalPosEmb(dim, theta=sinusoidal_pos_emb_theta)time_dim = dim * 4self.time_mlp = nn.Sequential(sinu_pos_emb,nn.Linear(dim, time_dim),nn.GELU(),nn.Linear(time_dim, time_dim),)self.downs = nn.ModuleList([])self.ups = nn.ModuleList([])num_resolutions = len(in_out)  # 计算了图像的分辨率数,存储在 num_resolutions 中# 下面的循环用于创建下采样部分(downs):for ind, (dim_in, dim_out) in enumerate(in_out):is_last = ind >= (num_resolutions - 1)self.downs.append(nn.ModuleList([ConvNextBlock(in_channels=dim_in,out_channels=dim_in,time_embedding_dim=time_dim,group=convnext_block_groups,),ConvNextBlock(in_channels=dim_in,out_channels=dim_in,time_embedding_dim=time_dim,group=convnext_block_groups,),DownSample(dim_in, dim_out)if not is_lastelse nn.Conv2d(dim_in, dim_out, 3, padding=1),]))# 创建了中间残差块 mid_block1 和 mid_block2:mid_dim = dims[-1]  # 通常会将最后一个阶段的输出通道数作为中间残差块的输入通道数self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, time_embedding_dim=time_dim)self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, time_embedding_dim=time_dim)# 下面的循环用于创建上采样部分(ups):for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):is_last = ind == (len(in_out) - 1)is_first = ind == 0self.ups.append(nn.ModuleList([ConvNextBlock(in_channels=dim_out + dim_in,out_channels=dim_out,time_embedding_dim=time_dim,group=convnext_block_groups,),ConvNextBlock(in_channels=dim_out + dim_in,out_channels=dim_out,time_embedding_dim=time_dim,group=convnext_block_groups,),Upsample(dim_out, dim_in)if not is_lastelse nn.Conv2d(dim_out, dim_in, 3, padding=1)]))default_out_dim = channelsself.out_dim = default(out_dim, default_out_dim)# 创建了最终的残差块 final_res_block 和输出卷积层 final_conv:self.final_res_block = ConvNextBlock(dim * 2, dim, time_embedding_dim=time_dim)self.final_conv = nn.Conv2d(dim, self.out_dim, 1)def forward(self, x, time):b, _, h, w = x.shapex = self.init_conv(x)  # 对输入张量进行初始卷积操作r = x.clone()  # 克隆 xt = self.time_mlp(time)  # 使用时间多层感知器处理时间信息unet_stack = []  # 创建一个空列表 unet_stack,用于存放下采样阶段的特征。# 对每个下采样模块进行操作:先执行两个卷积块,然后执行下采样,并将特征存储在 unet_stack 中。for down1, down2, downsample in self.downs:x = down1(x, t)unet_stack.append(x)x = down2(x, t)unet_stack.append(x)x = downsample(x)# 中间残差块x = self.mid_block1(x, t)x = self.mid_block2(x, t)# 对每个上采样模块进行操作:从 unet_stack 中取出特征,与当前特征拼接后执行两个卷积块,然后执行上采样。for up1, up2, upsample in self.ups:x = torch.cat((x, unet_stack.pop()), dim=1)x = up1(x, t)x = torch.cat((x, unet_stack.pop()), dim=1)x = up2(x, t)x = upsample(x)# 将初始特征 r 与最终的特征拼接后,执行最终的残差块和输出卷积层,得到最终的输出。x = torch.cat((x, r), dim=1)x = self.final_res_block(x, t)return self.final_conv(x)

 Life is a journey. We pursue love and light with purity.

你的 “三连” 是小曦持续更新的动力!
下期将推出
扩散的代码实现,零距离解读扩散是如何实现的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/831905.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

IoTDB 入门教程 基础篇⑦——数据库管理工具 | DBeaver 连接 IoTDB

文章目录 一、前文二、下载iotdb-jdbc三、安装DBeaver3.1 DBeaver 下载3.2 DBeaver 安装 四、安装驱动五、连接数据库六、参考 一、前文 IoTDB入门教程——导读 二、下载iotdb-jdbc 下载地址org/apache/iotdb/iotdb-jdbc:https://maven.proxy.ustclug.org/maven2/o…

stamps做sbas-insar,时序沉降图怎么画?

🏆本文收录于「Bug调优」专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收藏&&…

【人工智能Ⅱ】实验5:自然语言处理实践(情感分类)

实验5:自然语言处理实践(情感分类) 一:实验目的与要求 1:掌握RNN、LSTM、GRU的原理。 2:学习用RNN、LSTM、GRU网络建立训练模型,并对模型进行评估。 3:学习用RNN、LSTM、GRU网络做…

AST原理(反混淆)

一、AST原理 jscode var a "\u0068\u0065\u006c\u006c\u006f\u002c\u0041\u0053\u0054";在上述代码中,a 是一个变量,它被赋值为一个由 Unicode 转义序列组成的字符串。Unicode 转义序列在 JavaScript 中以 \u 开头,后跟四个十六进…

Python学习笔记------json

json简介 JSON是一种轻量级的数据交互格式。可以按照JSON指定的格式去组织和封装数据 JSON本质上是一个带有特定格式的字符串 主要功能:json就是一种在各个编程语言中流通的数据格式,负责不同编程语言中的数据传递和交互 为了让不同的语言能够相互通…

《LTC与铁三角∶从线索到回款-人民邮电》关于铁三角不错的论述

《LTC与铁三角∶从线索到回款-人民邮电》一书中,关于铁三角不错的论述,收藏之:客户责任人的角色定义及核心价值 AR 的核心价值定位主要体现在三个方面:客户关系、 客户满意度、竞争对手 “ 压制 ” 。 维护客户关系&#x…

百川2模型解读

简介 Baichuan 2是多语言大模型,目前开源了70亿和130亿参数规模的模型。在公开基准如MMLU、CMMLU、GSM8K和HumanEval上的评测,Baichuan 2达到或超过了其他同类开源模型,并在医学和法律等垂直领域表现优异。此外,官方还发布所有预…

[数据结构]————排序总结——插入排序(直接排序和希尔排序)—选择排序(选择排序和堆排序)-交换排序(冒泡排序和快速排序)—归并排序(归并排序)

文章涉及具体代码gitee: 登录 - Gitee.com 目录 1.插入排序 1.直接插入排序 总结 2.希尔排序 总结 2.选择排序 1.选择排序 ​编辑 总结 2.堆排序 总结 3.交换排序 1.冒泡排序 总结 2.快速排序 总结 4.归并排序 总结 5.总的分析总结 1.插入排…

Unity---版本控制软件

13.3 版本控制——Git-1_哔哩哔哩_bilibili Git用的比较多 Git 常用Linux命令 pwd:显示当前所在路径 ls:显示当前路径下的所有文件 tab键自动补全 cd:切换路径 mkdir:在当前路径下创建一个文件夹 clear:清屏 vim…

Linux的socket详解

一、本机直接的进程通信方式 管道(Pipes): 匿名管道(Anonymous pipes):通常用于父子进程间的通信,它是单向的。命名管道(Named pipes,也称FIFO):允…

微星主板安装双系统不能进入Ubuntu的解决办法

在微星主板的台式机上面依次安装了Windows11和Ubuntu22.04。在Ubuntu安装完成后重启,没有出现系统选择界面,直接进入了Windows11。怎么解决?方法如下: (1)正常安装Windows11 (2)安…

《自动机理论、语言和计算导论》阅读笔记:p352-P401

《自动机理论、语言和计算导论》学习第 12 天,p352-P401总结,总计 50 页。 一、技术总结 1.Turing Machine ™ 2.undecidability ​ a.Ld(the diagonalization language) 3.reduction p392, In general, if we have an algorithm to convert insta…

Git系列:config 配置

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

Java中的枚举类型介绍

一、背景及定义 情景: 枚举是在JDK1.5以后引入的。 主要用途是: 将一组常量组织起来,在这之前表示一组常量通常使用定义常量的方式: 这种定义方式实际上并不好。 例如:如果碰巧有另一个变量也是1,那么…

笔记85:如何计算递归算法的“时间复杂度”和空间复杂度?

先上公式: 递归算法的时间复杂度 递归次数 x 每次递归消耗的时间颗粒数递归算法的空间复杂度 递归深度 x 每次递归消耗的内存空间大小 注意: 时间复杂度指的是在执行这一段程序的时候,所花费的全部的时间,即时间的总和而空间复…

以太网基础-IP、ICMP、ARP协议

一、IP协议 参考:rfc791.txt.pdf (rfc-editor.org) IP协议(Internet Protocol)是TCP/IP协议族中最核心的协议,提供不可靠的、无连接的、尽力而为的数据报传输服务。 IP报文数据头如下 Version:4bit,4表示…

网络模型与调试

网络模型 网络的体系结构 ● 网络采用分而治之的方法设计,将网络的功能划分为不同的模块,以分层的形式有机组合在一起。 ● 每层实现不同的功能,其内部实现方法对外部其他层次来说是透明的。每层向上层提供服务,同时使用下层提供…

Elasticsearch:如何使用 Java 对索引进行 ES|QL 的查询

在我之前的文章 “Elasticsearch:对 Java 对象的 ES|QL 查询”,我详细介绍了如何使用 Java 来对 ES|QL 进行查询。对于不是很熟悉 Elasticsearch 的开发者来说,那篇文章里的例子还是不能单独来进行运行。在今天的这篇文章中,我来详…

MySQL CRUD进阶

前言👀~ 上一章我们介绍了CRUD的一些基础操作,关于如何在表里进行增加记录、查询记录、修改记录以及删除记录的一些基础操作,今天我们学习CRUD(增删改查)进阶操作 如果各位对文章的内容感兴趣的话,请点点小…

【网络编程下】五种网络IO模型

目录 前言 一.I/O基本概念 1.同步和异步 2.阻塞和非阻塞 二.五种网络I/O模型 1.阻塞I/O模型 2.非阻塞式I/O模型 ​编辑 3.多路复用 4.信号驱动式I/O模型 5. 异步I/O模型 三.五种I/O模型比较​编辑 六.I/O代码示例 1. 阻塞IO 2.非阻塞I/O 3.多路复用 (1)select …