AIGC笔记--Diffuser的训练pipeline

1--简单训练pipeline

import time
import numpy as np
import torch
from PIL import Image
import torchvision
import torch.nn.functional as F
from datasets import load_dataset
from torchvision import transforms
from matplotlib import pyplot as plt
from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline# 数据增广
def transform(examples):preprocess = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5], [0.5]),])images = [preprocess(image.convert("RGB")) for image in examples["image"]]return {"images": images}def process_dataset(batch_size):# 加载数据集# dataset = load_dataset("huggan/smithsonian_butterflies_subset", split = "train")dataset = load_dataset("/data-home/liujinfu/Diffuser/Data/smithsonian_butterflies_subset", split = "train")# 调用自定义的transform函数dataset.set_transform(transform)# 设置dataloadertrain_dataloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle = True)return train_dataloaderdef train_loop(train_dataloader, noise_scheduler, model, num_epoches, device):# 优化器optimizer = torch.optim.AdamW(model.parameters(), lr = 4e-4)losses = []start_time = time.time() for epoch in range(num_epoches):for _, batch in enumerate(train_dataloader): # 遍历clean_images = batch["images"].to(device) # B C H W# sample noise to add to the imagesnoise = torch.randn(clean_images.shape).to(clean_images.device) # B C H Wbs = clean_images.shape[0] # 64# sample a random timestep for each imagetimesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs, ), device = clean_images.device).long() # B# Add noise to the clean images according to the noise magnitude at each timestepnoisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) # 加噪# Get model predictionnoise_pred = model(noisy_images, timesteps, return_dict=False)[0]# Calculate the lossloss = F.mse_loss(noise_pred, noise) # 计算预测噪音和真实噪音之间的损失loss.backward(loss)losses.append(loss.item())# Update the model parameters with the optimizeroptimizer.step()optimizer.zero_grad()if (epoch + 1) % 5 == 0:loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")end_time = time.time()elapsed_time = end_time - start_time # 记录训练时间print("time cost: ", elapsed_time)return lossesdef vis(losses):# 可视化 lossfig, axs = plt.subplots(1, 2, figsize=(12, 4))axs[0].plot(losses)axs[1].plot(np.log(losses))return figdef generate(model, noise_scheduler):# 1. create a pipelineimage_pipe = DDPMPipeline(unet = model, scheduler = noise_scheduler)pipeline_output = image_pipe()return pipeline_output.images[0]# 可视化生成图像
def show_images(x):"""Given a batch of images x, make a grid and convert to PIL"""x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)grid = torchvision.utils.make_grid(x)grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))return grid_imdef make_grid(images, size=64):"""Given a list of PIL images, stack them together into a line for easy viewing"""output_im = Image.new("RGB", (size * len(images), size))for i, im in enumerate(images):output_im.paste(im.resize((size, size)), (i * size, 0))return output_imdef main():# 获取训练集image_size = 32 batch_size = 64train_dataloader = process_dataset(batch_size = batch_size)# 设置Schedulernoise_scheduler = DDPMScheduler(num_train_timesteps = 1000, beta_schedule = "squaredcos_cap_v2") # 创建Unet modeldevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = UNet2DModel(sample_size = image_size, # target image resolutionin_channels = 3,out_channels = 3,layers_per_block = 2, # how many resnet layers to use per Unet blockblock_out_channels = (64, 128,128, 256),down_block_types = ("DownBlock2D","DownBlock2D","AttnDownBlock2D","AttnDownBlock2D",),up_block_types=("AttnUpBlock2D","AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention"UpBlock2D","UpBlock2D", # a regular ResNet upsampling block),  ).to(device)# 开始训练losses = train_loop(train_dataloader = train_dataloader, noise_scheduler = noise_scheduler,model = model,num_epoches = 30, device = device)fig = vis(losses)fig.savefig("./loss.png")# 生成一张图片gen_img = generate(model, noise_scheduler)gen_img.save("./generate1.png")# 随机初始化噪音生成图片sample = torch.randn(8, 3, 32, 32).to(device)for i, t in enumerate(noise_scheduler.timesteps): # 反向去噪# Get model predwith torch.no_grad():residual = model(sample, t).sample# Update sample with stepsample = noise_scheduler.step(residual, t, sample).prev_sample# 可视化生成的图片grid_im = show_images(sample)grid_im.save("./genearate2.png")print("All Done!")if __name__ == "__main__":main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/834473.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

视频监控平台:交通运输标准JTT808设备SDK接入源代码函数分享

目录 一、JT/T 808标准简介 (一)概述 (二)协议特点 1、通信方式 2、鉴权机制 3、消息分类 (三)协议主要内容 1、位置信息 2、报警信息 3、车辆控制 4、数据转发 二、代码和解释 (一…

《ESP8266通信指南》13-Lua 简单入门(打印数据)

往期 《ESP8266通信指南》12-Lua 固件烧录-CSDN博客 《ESP8266通信指南》11-Lua开发环境配置-CSDN博客 《ESP8266通信指南》10-MQTT通信(Arduino开发)-CSDN博客 《ESP8266通信指南》9-TCP通信(Arudino开发)-CSDN博客 《ESP82…

AJAX知识点(前后端交互技术)

原生AJAX AJAX全称为Asynchronous JavaScript And XML,就是异步的JS和XML,通过AJAX可以在浏览器中向服务器发送异步请求,最大的优势:无需刷新就可获取数据。 AJAX不是新的编程语言,而是一种将现有的标准组合在一起使用的新方式 …

C语言【文件操作 2】

文章目录 前言顺序读写函数的介绍fputc && fgetcfputcfgetc fputs && fgetsfputsfgets fprintf && fscanffprintffscanf fwrite && freadfwritefread 文件的随机读写fseek函数偏移量ftell函数rewind函数 文件的结束判断被错误使用的feof 结语 …

Linux与windows网络管理

文章目录 一、TCP/IP1.1、TCP/IP概念TCP/IP是什么TCP/IP的作用TCP/IP的特点TCP/IP的工作原理 1.2、TCP/IP网络发展史1.3、OSI网络模型1.4、TCP/IP网络模型1.5、linux中配置网络网络配置文件位置DNS配置文件主机名配置文件常用网络查看命令 1.6、windows中配置网络CMD中网络常用…

认识卷积神经网络

我们现在开始了解卷积神经网络,卷积神经网络是深度学习在计算机视觉领域的突破性成果,在计算机视觉领域,往往我们输入的图像都很大,使用全连接网络的话,计算的代价较高,图像也很难保留原有的特征&#xff0…

python 和 MATLAB 都能绘制的母亲节花束!!

hey 母亲节快到了,教大家用python和MATLAB两种语言绘制花束~这段代码是我七夕节发的,我对代码进行了简化,同时自己整了个python版本 MATLAB 版本代码 function roseBouquet_M() % author : slandarer% 生成花朵数据 [xr,tr]meshgrid((0:24).…

jQuery-1.语法、选择器、节点操作

jQuery jQueryJavaScriptQuery&#xff0c;是一个JavaScript函数库&#xff0c;为编写JavaScript提供了更高效便捷的接口。 jQuery安装 去官网下载jQuery&#xff0c;1.x版本练习就够用 jQuery引用 <script src"lib/jquery-1.11.2.min.js"></script>…

我的Transformer专栏来啦

五一节前吹的牛&#xff0c;五一期间没完成&#xff0c;今天忙里偷闲&#xff0c;给完成了。 那就是初步拟定了一个《Transformer最后一公里》的写作大纲。 之前一直想写一系列Transformer架构的算法解析文章&#xff0c;但因为一直在忙&#xff08;虽然不知道在忙啥&#xf…

倍思|西圣开放式耳机哪个好用?热门机型深度测评!

在数字化生活的浪潮中&#xff0c;耳机已成为我们不可或缺的伴侣。然而&#xff0c;长时间佩戴传统的耳机容易导致的耳道疼痛等问题&#xff0c;严重的话将影响听力。许多人开始寻找更为舒适的佩戴体验。开放式耳机因为不需要需直接插入耳道的设计&#xff0c;逐渐受到大众的青…

Apipost使用心得,让接口文档变得更清晰,更快捷

Idea和Apipost结合使用 Idea 安装插件Apipost-Helper-2.0 在【file】–>【settings】–>【Plugins】搜索 “Apipost-Helper-2.0”–>【install】&#xff0c;重启Idea 编写controller接口 在idea中编写业务功能及接口之后&#xff0c;在controller中鼠标【右键】单…

Linux下的SPI通信

SPI通信 一. 1.SPI简介: SPI 是一种高速,全双工,同步串行总线。 SPI 有主从俩种模式通常由一个主设备和一个或者多个从设备组从。SPI不支持多主机。 SPI通信至少需要四根线,分别是 MISO(主设备数据输入,从设备输出),MOSI (主设数据输出从设备输入),SCLK(时钟信号),CS/SS…

安卓开发--新建工程,新建虚拟手机,按键事件响应

安卓开发--新建工程&#xff0c;新建虚拟手机&#xff0c;按键事件响应 1.前言2.运行一个工程2.1布局一个Button2.2 button一般点击事件2.2 button属性点击事件2.2 button推荐点击事件 本篇博客介绍安卓开发的入门工程&#xff0c;通过使用按钮Buton来了解一个工程的运作机制。…

【SpringBoot记录】自动配置原理(1):依赖管理

前言 我们都知道SpringBoot能快速创建Spring应用&#xff0c;其核心优势就在于自动配置功能&#xff0c;它通过一系列的约定和内置的配置来减少开发者手动配置的工作。下面通过最简单的案例分析SpringBoot的功能特性&#xff0c;了解自动配置原理。 SpringBoot简单案例 根据S…

第 129 场 LeetCode 双周赛题解

A 构造相同颜色的正方形 枚举&#xff1a;枚举每个 3 3 3\times 3 33的矩阵&#xff0c;判断是否满足条件 class Solution {public:bool canMakeSquare(vector<vector<char>>& grid) {for (int i 0; i < 2; i)for (int j 0; j < 2; j) {int c1 0, c…

hypertherm海宝EDGE控制器显示屏工控机维修

海宝工控机维修V3.0/4.0/5.0&#xff1b;hypertherm数控切割机系统MICRO EDGE系统显示屏维修&#xff1b; 美国hypertherm公司mirco edge数控系统技术标准如下&#xff1a; 1&#xff09; p4处理器 2&#xff09; 512mb内存 3&#xff09; 80g硬盘&#xff0c;1.44m内置软驱…

IOS Xcode证书配置和ipa打包流程(附详细图文教程)

IOS Xcode证书配置和ipa打包流程&#xff08;附图文教程&#xff09; 前言ipa文件简介证书文件简介Provisioning Profile描述文件简介当前环境版本Xcode证书配置和ipa打包流程生成Apple Distribution Certificates证书创建描述文件&#xff08;Provisioning Profiles&#xff0…

Goland开发者软件激活使用教程

Goland开发者工具&#xff1a; Goland是由JetBrains公司推出的专门针对Go语言设计的集成开发环境&#xff08;IDE&#xff09;。这款工具具有智能的代码补全、强大的代码导航和重构功能&#xff0c;同时提供了丰富的调试工具&#xff0c;能够满足Golang开发者的各种需求。 Gol…

pwn(一)前置技能

以下是pwn中的题目&#xff08;漏洞&#xff09;类型&#xff1a; 关于pwn的学习&#xff1a; 一.什么是pwn&#xff1f;&#xff08;二进制的漏洞&#xff09; "Pwn"是一个俚语&#xff0c;起源于电子游戏社区&#xff0c;经常在英语中用作网络或电子游戏文化中的…