AI绘画训练一个扩散模型-上集

介绍

AI绘画,其中最常见方案基于扩散模型,Stable Diffusion 在此基础上,增加了 VAE 模块和 CLIP 模块,本文搞了一个测试Demo,分为上下两集,第一集是denoising_diffusion_pytorch ,第二集是diffusers。
对于专业的算法同学而言,我更推荐使用 diffusers 来训练。原因是 diffusers 工具包在实际的 AI 绘画项目中用得更多,并且也更易于我们修改代码逻辑,实现定制化功能。
https://arxiv.org/abs/2112.10752

基础模块

  • 创建UNet模型和高斯扩散模型(Gaussian Diffusion)。

UNet是一个编码器-解码器结构的全卷积神经网络。Gaussian Diffusion用于定义噪声过程和损失函数。

  • 将模型加载到GPU上(如果有GPU)。

  • 使用随机初始化的图片进行一次训练,计算损失并反向传播。

这一步的目的是对模型进行一次预热,更新权重。

  • 使用diffusion模型采样生成图片。

这里采样1000步,也就是将噪声逐步减少,每步用UNet预测下一步的图像,最终输出生成的图片。

  • 如果图片在GPU上,将其移回到CPU。

  • 可视化第一张生成图片。

plt.imshow显示图片。

这样通过DDPM框架,可以从随机噪声生成符合数据分布的新图片。每次训练会使模型逐步逼近真实数据分布,从而产生更高质量的图片。

# 创建UNet和扩散模型from denoising_diffusion_pytorch import Unet, GaussianDiffusion
import torchmodel = Unet(dim = 64,dim_mults = (1, 2, 4, 8)
).cuda()diffusion = GaussianDiffusion(model,image_size = 128,timesteps = 1000   # number of steps
).cuda()# 使用随机初始化的图片进行一次训练
training_images = torch.randn(8, 3, 128, 128)
loss = diffusion(training_images.cuda())
loss.backward()# 采样1000步生成一张图片
sampled_images = diffusion.sample(batch_size = 4)
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import torchvision.transforms as transforms# 如果张量在 GPU上,需要移动到 CPU上
if sampled_images.is_cuda:sampled_images = sampled_images.cpu()# 检查我们生成的一张图
img = sampled_images[0].clone().detach().permute(1, 2, 0)plt.imshow(img)

数据集

  • 导入所需的库:PIL、io、datasets等。

  • 使用datasets库中的load_dataset方法加载Oxford Flowers数据集。

  • 创建一个目录来保存图片。

  • 遍历数据集的训练、验证、测试split,逐个图像获取图片bytes数据,并保存为PNG格式图片。

  • 使用PIL库的Image对象将bytes数据加载并保存为图片文件。

  • 使用tqdm显示循环进度。

# 数据集下载
from PIL import Image
from io import BytesIO
from datasets import load_dataset
import os
from tqdm import tqdmdataset = load_dataset("nelorth/oxford-flowers")# 创建一个用于保存图片的文件夹
images_dir = "./oxford-datasets/raw-images"
os.makedirs(images_dir, exist_ok=True)# 遍历所有图片并保存,针对oxford-flowers,整个过程要持续15分钟左右
for split in dataset.keys():for index, item in enumerate(tqdm(dataset[split])):image = item['image']image.save(os.path.join(images_dir, f"{split}_image_{index}.png"))

模型训练

  • 定义Unet模型架构和Gaussian Diffusion过程。

  • 加载数据,指定训练参数:

    • 训练总步数20000
    • batch size 16
    • 学习率2e-5
    • 梯度累积步数2
    • EMA指数衰减参数0.995
    • 使用混合精度训练
    • 每2000步保存一次模型
    • 创建Trainer进行模型训练。Trainer封装了训练循环逻辑。
  • 调用trainer.train()进行训练。

# 模型训练
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainermodel = Unet(dim = 64,dim_mults = (1, 2, 4, 8)
).cuda()diffusion = GaussianDiffusion(model,image_size = 128,timesteps = 1000   # 加噪总步数
).cuda()trainer = Trainer(diffusion,'./oxford-datasets/raw-images',train_batch_size = 16,train_lr = 2e-5,train_num_steps = 20000,          # 总共训练20000步gradient_accumulate_every = 2,    # 梯度累积步数ema_decay = 0.995,                # 指数滑动平均decay参数amp = True,                       # 使用混合精度训练加速calculate_fid = False,            # 我们关闭FID评测指标计算(比较耗时)。FID用于评测生成质量。save_and_sample_every = 2000      # 每隔2000步保存一次模型
)trainer.train()
# 你可以等待上面的模型训练完成后,查看生成结果from glob import globresult_images = glob(r"./results/*.png")
print(result_images)
# 可视化图像看看
from PIL import Imageimg = Image.open("./results/sample-1.png")
img

测试

https://colab.research.google.com/github/NightWalker888/ai_painting_journey/blob/main/lesson12/train_diffusion_v2.ipynb#scrollTo=8BVjfBPI93Ar

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/420912.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

顺序串的实现

顺序串接口查找增加...测试接口 package com.lovely.string;/*** * author echo lovely* 2020年6月9日下午6:44:31** 串的接口描述*/ public interface IString {public void clear();public boolean isEmpty(); public int length(); public char charAt(int i) throws Excep…

WPF多线程UI更新——两种方法

WPF多线程UI更新——两种方法 前言 在WPF中,在使用多线程在后台进行计算限制的异步操作的时候,如果在后台线程中对UI进行了修改,则会出现一个错误:(调用线程无法访问此对象,因为另一个线程拥有该对象。&…

jdbc万能dao

jdbc万能dao一,为何封装万能dao二,代码实现一,为何封装万能dao 不用框架,纯jdbc连接数据库,会用到dao包,如果每个表都要写增删改查,一个dao至少四个方法,dao会有大量代码重复&#…

5月27日

其实前天我想说我有点理解我爸了 当年到福建的时候跟现在差不多吧 气候 方言 吃的 住的 跟自己原来习惯的完全是不同 恐怕人人都会问这到底是为了什么 能为了什么呢? 路走到这儿了 有的是自己选的 有的不是 但已经走到这里了 当时晚上到福建家里的 第二天醒来看 屋子…

二叉树的递归遍历

二叉树遍历一,什么是二叉树二,递归实现2.1 结点类描述2.2 三种递归2.2 测试一,什么是二叉树 在计算机科学中,二叉树是每个结点最多有两个子树的树结构。通常子树被称作"左子树"(left subtree)和&…

概率论的公理结构

样本点 一个随机事件出现的可能的结果叫做样本点。 类比平面几何,线、面、体也是由点组成的集合,研究的是点线面关系及性质,同样样本点也是组成事件(集合)的材料,是集合的基本元素,把这些样本…

python词云的简单使用

词云的生成所需库代码实现wordclod参数说明具体实现效果展示所需库 wordcloud, jieba, imageiowordcloud 词云库,用来统计文本文档里面出现的高频词汇,或者句子,以图片可视化的方式显示出来jieba库,分割中文的库,把较…

(一)Neo4j在Centos7虚拟机上的安装

1、什么是图数据库? 图数据库是基于数学里图论的思想和算法而实现的高效处理复杂关系网络的新型数据库系统。图形数据库善于高效处理大量的、复杂的、互连的、多变的数据。其计算效率远远高于传统的关系型数据库。图形数据库在社交网络、实时推荐、征信系统、人工智…

代码演示 .NET 4.5 自带的 ReadonlyCollection 的使用

代码如下: 1. using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks;namespace ConfigurationLibrary {public class ConfigElement{public int Id { get; set; }public string Value { get; set…

(二)Cypher语言常用方法举例

1、概述 “Cypher”是一个描述性的类Sql的图操作语言。相当于关系数据库的Sql,可见其重要性!其语法针对图的特点而设计,非常方便和灵活。没有Join,是一大特点!学好Cypher是学好Neo4j的关键,也是核心所在&a…

Java高新技术 枚举

Java高新技术 枚举 知识概要: (1)为什么要有枚举 (2)枚举的示例 (3)枚举的应用 (4)枚举的高级应用 (1)为什么需要枚举 问题:要定义星期几或性别的变量,该怎么定义? 假设用1-7分别表示星期一到星期日&am…

github 人像卡通化探索项目

把项目下载到本地 下载地址 https://github.com/minivision-ai/photo2cartoon安装依赖库 python 3.7 # 3.x版本都可 pytorch 1.4 tensorflow-gpu 1.14 # tesorflow 得是1.0版本,2.0版本语法部分改变,不然项目运行会出错 face-alignment dlibpytorch …

浅谈城市规划在移动GIS方面的应用发展

1、概述 城市建设进程加快,城市规划管理工作日趋繁重,各种来源的数据产生各种层出不穷的问题,严重影响城市规划时的准确性,为此全面合理的掌握好各方面的城市规划资料才能做出更加科学的决策。移动端的兴起为规划动态方面提供了极…

(四)Neo4j删除数据需要注意的问题

1、先删关系,再删节点 # 删除所有记录 MATCH (n) OPTIONAL MATCH (n)-[r]-() DELETE n,r 2、彻底删除节点标签名,需要删除前期对该标签名建立的索引 # 查看全部索引 :schema# 删除索引 drop index on :Person(id)# 当索引删除不掉时,可能是…

jsp阶段总结

目录web开发jsp是运行在服务器端还是客户端? 服务端 js是运行在服务器端还是客户端? 客户端 jsp的本质是什么? jsp原理 jsp的本质就是servlet jsp在服务器中,当浏览器请求该jsp时,jsp文件在服务器中会经历什么过程? 转译:将jsp文件转译成java文件 编译:将转译后的java文…

CVE-2013-3897漏洞成因与利用分析

CVE-2013-3897漏洞成因与利用分析 1. 简介 此漏洞是UAF(Use After Free)类漏洞,即引用了已经释放的内存。攻击者可以利用此类漏洞实现远程代码执行。UAF漏洞的根源源于对对象引用计数的处理不当,比如在编写程序时忘记AddRef或者多…

(三)Neo4j自带northwind案例--Cypher语言应用

0、概述 通过该案例,应用Cypher查询语言,感受Neo4j套路。官方的用此案例的用意: The Northwind Graph demonstrates how to migrate(迁移) from a relational database to Neo4j(把一个负责的多表关系数据…

servlet 源码分析

servlet源码分析1. servlet接口1.1 看servlet源码1.2 直接用类实现servlet接口,来写servlet类2. servlet子类GenericServlet2.1 servlet子类实现GenericServlet抽象类2.2 继承GenericServelt抽象类3. httpServelt类分析4. 这么多搬来的代码,最后总结1. s…

RDIFramework.NET 中多表关联查询分页实例

RDIFramework.NET 中多表关联查询分页实例 RDIFramework.NET 中多表关联查询分页实例 RDIFramework.NET,基于.NET的快速信息化系统开发、整合框架,给用户和开发者最佳的.Net框架部署方案。该框架以SOA范式作为指导思想,作为异质系统整合与互操…