【Diffusion实战】训练一个diffusion模型生成蝴蝶图像(Pytorch代码详解)

  上一篇Diffusion实战是确确实实一步一步走的公式,这回采用一个更方便的库:diffusers,来实现Diffusion模型训练。


Diffusion实战篇:
  【Diffusion实战】训练一个diffusion模型生成S曲线(Pytorch代码详解)
Diffusion综述篇:
  【Diffusion综述】医学图像分析中的扩散模型(一)
  【Diffusion综述】医学图像分析中的扩散模型(二)


0、所需安装

pip install diffusers  # diffusers库
pip install datasets  

1、数据集下载

  下载地址:蝴蝶数据集
  下载好后的文件夹中包括以下文件,放在当前目录下就可以了。

在这里插入图片描述
加载数据集,并对一批数据进行可视化:

import torch
import torchvision
from datasets import load_dataset
from torchvision import transforms
import numpy as np
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Imagedef show_images(x):"""Given a batch of images x, make a grid and convert to PIL"""x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)grid = torchvision.utils.make_grid(x)grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))return grid_imdef transform(examples):images = [preprocess(image.convert("RGB")) for image in examples["image"]]return {"images": images}device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)# 数据加载
dataset = load_dataset("./smithsonian_butterflies_subset", split='train')image_size = 32
batch_size = 64# 数据增强
preprocess = transforms.Compose([transforms.Resize((image_size, image_size)),  # Resizetransforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)transforms.ToTensor(),  # Convert to tensor (0, 1)transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)]
)dataset.set_transform(transform)# 数据装载
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)# 抽取一批数据可视化
xb = next(iter(train_dataloader))["images"].to(device)[:8]
print("X shape:", xb.shape)
show_images(xb).resize((8 * 64, 64), resample=Image.NEAREST)

输出可视化结果:

在这里插入图片描述


2、加噪调度器

  即DDPM论文中需要预定义的 β t {\beta_t } βt ,可使用DDPMScheduler类来定义,其中num_train_timesteps参数为时间步 t {t} t

from diffusers import DDPMScheduler# βt值
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)plt.figure(dpi=300)
plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
plt.legend(fontsize="x-large");

根据定义的 β t {\beta_t } βt ,可视化 α ˉ t {\sqrt {{{\bar \alpha }_t}}} αˉt 1 − α ˉ t {\sqrt {1 - {{\bar \alpha }_t}}} 1αˉt

在这里插入图片描述

  通过设置beta_start、beta_end和beta_schedule三个参数来控制噪声调度器的超参数 β t {\beta_t } βt

noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.001, beta_end=0.004)

在这里插入图片描述

  beta_schedule可以通过一个函数映射来为模型推理的每一步生成一个 β t {\beta_t } βt值。

noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

在这里插入图片描述

x t = α ˉ t x 0 + 1 − α ˉ t ε {{x_t} = \sqrt {{{\bar \alpha }_t}} {x_0} + \sqrt {1 - {{\bar \alpha }_t}} \varepsilon } xt=αˉt x0+1αˉt ε 加噪前向过程可视化:

timesteps = torch.linspace(0, 999, 8).long().to(device)  # 随机采样时间步
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)  # 加噪
print("Noisy X shape", noisy_xb.shape)
show_images(noisy_xb).resize((8 * 64, 64), resample=Image.NEAREST)

输出为:

在这里插入图片描述


3、扩散模型定义

  diffusers库中模型的定义也非常简洁:

# 创建模型
from diffusers import UNet2DModelmodel = UNet2DModel(sample_size=image_size,  # the target image resolutionin_channels=3,  # the number of input channels, 3 for RGB imagesout_channels=3,  # the number of output channelslayers_per_block=2,  # how many ResNet layers to use per UNet blockblock_out_channels=(64, 128, 128, 256),  # More channels -> more parametersdown_block_types=("DownBlock2D",  # a regular ResNet downsampling block"DownBlock2D","AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention"AttnDownBlock2D",),up_block_types=("AttnUpBlock2D","AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention"UpBlock2D","UpBlock2D",  # a regular ResNet upsampling block),
)model.to(device)
with torch.no_grad():model_prediction = model(noisy_xb, timesteps).sample
model_prediction.shape  # 验证输出与输出尺寸相同

4、扩散模型训练

  定义优化器,和传统模型一样的训练写法:

# 定义噪声调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
)# 优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)losses = []for epoch in range(30):for step, batch in enumerate(train_dataloader):clean_images = batch["images"].to(device)# 为图像添加随机噪声noise = torch.randn(clean_images.shape).to(clean_images.device)  # epsbs = clean_images.shape[0]# 为每一张图像随机选择一个时间步timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()  # 根据时间步,向清晰的图像中加噪声, 前向过程:根号下αt^ * x0 + 根号下(1-αt^) * epsnoisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)# 获得模型预测结果noise_pred = model(noisy_images, timesteps, return_dict=False)[0]# 计算损失, 损失回传loss = F.mse_loss(noise_pred, noise)  loss.backward(loss)losses.append(loss.item())# 更新模型参数optimizer.step()optimizer.zero_grad()if (epoch + 1) % 5 == 0:loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")

30个epoch训练过程如下所示:

在这里插入图片描述

可用以下代码查看损失曲线:

# 损失曲线可视化
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))  # 对数坐标
plt.show()

损失曲线可视化:

在这里插入图片描述


5、图像生成

  (1)通过建立pipeline生成图像:

# 图像生成
# 方法一:建立一个pipeline, 打包模型和噪声调度器
from diffusers import DDPMPipeline
image_pipe = DDPMPipeline(unet=model, scheduler=noise_scheduler)pipeline_output = image_pipe()
plt.figure()
plt.imshow(pipeline_output.images[0])
plt.axis('off')
plt.show()# 保存pipeline
image_pipe.save_pretrained("my_pipeline")  # 在当前目录下保存了一个 my_pipeline 的文件夹

生成的蝴蝶图像如下:

在这里插入图片描述

生成的my_pipeline文件夹如下:

在这里插入图片描述

  (2)通过随机采样循环生成图像:

# 方法二:模型调用, 写采样循环 
# 随机初始化8张图像:
sample = torch.randn(8, 3, 32, 32).to(device)for i, t in enumerate(noise_scheduler.timesteps):# 获得模型预测结果with torch.no_grad():residual = model(sample, t).sample# 根据预测结果更新图像sample = noise_scheduler.step(residual, t, sample).prev_sampleshow_images(sample)

8张生成图像如下:
在这里插入图片描述


6、代码汇总

import torch
import torchvision
from datasets import load_dataset
from torchvision import transforms
import numpy as np
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Imagedef show_images(x):"""Given a batch of images x, make a grid and convert to PIL"""x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)grid = torchvision.utils.make_grid(x)grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))return grid_imdef transform(examples):images = [preprocess(image.convert("RGB")) for image in examples["image"]]return {"images": images}# --------------------------------------------------------------------------------
# 1、数据集加载与可视化
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)# 数据加载
dataset = load_dataset("./smithsonian_butterflies_subset", split='train')image_size = 32
batch_size = 64# 数据增强
preprocess = transforms.Compose([transforms.Resize((image_size, image_size)),  # Resizetransforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)transforms.ToTensor(),  # Convert to tensor (0, 1)transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)]
)dataset.set_transform(transform)# 数据装载
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# --------------------------------------------------------------------------------# --------------------------------------------------------------------------------
# 抽取一批数据可视化
xb = next(iter(train_dataloader))["images"].to(device)[:8]
print("X shape:", xb.shape)
show_images(xb).resize((8 * 64, 64), resample=Image.NEAREST)
# --------------------------------------------------------------------------------# --------------------------------------------------------------------------------
# 2、噪声调度器
from diffusers import DDPMScheduler# 加噪声的系数βt
# noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
# noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.001, beta_end=0.004)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')plt.figure(dpi=300)
plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
plt.legend(fontsize="x-large");
# --------------------------------------------------------------------------------# --------------------------------------------------------------------------------
# 加噪声可视化
timesteps = torch.linspace(0, 999, 8).long().to(device)  # 随机采样时间步
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)  # 加噪
print("Noisy X shape", noisy_xb.shape)
show_images(noisy_xb).resize((8 * 64, 64), resample=Image.NEAREST)
# --------------------------------------------------------------------------------# --------------------------------------------------------------------------------
# 3、创建模型
from diffusers import UNet2DModelmodel = UNet2DModel(sample_size=image_size,  # the target image resolutionin_channels=3,  # the number of input channels, 3 for RGB imagesout_channels=3,  # the number of output channelslayers_per_block=2,  # how many ResNet layers to use per UNet blockblock_out_channels=(64, 128, 128, 256),  # More channels -> more parametersdown_block_types=("DownBlock2D",  # a regular ResNet downsampling block"DownBlock2D","AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention"AttnDownBlock2D",),up_block_types=("AttnUpBlock2D","AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention"UpBlock2D","UpBlock2D",  # a regular ResNet upsampling block),
)model.to(device)
with torch.no_grad():model_prediction = model(noisy_xb, timesteps).sample
model_prediction.shape  # 验证输出与输出尺寸相同
# --------------------------------------------------------------------------------# --------------------------------------------------------------------------------
# 4、扩散模型训练
# 定义噪声调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
)# 优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)losses = []for epoch in range(30):for step, batch in enumerate(train_dataloader):clean_images = batch["images"].to(device)# 为图像添加随机噪声noise = torch.randn(clean_images.shape).to(clean_images.device)  # epsbs = clean_images.shape[0]# 为每一张图像随机选择一个时间步timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()  # 根据时间步,向清晰的图像中加噪声, 前向过程:根号下αt^ * x0 + 根号下(1-αt^) * epsnoisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)# 获得模型预测结果noise_pred = model(noisy_images, timesteps, return_dict=False)[0]# 计算损失, 损失回传loss = F.mse_loss(noise_pred, noise)  loss.backward(loss)losses.append(loss.item())# 更新模型参数optimizer.step()optimizer.zero_grad()if (epoch + 1) % 5 == 0:loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")
# --------------------------------------------------------------------------------# --------------------------------------------------------------------------------
# 损失曲线可视化
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))  # 对数坐标
plt.show()
# --------------------------------------------------------------------------------# --------------------------------------------------------------------------------
# 5、图像生成
# 方法一:建立一个pipeline, 打包模型和噪声调度器
from diffusers import DDPMPipeline
image_pipe = DDPMPipeline(unet=model, scheduler=noise_scheduler)pipeline_output = image_pipe()plt.figure()
plt.imshow(pipeline_output.images[0])
plt.axis('off')
plt.show()image_pipe.save_pretrained("my_pipeline")  # 在当前目录下保存了一个 my_pipeline 的文件夹# 方法二:模型调用, 写采样循环 
# 随机初始化8张图像:
sample = torch.randn(8, 3, 32, 32).to(device)for i, t in enumerate(noise_scheduler.timesteps):# 获得模型预测结果with torch.no_grad():residual = model(sample, t).sample# 根据预测结果更新图像sample = noise_scheduler.step(residual, t, sample).prev_sampleshow_images(sample)grid_im = show_images(sample).resize((8 * 64, 64), resample=Image.NEAREST)
plt.figure(dpi=300)
plt.imshow(grid_im)
plt.axis('off')
plt.show()
# --------------------------------------------------------------------------------

  参考资料:扩散模型从原理到实践. 人民邮电出版社. 李忻玮, 苏步升等.

  diffusers确实很方便使用,有点子PyCaret的感觉了~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/829380.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

web题目实操 5(备份文件和关于MD5($pass,true)注入的学习)

1.[ACTF2020 新生赛]BackupFile (1)打开页面后根据提示是备份文件 (2)查看源码发现啥都没有 (3)这里啊直接用工具扫描,可以扫描到一个文件名为:/index.php.bak的文件 (…

ArcGIS Pro 和 Python — 分析全球主要城市中心的土地覆盖变化

第一步——设置工作环境 1–0. 地理数据库 在下载任何数据之前,我将创建几个地理数据库,在其中保存和存储所有数据以及我将创建的后续图层。将为我要分析的五个城市中的每一个创建一个地理数据库,并将其命名为: “Phoenix.gdb” “Singapore.gdb” “Berlin.gdb” “B…

安卓悬浮窗权限检查

目录 悬浮窗权限代码检测悬浮窗功能 悬浮窗权限 请求了这个权限后&#xff0c;app的权限管理中会有「显示悬浮窗」的权限选项。后面会引导用户去开启这个权限。 <uses-permission android:name"android.permission.SYSTEM_ALERT_WINDOW" />代码检测悬浮窗功能…

Windows系统下将MySQL数据库表内的数据全量导入Elasticsearch

目录 下载安装Logstash 配置Logstash配置文件 运行配置文件 查看导入结果 使用Logstash将sql数据导入Elasticsearch 下载安装Logstash 官网地址 选择Windows系统&#xff0c;需下载与安装的Elasticsearch相同版本的&#xff0c;下载完成后解压安装包。 配置Logstash配…

贪吃蛇大作战【纯c语言】

如果有看到不懂的地方或者对c语言某些知识忘了的话&#xff0c;可以找我之前的文章哦&#xff01;&#xff01;&#xff01; 个人主页&#xff1a;小八哥向前冲~-CSDN博客 所属专栏&#xff1a;c语言_小八哥向前冲~的博客-CSDN博客 贪吃蛇游戏演示&#xff1a; 贪吃蛇游戏动画演…

第一阶段--Day2--信息安全法律法规、网络安全相关标准

目录 1. 针对信息安全的规定 2. 网络安全相关标准 1. 针对信息安全的规定 《中华人民共和国计算机信息系统安全保护条例》1994年2月18日颁布并实施 中华人民共和国计算机信息系统安全保护条例__增刊20111国务院公报_中国政府网 《中华人民共和国国际联网安全保护管理…

笔记:编写程序,分别采用面向对象和 pyplot 快捷函数的方式绘制正弦曲线 和余弦曲线。 提示:使用 sin()或 cos()函数生成正弦值或余弦值。

文章目录 前言一、面向对象和 pyplot 快捷函数的方式是什么&#xff1f;二、编写代码面向对象的方法&#xff1a;使用 pyplot 快捷函数的方法&#xff1a; 总结 前言 本文将探讨如何使用编程语言编写程序&#xff0c;通过两种不同的方法绘制正弦曲线和余弦曲线。我们将分别采用…

图像处理ASIC设计方法 笔记18 轮廓跟踪算法的硬件加速方案

目录 1排除伪孤立点(断裂链表)方法1 限制链表的长度方法2 增加判断条件排除断裂链表方法3 排除不必要跟踪的轮廓(推荐用这个方法)P129 轮廓跟踪算法的硬件加速方案 1排除伪孤立点(断裂链表) 如果图像中某区域存在相邻像素之间仅有对角连接的部位,则对包围该区域的像素…

SOLIDWORKS Electrical 3D--精准的三维布线

相信很多工程师在实际生产的时候都会遇到线材长度不准确的问题&#xff0c;从而导致线材浪费甚至整根线材报废的问题&#xff0c;这基本都是由于人工测量长度所导致的&#xff0c;因此本次和大家简单介绍一下SOLIDWORKS Electrical 3D布线的功能&#xff0c;Electrical 3D布线能…

伙伴匹配(后端)-- 用户登录

文章目录 登录逻辑设计登录业务代码实现用户登录态如何知道是哪个用户登录了&#xff1f;cookie与session 逻辑删除配置添加TableLogic注解 &#xff08;现在做单机登录&#xff09; 后面修改为redis单点登录 登录逻辑设计 接收参数&#xff1a;用户接账户&#xff0c;密码 请…

【数据标注】使用LabelImg标注YOLO格式的数据(案例演示)

文章目录 LabelImg介绍LabelImg安装LabelImg界面标注常用的快捷键标注前的一些设置案例演示检查YOLO标签中的标注信息是否正确参考文章 LabelImg介绍 LabelImg是目标检测数据标注工具&#xff0c;可以标注两种格式&#xff1a; VOC标签格式&#xff0c;标注的标签存储在xml文…

目标检测——蔬菜杂草数据集

引用 亲爱的读者们&#xff0c;您是否在寻找某个特定的数据集&#xff0c;用于研究或项目实践&#xff1f;欢迎您在评论区留言&#xff0c;或者通过公众号私信告诉我&#xff0c;您想要的数据集的类型主题。小编会竭尽全力为您寻找&#xff0c;并在找到后第一时间与您分享。 …

架构师系列- 消息中间件(12)-kafka基础

1、应用场景 1.1 kafka场景 Kafka最初是由LinkedIn公司采用Scala语言开发&#xff0c;基于ZooKeeper&#xff0c;现在已经捐献给了Apache基金会。目前Kafka已经定位为一个分布式流式处理平台&#xff0c;它以 高吞吐、可持久化、可水平扩展、支持流处理等多种特性而被广泛应用…

22年全国职业技能大赛——Web Proxy配置(web 代理)

前言&#xff1a;原文在我的博客网站中&#xff0c;持续更新数通、系统方面的知识&#xff0c;欢迎来访&#xff01; 系统服务&#xff08;22年国赛&#xff09;—— web Proxy服务&#xff08;web代理&#xff09;https://myweb.myskillstree.cn/114.html 目录 RouterSrv …

强复购、循环消费:排队复购模式助您在市场中脱颖而出

尊敬的各位读者&#xff0c;今天我很高兴向大家介绍一种新颖而又引人入胜的商业模式——排队复购模式。这个模式因其强大的复购属性和循环消费特性而备受瞩目&#xff0c;被誉为电商领域的新宠儿。 为何要介绍排队复购模式&#xff1f;因为它不仅操作简单、容易引起消费者的兴…

BUUCTF_[BSidesCF 2020]Had a bad day

[BSidesCF 2020]Had a bad day 1.一看题目直接尝试文件包含 2.直接报错&#xff0c;确实是存在文件包含漏洞 http://307b4461-36d6-443f-879a-68803a57f721.node5.buuoj.cn:81/index.php?categoryphp://filter/convert.base64-encode/resourceindex strpos() 函数查找字符串…

安卓玩机工具推荐----MTK芯片 简单制作线刷包 备份分区 备份基带 去除锁类 推荐工具操作解析

工具说明 在前面几期mtk芯片类玩机工具中解析过如何无官方固件从手机抽包 制作线刷包的步骤&#xff0c;类似的工具与操作有很多种。演示的只是本人片面的理解与一些步骤解析。mtk芯片机型抽包关键点在于..mt*****txt的分区地址段引导和 perloader临时分区引导。前面几期都是需…

【嵌入式Linux】STM32P1开发环境搭建

要进行嵌入式Linux开发&#xff0c;需要在Windows、Linux和嵌入式Linux3个系统之间来回跑&#xff0c;需要使用多个软件工具。经过了4小时的安装&#xff08;包括下载时间&#xff09;&#xff0c;我怕以后会忘记&#xff0c;本着互利互助的原则&#xff0c;我打算把这些步骤详…

java接口加密解密

这里写目录标题 controller加解密工具类加密&#xff08;本质是对ResponseBody加密&#xff09;解密&#xff08;本质是对RequestBody传参解密&#xff09;注解 controller Controller public class PathVariableController {GetMapping(value "/test")ResponseBod…

IDEA pom.xml依赖警告

IDEA中&#xff0c;有时 pom.xml 中会出现如下提示&#xff1a; IDEA 2022.1 升级了检测易受攻击的 Maven 和 Gradle 依赖项&#xff0c;并建议修正&#xff0c;通过插件 Package Checker 捆绑到 IDE 中。 这并不是引用错误&#xff0c;不用担心。如果实在强迫症不想看到这个提…