【PyTorch][chapter 20][李宏毅深度学习]【无监督学习][ GAN]【实战】

前言

 本篇主要是结合手写数字例子,结合PyTorch 介绍一下Gan 实战

第一轮训练效果

第20轮训练效果,已经可以生成数字了

68 轮


目录: 

  1.   谷歌云服务器(Google Colab)
  2.   整体训练流程
  3.   Python 代码

一  谷歌云服务器(Google Colab)

     个人用的一直是联想小新笔记本,虽然非常稳定方便。但是现在跑深度学习,性能确实有点跟不上. 

   1.1    打开谷歌云服务器(Google Colab)

      https://colab.research.google.com/

    1. 2  新建笔记

                 

1

 1.4  选择T4GPU 

1.5  点击运行按钮

可以看到当前硬件的情况

     


二  整体训练流程


三    PyTorch 例子

# -*- coding: utf-8 -*-
"""
Created on Fri Mar  1 13:27:49 2024@author: chengxf2
"""
import torch.optim as optim #优化器
import numpy as np 
import matplotlib.pyplot  as plt
import torchvision
from torchvision import transforms
import torch
import torch.nn as nn#第一步加载手写数字集
def loadData():#同时归一化数据集(-1,1)style = transforms.Compose([transforms.ToTensor(),   #0-1 归一化0-1, channel,height,widthtransforms.Normalize(mean=0.5, std=0.5) #变成了-1,1 ])trainData = torchvision.datasets.MNIST('data',train=True,transform=style,download=True)dataloader = torch.utils.data.DataLoader(trainData,batch_size= 16,shuffle=True)imgs,_ = next(iter(dataloader))#torch.Size([64, 1, 28, 28])print("\n imgs shape ",imgs.shape)return dataloaderclass Generator(nn.Module):'''定义生成器输入:z 随机噪声[batch, input_size]输出:x: 图片 [batch, height, width, channel]'''def __init__(self,input_size):super(Generator,self).__init__()self.net = nn.Sequential(nn.Linear(in_features = input_size , out_features =256),nn.ReLU(),nn.Linear(in_features = 256 , out_features =512),nn.ReLU(),nn.Linear(in_features = 512 , out_features =28*28),nn.Tanh())def forward(self, z):# z 随机输入[batch, dim]x = self.net(z)#[batch, height, width, channel]#print(x.shape)x = x.view(-1,28,28,1)return xclass Discriminator(nn.Module):'''定义鉴别器输入:x: 图片 [batch, height, width, channel]输出:y:  二分类图片的概率: BCELoss 计算交叉熵损失'''def __init__(self):super(Discriminator,self).__init__()#开始的维度和终止的维度,默认值分别是1和-1self.flatten = nn.Flatten()self.net = nn.Sequential(nn.Linear(in_features = 28*28 , out_features =512),nn.LeakyReLU(), #负值的时候保留梯度信息nn.Linear(in_features = 512 , out_features =256),nn.LeakyReLU(),nn.Linear(in_features = 256 , out_features =1),nn.Sigmoid())def forward(self, x):x = self.flatten(x)#print(x.shape)out =self.net(x)return outdef gen_img_plot(model, epoch, test_input):out = model(test_input).detach().cpu()out = out.numpy()imgs = np.squeeze(out)fig = plt.figure(figsize=(4,4))for i in range(out.shape[0]):plt.subplot(4,4,i+1)img = (imgs[i]+1)/2.0#[-1,1]plt.imshow(img)plt.axis('off')plt.show()def train():#1 初始化参数device ='cuda' if torch.cuda.is_available() else 'cpu'#2 加载训练数据dataloader = loadData()test_input  = torch.randn(16,100,device=device)#3 超参数maxIter = 20 #最大训练次数input_size = 100batchNum = 16input_size =100#4 初始化模型gen = Generator(100).to(device)dis = Discriminator().to(device)#5 优化器,损失函数d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4)g_optim = torch.optim.Adam(gen.parameters(),lr=1e-4)loss_fn = torch.nn.BCELoss()#6 loss 变化列表D_loss =[]G_loss= []for epoch in range(0,maxIter):d_epoch_loss = 0.0g_epoch_loss  =0.0#count = len(dataloader)for step ,(realImgs, _) in enumerate(dataloader):realImgs = realImgs.to(device)random_noise = torch.randn(batchNum, input_size).to(device)#先训练判别器d_optim.zero_grad()real_output = dis(realImgs)d_real_loss = loss_fn(real_output, torch.ones_like(real_output))d_real_loss.backward()#不要训练生成器,所以要生成器detachfake_img = gen(random_noise)fake_output = dis(fake_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))d_fake_loss.backward()d_loss = d_real_loss+d_fake_lossd_optim.step()#优化生成器g_optim.zero_grad()fake_output = dis(fake_img.detach())g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()g_optim.step()with torch.no_grad():d_epoch_loss+= d_lossg_epoch_loss+= g_losscount = 16       with torch.no_grad():d_epoch_loss/=countg_epoch_loss/=countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)gen_img_plot(gen, epoch, test_input)print("Epoch: ",epoch)print("-----finised-----")if __name__ == "__main__":train()

参考:

10.完整课程简介_哔哩哔哩_bilibili

理论【PyTorch][chapter 19][李宏毅深度学习]【无监督学习][ GAN]【理论】-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/711894.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux学习-字符串数组和字符串

目录 使用场景 字符型数组定义: 初始化 数组储存 打印 字符型数组常见函数 常见操作 strcpy:字符串拷贝 strcat(str1,str2)字符串拼接 strcmp:字符串比较 注意: 二维字符型数…

Open CASCADE学习|曲线曲面连续性

1、曲线的连续性 曲线的连续性是三维建模、动画设计等领域中非常重要的一个概念,它涉及到曲线在不同点之间的连接方式和光滑程度。下面将详细介绍曲线的连续性,包括C连续性和G连续性。 1.1C连续性(参数连续性) C连续性是指曲线…

使用MyBatisPlus实现向数据库中存储List类型的数据

使用MyBatisPlus实现向数据库中存储List类型的数据 问题描述 建表时,表中的这五个字段为json类型 但是在入库的时候既不能写入数据,也不能查询出数据。 解决方案: 1.首先明确,数据存入的时候是经过了数据类型转化&#xff0c…

中国电子学会2020年06月真题C语言软件编程等级考试三级(含详细解析答案)

中国电子学会考评中心历届真题(含解析答案) C语言软件编程等级考试三级 2020年06月 编程题五道 总分:100分一、最接近的分数(20分) 分母不超过N且小于A/B的最大最简分数是多少? 时间限制: 1000ms 内存限制: 65536kb 输入…

数据之光:探索数据库技术的演进之路

✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨ 🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua,在这里我会分享我的知识和经验。&#x…

喜讯!持安科技CEO何艺获评安全419《2023年度十大优秀创业者》

近日,由网络安全产业资讯媒体安全419主办的《年度策划》2023年度十大优秀创业者正式出炉,零信任办公安全技术创新企业持安科技创始人兼CEO何艺,获评十大优秀创业者。 这是安全419第二届推出该项目的评选活动,安全419编辑老师在多年…

抽象类、模板方法模式

抽象类概述 在Java中abstract是抽象的意思,如果一个类中的某个方法的具体实现不能确定,就可以申明成abstract修饰的抽象方法(不能写方法体了),这个类必须用abstract修饰,被称为抽象类。 抽象方法定义&…

【解决】修改 UI界面渲染层级 的常见误区

开发平台:Unity 2021版本   问题描述 Unity 中管理 UI 上显示元素的前后层级关系大致为以下两种方式: 方式一:修改UI元素队列顺序与层级方式二:使用 Canvas 组件中的 Override Sort 属性配置 方式二 对应复杂的 UI 层级关系将常…

这些单片机汇编语言的错误,你还在犯错吗?

在单片机开发中,很多工程师会选择汇编语言来作为底层编程,来直接控制硬件和高校执行命令,然而因为汇编语言是直接与硬件交互,所以很容易出现错误,本文将基于Keil C51汇编器的环境总结单片机汇编语言常见的错误&#xf…

人工智能_大模型010_Centos7.9中CPU安装ChatGLM3-6B大模型_安装使用_010---人工智能工作笔记0145

从一个空的虚拟机开始安装: https://www.modelscope.cn/models/ZhipuAI/chatglm3-6b/files 可以看到这里有很多的数据文件,那么这里 这里点击模型文件就可以下载,这个就是chatglm3-6B的文件,需要点击每个文件,然后点击右边的下载,把文件都下载下来 右侧有下载按钮.点击下载可…

使用Fabric创建的canvas画布背景图片,自适应画布宽高

之前的文章写过vue2使用fabric实现简单画图demo,完成批阅功能;但是功能不完善,对于很大的图片就只能显示一部分出来,不符合我们的需求。这就需要改进,对我们设置的背景图进行自适应。 有问题的canvas画布背景 修改后的…

Unity2023.1.19_ECS

Unity2023.1.19_ECS 在学习的路上一往无前的遇到了好东西,官方的EntityComponnentSystemSamples的Repository,这是一个包含实体,图形,网络,物理案例的全方位案例教程。 又找见接下来要干的事情了!学习永无…

【rust】11、所有权

文章目录 一、背景二、Stack 和 Heap2.1 Stack2.2 Heap2.3 性能区别2.4 所有权和堆栈 三、所有权原则3.1 变量作用域3.2 String 类型示例 四、变量绑定背后的数据交互4.1 所有权转移4.1.1 基本类型: 拷贝, 不转移所有权4.1.2 分配在 Heap 的类型: 转移所有权 4.2 Clone(深拷贝)…

Quartz 任务调度框架源码阅读解析

概念: quartz 是一个基于JAVA的定时任务调度框架 案例: <dependency><groupId>org.quartz-scheduler</groupId><artifactId>quartz</artifactId><version>2.3.0</version></dependency>JobDetail job JobBuilder.newJob(Sc…

每日一练 | 华为认证真题练习Day191

1、在没有启用BGP路径负载分担的情况下&#xff0c;哪种BGP路由会发送BGP邻居? A. 从所有邻居学到的所有BGP路由。 B. 只有从IBGP学到的路由。 C. 只有从EBGP学到的路由。 D. 只有被BGP优选的最佳路由。 2、第三类LSA的LINK ID是 A. 生成这条LSA的路由器的ROUTER ID B. …

LeetCode 刷题 [C++] 第236题.二叉树的最近公共祖先

题目描述 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为&#xff1a;“对于有根树 T 的两个节点 p、q&#xff0c;最近公共祖先表示为一个节点 x&#xff0c;满足 x 是 p、q 的祖先且 x 的深度尽可能大&#xff08;一个节点也可以…

大数据分析案例-基于SVM支持向量机算法构建手机价格分类预测模型

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…

矩阵爆破逆向之条件断点的妙用

不知道你是否使用过IDA的条件断点呢&#xff1f;在IDA进阶使用中&#xff0c;它的很多功能都有大作用&#xff0c;比如&#xff1a;ida-trace来跟踪调用流程。同时IDA的断点功能也十分强大&#xff0c;配合IDA-python的输出语句能够大杀特杀&#xff01; 那么本文就介绍一下这…

【JAVA】JDK内置工具之appletviewer

下载java 下载java的时候会先下载Java jdk&#xff0c;Java Development Kit Java开发工具包。 然后会下载jre&#xff0c;也就是Java Runtime Environment Java运行环境。什么是JDK、JRE&#xff1f;_java中的jdk,jre代表什么-CSDN博客 下载之后先找到java下的bin文件&#x…

yolov9 tensorRT 的 C++ 部署

yolov9 tensorRT C 部署 本示例中&#xff0c;包含完整的代码、模型、测试图片、测试结果。 完整的代码、模型、测试图片、测试结果【github参考链接】 TensorRT版本&#xff1a;TensorRT-7.1.3.4 导出onnx模型 导出适配本实例的onnx模型参考【yolov9 瑞芯微芯片rknn部署、地平…