【图像合成】基于DCGAN典型网络的MNIST字符生成(pytorch)

关于

 

近年来,基于卷积网络(CNN)的监督学习已经 在计算机视觉应用中得到了广泛的采用。相比之下,无监督 使用 CNN 进行学习受到的关注较少。在这项工作中,我们希望能有所帮助 缩小了 CNN 在监督学习和无监督学习方面的成功之间的差距。我们介绍一类称为深度卷积生成的 CNN 对抗性网络(DCGAN),具有一定的架构限制,以及 证明他们是无监督学习的有力候选人。训练 在各种图像数据集上,我们展示了令人信服的证据,表明我们的深度卷积对抗对学习了从对象部分到 生成器和鉴别器中的场景。此外,我们使用学到的 新任务的特征 - 证明它们作为一般图像表示的适用性。(https://arxiv.org/pdf/1511.06434.pdf)

工具

 数据集

方法实现

加载必要的库函数和自定义函数

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as Ffrom torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
def get_sample_image(G, n_noise):"""save sample 100 images"""z = torch.randn(100, n_noise).to(DEVICE)y_hat = G(z).view(100, 28, 28) # (100, 28, 28)result = y_hat.cpu().data.numpy()img = np.zeros([280, 280])for j in range(10):img[j*28:(j+1)*28] = np.concatenate([x for x in result[j*10:(j+1)*10]], axis=-1)return img

定义判别模型

class Discriminator(nn.Module):"""Convolutional Discriminator for MNIST"""def __init__(self, in_channel=1, num_classes=1):super(Discriminator, self).__init__()self.conv = nn.Sequential(# 28 -> 14nn.Conv2d(in_channel, 512, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2),# 14 -> 7nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),# 7 -> 4nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2),nn.AvgPool2d(4),)self.fc = nn.Sequential(# reshape input, 128 -> 1nn.Linear(128, 1),nn.Sigmoid(),)def forward(self, x, y=None):y_ = self.conv(x)y_ = y_.view(y_.size(0), -1)y_ = self.fc(y_)return y_

定义生成模型

class Generator(nn.Module):"""Convolutional Generator for MNIST"""def __init__(self, input_size=100, num_classes=784):super(Generator, self).__init__()self.fc = nn.Sequential(nn.Linear(input_size, 4*4*512),nn.ReLU(),)self.conv = nn.Sequential(# input: 4 by 4, output: 7 by 7nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(),# input: 7 by 7, output: 14 by 14nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.ReLU(),# input: 14 by 14, output: 28 by 28nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),nn.Tanh(),)def forward(self, x, y=None):x = x.view(x.size(0), -1)y_ = self.fc(x)y_ = y_.view(y_.size(0), 512, 4, 4)y_ = self.conv(y_)return y_

 模型超参数定义配置

batch_size = 64criterion = nn.BCELoss()
D_opt = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.001, betas=(0.5, 0.999))max_epoch = 30 # need more than 20 epochs for training generator
step = 0
n_critic = 1 # for training more k steps about Discriminator
n_noise = 100D_labels = torch.ones([batch_size, 1]).to(DEVICE) # Discriminator Label to real
D_fakes = torch.zeros([batch_size, 1]).to(DEVICE) # Discriminator Label to fake

 模型训练

for epoch in range(max_epoch):for idx, (images, labels) in enumerate(data_loader):# Training Discriminatorx = images.to(DEVICE)x_outputs = D(x)D_x_loss = criterion(x_outputs, D_labels)z = torch.randn(batch_size, n_noise).to(DEVICE)z_outputs = D(G(z))D_z_loss = criterion(z_outputs, D_fakes)D_loss = D_x_loss + D_z_lossD.zero_grad()D_loss.backward()D_opt.step()if step % n_critic == 0:# Training Generatorz = torch.randn(batch_size, n_noise).to(DEVICE)z_outputs = D(G(z))G_loss = criterion(z_outputs, D_labels)D.zero_grad()G.zero_grad()G_loss.backward()G_opt.step()if step % 500 == 0:print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item()))if step % 1000 == 0:G.eval()img = get_sample_image(G, n_noise)imsave('./{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')G.train()step += 1

测试生成效果

# generation to image
G.eval()
imshow(get_sample_image(G, n_noise), cmap='gray')

 

模型和状态参量保存

def save_checkpoint(state, file_name='checkpoint.pth.tar'):torch.save(state, file_name)# Saving params.
# torch.save(D.state_dict(), 'D_c.pkl')
# torch.save(G.state_dict(), 'G_c.pkl')
save_checkpoint({'epoch': epoch + 1, 'state_dict':D.state_dict(), 'optimizer' : D_opt.state_dict()}, 'D_dc.pth.tar')
save_checkpoint({'epoch': epoch + 1, 'state_dict':G.state_dict(), 'optimizer' : G_opt.state_dict()}, 'G_dc.pth.tar')

应用

DCGAN作为一个成熟的生成模型,在自然图像,医学图像,医学电生理信号数据分析中,都可以用来实现数据的合成,达到数据增强的目的,同时,如何减少增强数据对于后端任务的不利干扰,也是一个需要关注的方面。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/777246.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

aws使用记录

数据传输(S3) 安装命令行 安装awscli: https://docs.aws.amazon.com/zh_cn/cli/latest/userguide/getting-started-install.html#getting-started-install-instructions 直到 aws configure list 可以运行 身份验证: 运行: aws config…

【QGIS从shp文件中筛选目标区域导出为shp】

文章目录 1、写在前面2、QGIS将shp文件中目标区域输出为shp2.1、手动点选2.2、高级过滤 3、上述shp完成后,配合python的shp文件,即可凸显研究区域了 1、写在前面 利用shp文件制作研究区域mask,Matlab版本,请点击 Matlab利用shp文…

网络编程综合项目-多用户通信系统

文章目录 1.项目所用技术栈本项目使用了java基础,面向对象,集合,泛型,IO流,多线程,Tcp字节流编程的技术 2.通信系统整体分析主要思路(自己理解)1.如果不用多线程2.使用多线程3.对多线…

uniapp-Form示例(uviewPlus)

示例说明 Vue版本&#xff1a;vue3 组件&#xff1a;uviewPlus&#xff08;Form 表单 | uview-plus 3.0 - 全面兼容nvue的uni-app生态框架 - uni-app UI框架&#xff09; 说明&#xff1a;表单组建、表单验证、提交验证等&#xff1b; 截图&#xff1a; 示例代码 <templat…

O2OA(翱途)开发平台-快速入门开发一个门户实例

O2OA(翱途)开发平台[下称O2OA开发平台或者O2OA]拥有门户页面定制与集成的能力&#xff0c;平台通过门户定制&#xff0c;可以根据企业的文化&#xff0c;业务需要设计符合企业需要的统一信息门户&#xff0c;系统首页等UI界面。本篇主要介绍通过门户管理系统如何快速的进行一个…

学点儿Java_Day12_IO流

1 IO介绍以及分类 IO: Input Output 流是一组有顺序的&#xff0c;有起点和终点的字节集合&#xff0c;是对数据传输的总称或抽象。即数据在两设备间的传输称为流&#xff0c;流的本质是数据传输&#xff0c;根据数据传输特性将流抽象为各种类&#xff0c;方便更直观的进行数据…

C++取经之路(其二)——含数重载,引用。

含数重载: 函数重载是指&#xff1a;在c中&#xff0c;在同一作用域&#xff0c;函数名相同&#xff0c;形参列表不相同(参数个数&#xff0c;或类型&#xff0c;或顺序)不同&#xff0c;C语言不支持。 举几个例子&#xff1a; 1.参数类型不同 int Add(int left, int right)…

【任职资格】某大型制造型企业任职资格体系项目纪实

该企业以业绩、责任、能力为导向&#xff0c;确定了分层分类的整体薪酬模式&#xff0c;但是每一名员工到底应该拿多少工资&#xff0c;同一个岗位的人员是否应该拿同样的工资是管理人员比较头疼的事情。华恒智信顾问认为&#xff0c;通过任职资格评价能实现真正的人岗匹配&…

基于Transformer的医学图像分类研究

医学图像分类目前面临的挑战 医学图像分类需要研究人员同时具备医学图像分析和数字图像的知识背景。由于图像尺度、数据格式和数据类别分布的影响&#xff0c;现有的模型方法&#xff0c;如传统的机器学习的识别方法和基于深度卷积神经网络的方法&#xff0c;取得的识别准确度…

微软AI 程序员AutoDev,自主执行工程任务生成代码

全球首个 AI 程序员 Devin 的横空出世&#xff0c;可能成为软件和 AI 发展史上一个重要的节点。它掌握了全栈的技能&#xff0c;不仅可以写代码 debug&#xff0c;训模型&#xff0c;还可以去美国最大求职网站 Upwork 上抢单。 Devin 诞生之后&#xff0c;让码农纷纷恐慌。最近…

智慧光伏:企业无纸化办公

随着科技的快速发展&#xff0c;光伏技术不仅成为推动绿色能源革命的重要力量&#xff0c;更在企业办公环境中扮演起引领无纸化办公的重要角色。智慧光伏不仅为企业提供了清洁、可持续的能源&#xff0c;更通过智能化的管理方式&#xff0c;推动企业向无纸化办公转型&#xff0…

滑动窗口_水果成篮_C++

题目&#xff1a; 题目解析&#xff1a; fruits[i]表示第i棵树&#xff0c;这个fruits[i]所表示的数字是果树的种类例如示例1中的[1,2,1]&#xff0c;表示第一棵树 的种类是 1&#xff0c;第二个树的种类是2 第三个树的种类是1随后每一个篮子只能装一种类型的水果&#xff0c;我…

SQL Server事务复制操作出现的错误 进程无法在“xxx”上执行sp_replcmds

SQL Server事务复制操作出现的错误 进程无法在“xxx”上执行“sp_replcmds” 无法作为数据库主体执行&#xff0c;因为主体 "dbo" 不存在、无法模拟这种类型的主体&#xff0c;或您没有所需的权限

术语技巧:如何格式化网页中的术语

术语是语言服务中的核心语言资产。快速处理英汉对照的术语是我们在翻译技术学习过程中需要掌握的必备技能。 通常&#xff0c;我们需要把在权威网站上收集到的术语放到word当中&#xff0c;调整正左右对齐的样式&#xff0c;便于打印学习或者转化为Excel表。 如何快速实现这一…

加密流量分类torch实践5:TrafficClassificationPandemonium项目更新3

加密流量分类torch实践5&#xff1a;TrafficClassificationPandemonium项目更新3 更新日志 代码已经推送开源至露露云的github&#xff0c;如果能帮助你&#xff0c;就给鼠鼠点一个star吧&#xff01;&#xff01;&#xff01; 我的CSDN博客 我的Github Page博客 3/23日更新…

iOS - Runtime-API

文章目录 iOS - Runtime-API1. Runtime应用1.1 字典转模型1.2 替换方法实现1.3 利用关联对象给分类添加属性1.4 利用消息转发机制&#xff0c;解决方法找不到的异常问题 2. Runtime-API2.1 Runtime API01 – 类2.1.1 动态创建一个类&#xff08;参数&#xff1a;父类&#xff0…

【Pt】马灯贴图绘制过程 02-制作锈迹

目录 一、边缘磨损效果 二、刮痕效果 三、边缘磨损与刮痕的混合 四、锈迹效果 本篇效果&#xff1a; 一、边缘磨损效果 将智能材质“Iron Forge Old” 拖入图层 打开“Iron Forge Old” 文件夹&#xff0c;选中“Sharpen”&#xff08;锐化&#xff09;&#xff0c;增大“…

2010-2021年银行网点及员工信息数据

2010-2021年银行网点及员工信息数据 1、时间&#xff1a;2010-2021年 2、来源&#xff1a;整理自csmar 3、指标&#xff1a;银行代码、股票代码、银行中文简称、统计截止日期、分行数量、机构网点数量、其中&#xff1a;境内网点数量、其中&#xff1a;境外网点数量、在职员…

Linux集群

目录 一、什么是集群&#xff1f; 二、 搭建(tomcatnginxkeepalived)集群 一、JDK安装 二、Tomcat安装 三、Nginx 3.1、什么是Nginx&#xff1f; 3.2、下载Nginx 3.3、安装 四、搭建NginxTomcat的实现集群 配置nginx.comf文件 五&#xff1a;Nginx搭建图片服务器 …

【Java程序设计】【C00392】基于(JavaWeb)Springboot的校园生活服务平台(有论文)

基于&#xff08;JavaWeb&#xff09;Springboot的校园生活服务平台&#xff08;有论文&#xff09; 项目简介项目获取开发环境项目技术运行截图 博主介绍&#xff1a;java高级开发&#xff0c;从事互联网行业六年&#xff0c;已经做了六年的毕业设计程序开发&#xff0c;开发过…