生成对抗网络(Generative Adversarial Networks ,GAN)

生成对抗网络是深度学习领域最具革命性的生成模型之一。

一 GAN框架

1.1组成

构造生成器(G)与判别器(D)进行动态对抗,实现数据的无监督生成。

G(造假者):接收噪声 $z \sim p_z$​,生成数据 $G(z)$ 。 

D(鉴定家):接收真实数据 $x \sim p_{data}$ 和生成数据 $G(z)$,输出概率 $D(x)$ 或 $D(G(z))$

1.2核心原理

对抗目标:

$\min_G \max_D V(D,G) = \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]$

该公式为极大极小博弈,D和G互为对手,在动态博弈中驱动模型逐步提升性能。

其中:

第一项(真实性强化) $\mathbb{E}_{x \sim p_{data}}[\log D(x)]$ 的目标为:

让判别器 D 将真实数据 x 识别为“真”(即让 $D(x) \to 1$),最大化这一项可使D对真实数据的判断置信度更高。

第二项(生成性对抗):$\mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]$的目标为:

生成器G希望生成的假数据$G(z)$被判别器D判为“真”(即让 $D(G(z)) \to 1$)从而最小化$\log(1 - D(G(z)))$,判别器D则希望判假数据为“假”(即让$D(G(z)) \to 0$),从而最大化$\log(1 - D(G(z)))$

生成器和判别器在此项上存在直接对抗。

数学原理

(1)最优判别器理论

固定生成器G时,最大化 $V(D,G)$ 得到最优判别器 $D^*$

$D^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}$

其中, $p_g(x)$ 是生成数据的分布。当生成数据完美匹配真实分布时 $p_{g} = p_{data}$,判别器无法区分真假(输出$D(x)=0.5 \quad \forall x$)。

推导一下:

$V(D, G) = \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]$                                      (1)

转换为积分形式,设真实数据分布为 $p_{data}(x)$,生成数据分布为$p_g(x)$,则(1)可以改写为:

$V(D, G) = \int_{x} p_{data}(x) \log D(x) dx + \int_{x} p_g(x) \log(1 - D(x)) dx$                             (2) 

逐点最大化对于每个样本 x,单独最大化以下函数:

$f(D(x)) = p_{data}(x) \log D(x) + p_g(x) \log(1 - D(x))$                                                  (3)

求导并解方程: 

$f'(D(x)) = \frac{p_{data}(x)}{D(x)} - \frac{p_g(x)}{1 - D(x)} = 0$                                                                            (4)

求得:

$D^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}$                                                                                                 (5) 

(2) 目标化简:JS散度(Jensen-Shannon Divergence) 

将最优判别器 $D^*$ 代入原目标函数,可得:

$V(G,D^*) = 2 \cdot JSD(p_{data} \| p_g) - 2\log 2$

最小化目标即等价于最小化 $p_{data}$$p_g$ 的JS散度。

JS散度特性:对称、非负,衡量两个分布的相似性。

1.3训练过程解释

每个训练步骤包含两阶段:

(1)判别器更新(固定G,最大化 $V(D,G)$

$\nabla_D \left[ \log D(x) + \log(1 - D(G(z))) \right]$

通过梯度上升优化D的参数,提升判别能力。

(2)生成器更新(固定D,最小化 $V(D,G)$

$\nabla_G \left[ \log(1 - D(G(z))) \right]$

实际训练中常用 $\nabla_G \left[ -\log D(G(z)) \right]$代替以增强梯度稳定性。

训练中出现的问题

(1)JS散度饱和导致梯度消失

(2)参数空间的非凸优化(存在无数个局部极值,优化算法极易陷入次优解,而非全局最优解)使训练难以收敛

二 经典GAN架构

DCGAN(GAN+卷积)

特性原始GANDCGAN
网络结构全连接层(MLPs)卷积生成器 + 卷积判别器
稳定性容易梯度爆炸/消失,难以收敛通过BN和特定激活函数稳定训练
生成图像分辨率低分辨率(如32x32)支持64x64及以上分辨率的清晰图像生成
图像质量轮廓模糊,缺乏细节细粒度纹理(如毛发、砖纹)
计算效率参数量大,训练速度慢卷积结构参数共享,效率提升

(1)生成器架构(反卷积)

class Generator(nn.Module):def __init__(self, noise_dim=100, output_channels=3):super().__init__()self.main = nn.Sequential(# 输入:100维噪声,输出:1024x4x4nn.ConvTranspose2d(noisel_dim, 512, 4, 1, 0, bias=False),nn.BatchNorm2d(512),nn.ReLU(True),# 上采样至8x8nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),nn.BatchNorm2d(256),nn.ReLU(True),# 输出层:3通道RGB图像nn.ConvTranspose2d(64, output_channels, 4, 2, 1, bias=False),nn.Tanh()  # 将输出压缩到[-1,1])

(2) 判别器架构(卷积)

class Discriminator(nn.Module):def __init__(self, input_channels=3):super().__init__()self.main = nn.Sequential(# 输入:3x64x64nn.Conv2d(input_channels, 64, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# 下采样至32x32nn.Conv2d(64, 128, 4, 2, 1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),# 输出层:二分类概率nn.Conv2d(512, 1, 4, 1, 0, bias=False),nn.Sigmoid())

(3)双重优化问题

保持生成器和判别器动态平衡的核心机制

for epoch in range(num_epochs):# 更新判别器optimizer_D.zero_grad()real_loss = adversarial_loss(D(real_imgs), valid)fake_loss = adversarial_loss(D(gen_imgs.detach()), fake)d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()# 更新生成器optimizer_G.zero_grad()g_loss = adversarial_loss(D(gen_imgs), valid)  # 欺诈判别器g_loss.backward()optimizer_G.step()

三 应用场景 

图像合成引擎(语义图到照片)、医学影像增强、语音与音频合成。

GAN作为生成式AI的基石模型,其核心价值不仅在于数据生成能力,更在于构建了一种全新的深度学习范式——通过对抗博弈驱动模型持续进化。

四 一个完整DCGAN代码示例

MNIST数据集

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 参数设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
image_size = 64
num_epochs = 50
latent_dim = 100
lr = 0.0002
beta1 = 0.5# 数据准备
transform = transforms.Compose([transforms.Resize(image_size),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # MNIST是单通道
])dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)# 可视化辅助函数
def show_images(images):plt.figure(figsize=(8,8))images = images.permute(1,2,0).cpu().numpy()plt.imshow((images * 0.5) + 0.5)  # 反归一化plt.axis('off')plt.show()# 权重初始化
def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find('BatchNorm') != -1:nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0)# 生成器定义
class Generator(nn.Module):def __init__(self, latent_dim):super(Generator, self).__init__()self.main = nn.Sequential(# 输入:latent_dim x 1 x 1nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),nn.BatchNorm2d(512),nn.ReLU(True),# 输出:512 x 4 x 4nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),nn.BatchNorm2d(256),nn.ReLU(True),# 输出:256 x 8 x 8nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),nn.BatchNorm2d(128),nn.ReLU(True),# 输出:128 x 16x16nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),nn.BatchNorm2d(64),nn.ReLU(True),# 输出:64 x 32x32nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),nn.Tanh()  # 输出范围[-1,1]# 最终输出:1 x 64x64)def forward(self, input):return self.main(input)# 判别器定义
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(# 输入:1 x 64x64nn.Conv2d(1, 64, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# 输出:64 x32x32nn.Conv2d(64, 128, 4, 2, 1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),# 输出:128x16x16nn.Conv2d(128, 256, 4, 2, 1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace=True),# 输出:256x8x8nn.Conv2d(256, 512, 4, 2, 1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace=True),# 输出:512x4x4nn.Conv2d(512, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, input):return self.main(input).view(-1, 1).squeeze(1)# 初始化模型
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)# 应用权重初始化
generator.apply(weights_init)
discriminator.apply(weights_init)# 损失函数和优化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))# 训练过程可视化准备
fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)# 训练循环
for epoch in range(num_epochs):for i, (real_images, _) in enumerate(dataloader):# 准备数据real_images = real_images.to(device)batch_size = real_images.size(0)# 真实标签和虚假标签real_labels = torch.full((batch_size,), 0.9, device=device)  # label smoothingfake_labels = torch.full((batch_size,), 0.0, device=device)# ========== 训练判别器 ==========optimizer_D.zero_grad()# 真实图片的判别结果outputs_real = discriminator(real_images)loss_real = criterion(outputs_real, real_labels)# 生成假图片noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)fake_images = generator(noise)# 假图片的判别结果outputs_fake = discriminator(fake_images.detach())loss_fake = criterion(outputs_fake, fake_labels)# 合并损失并反向传播loss_D = loss_real + loss_fakeloss_D.backward()optimizer_D.step()# ========== 训练生成器 ==========optimizer_G.zero_grad()# 更新生成器时的判别结果outputs = discriminator(fake_images)loss_G = criterion(outputs, real_labels)  # 欺骗判别器# 反向传播loss_G.backward()optimizer_G.step()# 打印训练状态if i % 100 == 0:print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] "f"Loss_D: {loss_D.item():.4f} Loss_G: {loss_G.item():.4f}")# 每个epoch结束时保存生成结果with torch.no_grad():test_images = generator(fixed_noise)grid = torchvision.utils.make_grid(test_images, nrow=8, normalize=True)show_images(grid)# 保存模型检查点if (epoch+1) % 5 == 0:torch.save(generator.state_dict(), f'generator_epoch_{epoch+1}.pth')torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch+1}.pth')print("训练完成!")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/81198.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

httpclient请求出现403

问题 httpclient请求对方服务器报403,用postman是可以的 解决方案: request.setHeader( “User-Agent” ,“Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:50.0) Gecko/20100101 Firefox/50.0” ); // 设置请求头 原因: 因为没有设置为浏览器形式&#…

嵌入式硬件篇---IIC

文章目录 前言1. IC协议基础1.1 物理层特性两根信号线SCLSDA支持多主多从 标准模式电平 1.2 通信流程起始条件(Start Condition)从机地址(Slave Address)应答(ACK/NACK)数据传输:停止条件&#…

深入探讨 Java 注解:从基础到高级应用

Java 注解自 Java 5 引入以来,已成为现代 Java 开发中不可或缺的一部分。它们通过为代码添加元数据,简化了配置、增强了代码可读性,并支持了从编译时验证到运行时动态行为的多种功能。本文将全面探讨 Java 注解的使用、定义和处理方式,并通过一个实际的插件系统示例展示其强…

力扣-105.从前序与中序遍历序列构造二叉树

题目描述 给定两个整数数组 preorder 和 inorder &#xff0c;其中 preorder 是二叉树的先序遍历&#xff0c; inorder 是同一棵树的中序遍历&#xff0c;请构造二叉树并返回其根节点。 class Solution { public:TreeNode* buildTree(vector<int>& preorder, vecto…

NoSQL数据库技术与应用复习总结【看到最后】

第1章 初识NoSQL 1.1 大数据时代对数据存储的挑战 1.高并发读写需求 2.高效率存储与访问需求 3.高扩展性 1.2 认识NoSQL NoSQL--非关系型、分布式、不提供ACID的数据库设计模式 NoSQL特点 1.易扩展 2.高性能 3.灵活的数据模型 4.高可用 NoSQL拥有一个共同的特点&am…

【ios越狱包安装失败?uniapp导出ipa文件如何安装到苹果手机】苹果IOS直接安装IPA文件

问题场景&#xff1a; 提示&#xff1a;ipa是用于苹果设备安装的软件包资源 设备&#xff1a;iphone 13(未越狱) 安装包类型&#xff1a;ipa包 调试工具&#xff1a;hbuilderx 问题描述 提要&#xff1a;ios包无法安装 uniapp导出ios包无法安装 相信有小伙伴跟我一样&…

php数据导出pdf,然后pdf转图片,再推送钉钉群

public function takePdf($data_plan, $data_act, $file_name, $type){$pdf new \TCPDF(L); // L - 横向 P-竖向// 设置文档信息//$file_name 外协批价单;$pdf->SetCreator($file_name);$pdf->SetAuthor($file_name);$pdf->SetTitle($file_name);$pdf->SetSubjec…

每日算法-250513

每日算法 - 2024-05-13 记录今天学习的算法题解。 2335. 装满杯子需要的最短总时长 题目 思路 贪心 这道题的关键在于每次操作尽可能多地减少杯子的数量。我们每次操作可以装一杯或两杯&#xff08;不同类型&#xff09;。为了最小化总时间&#xff0c;应该优先选择装两杯不同…

城市生命线综合管控系统解决方案-守护城市生命线安全

一、政策背景 国务院办公厅《城市安全风险综合监测预警平台建设指南》‌要求&#xff1a;将燃气、供水、排水、桥梁、热力、综合管廊等纳入城市生命线监测体系&#xff0c;建立"能监测、会预警、快处置"的智慧化防控机制。住建部‌《"十四五"全国城市基础…

分布式AI推理的成功之道

随着AI模型逐渐成为企业运营的核心支柱&#xff0c;实时推理已成为推动这一转型的关键引擎。市场对即时、可决策的AI洞察需求激增&#xff0c;而AI代理——正迅速成为推理技术的前沿——即将迎来爆发式普及。德勤预测&#xff0c;到2027年&#xff0c;超半数采用生成式AI的企业…

auto.js面试题及答案

以下是常见的 Auto.js 面试题及参考答案&#xff0c;涵盖基础知识、脚本编写、运行机制、权限、安全等方面&#xff0c;适合开发岗位的技术面试准备&#xff1a; 一、基础类问题 什么是 Auto.js&#xff1f;它的主要用途是什么&#xff1f; 答案&#xff1a; Auto.js 是一个…

C语言中的指定初始化器

什么是指定初始化器? C99标准引入了一种更灵活、直观的初始化语法——指定初始化器(designated initializer), 可以在初始化列表中直接引用结构体或联合体成员名称的语法。通过这种方式,我们可以跳过某些不需要初始化的成员,并且可以以任意顺序对特定成员进行初始化。这…

高德地图在Vue3中的使用方法

1.地图初始化 容器创建&#xff1a;通过 <div> 标签定义地图挂载点。 <div id"container" style"height: 300px; width: 100%; margin-top: 10px;"></div> 密钥配置&#xff1a;绑定高德地图安全密钥&#xff0c;确保 API 合法调用。 参…

RabbitMQ发布订阅模式深度解析与实践指南

目录 RabbitMQ发布订阅模式深度解析与实践指南1. 发布订阅模式核心原理1.1 消息分发模型1.2 核心组件对比 2. 交换机类型详解2.1 交换机类型矩阵2.2 消息生命周期 3. 案例分析与实现案例1&#xff1a;基础广播消息系统案例2&#xff1a;分级日志处理系统案例3&#xff1a;分布式…

中小型培训机构都用什么教务管理系统?

在教育培训行业快速发展的今天&#xff0c;中小型培训机构面临着学员管理复杂、课程体系多样化、教学效果难以量化等挑战。一个高效的教务管理系统已成为机构运营的核心支撑。本文将深入分析当前市场上适用于中小型培训机构的教务管理系统&#xff0c;重点介绍爱耕云这一专业解…

C++虚函数食用笔记

虚函数定义与作用&#xff1a; virtual关键字声明虚函数&#xff0c;虚函数可被派生类override(保证返回类型与参数列表&#xff0c;名字均相同&#xff09;&#xff0c;从而通过基类指针调用时&#xff0c;实现多态的功能 virtual关键字: 将函数声明为虚函数 override关键…

运算放大器相关的电路

1运算放大器介绍 解释&#xff1a;运算放大器本质就是一个放大倍数很大的元件&#xff0c;就如上图公式所示 Vp和Vn相差很小但是放大后输出还是会很大。 运算放大器不止上面的三个引脚&#xff0c;他需要独立供电&#xff1b; 如图比较器&#xff1a; 解释&#xff1a;Vp&…

华为OD机试真题——通信系统策略调度(用户调度问题)(2025B卷:100分)Java/python/JavaScript/C/C++/GO最佳实现

2025 B卷 100分 题型 本专栏内全部题目均提供Java、python、JavaScript、C、C++、GO六种语言的最佳实现方式; 并且每种语言均涵盖详细的问题分析、解题思路、代码实现、代码详解、3个测试用例以及综合分析; 本文收录于专栏:《2025华为OD真题目录+全流程解析+备考攻略+经验分…

Ubuntu 系统默认已安装 python,此处只需添加一个超链接即可

步骤 1&#xff1a;确认 Python 3 的安装路径 查看当前 Python 3 的路径&#xff1a; which python3 输出类似&#xff1a; /usr/bin/python3 步骤 2&#xff1a;创建符号链接 使用 ln -s 创建符号链接&#xff0c;将 python 指向 python3&#xff1a; sudo ln -s /usr/b…

深度学习-分布式训练机制

1、分布式训练时&#xff0c;包括train.py的全部的代码都会在每个gpu上运行吗&#xff1f; 在分布式训练&#xff08;如使用 PyTorch 的 DistributedDataParallel&#xff0c;DDP&#xff09;时&#xff0c;每个 GPU 上运行的进程会执行 train.py 的全部代码&#xff0c;但通过…