pytorch实现变分自编码器

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

变分自编码器(Variational Autoencoder, VAE)是一种生成模型,属于深度学习中的无监督学习方法。它通过学习输入数据的潜在分布(Latent Distribution),生成与输入数据相似的新样本。VAE 可以用于数据生成、降维、异常检测等任务。

VAE 的关键思想是在传统的自编码器(Autoencoder)的基础上,引入了变分推断(Variational Inference)和概率模型,使得网络能够学习到数据的潜在分布,而不仅仅是数据的映射。

VAE 的结构:

  1. 编码器(Encoder):将输入数据映射到潜在空间的分布。不同于传统的自编码器直接将数据映射到一个固定的潜在向量,VAE 通过输出潜在变量的均值和方差来描述一个概率分布,这样潜在空间中的每个点都有一个概率分布。
  2. 潜在空间(Latent Space):表示数据的潜在特征。在 VAE 中,潜在空间的表示是一个分布而不是固定的值。通常,采用正态分布来作为潜在空间的先验分布。
  3. 解码器(Decoder):从潜在空间的样本中重构输入数据。解码器通过将潜在空间的点映射回数据空间来生成样本。

VAE 的目标函数:

VAE 的目标是最大化变分下界(Variational Lower Bound,简称 ELBO),即通过优化以下两部分的加权和:

  • 重构误差(Reconstruction Loss):衡量生成的数据和输入数据之间的差异,通常使用均方误差(MSE)或交叉熵(Cross-Entropy)。
  • KL 散度(KL Divergence):衡量潜在空间的分布与先验分布(通常是标准正态分布)之间的差异。

其最终的目标是使生成的数据尽可能接近真实数据,同时使潜在空间的分布接近先验分布。

优点:

  • VAE 能够生成具有多样性的样本,尤其适用于图像、音频等数据的生成。
  • 潜在空间通常具有良好的结构,可以进行插值、样本生成等操作。

应用:

  • 生成任务:如图像生成、文本生成等。
  • 数据重构:如去噪、自编码等。
  • 半监督学习:VAE 可以结合有标签和无标签的数据进行训练,提升模型的泛化能力。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt# 生成圆形图像的函数(使用PyTorch)
def generate_circle_image(size=64):image = torch.zeros((1, size, size))  # 使用 PyTorch 创建空白图像center = size // 2radius = size // 4for y in range(size):for x in range(size):if (x - center) ** 2 + (y - center) ** 2 <= radius ** 2:image[0, y, x] = 1  # 在圆内的点设置为白色return image# 生成方形图像的函数(使用PyTorch)
def generate_square_image(size=64):image = torch.zeros((1, size, size))  # 使用 PyTorch 创建空白图像padding = size // 4image[0, padding:size - padding, padding:size - padding] = 1  # 设置方形区域为白色return image# 自定义数据集:圆形和方形图像
class ShapeDataset(Dataset):def __init__(self, num_samples=1000, size=64):self.num_samples = num_samplesself.size = sizeself.data = []# 生成数据:一半是圆形图像,一半是方形图像for i in range(num_samples // 2):self.data.append(generate_circle_image(size))self.data.append(generate_square_image(size))def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx].float()  # 直接返回 PyTorch Tensor 格式的数据# VAE模型定义
class VAE(nn.Module):def __init__(self, latent_dim=2):super(VAE, self).__init__()self.latent_dim = latent_dim# 编码器self.fc1 = nn.Linear(64 * 64, 400)self.fc21 = nn.Linear(400, latent_dim)  # 均值self.fc22 = nn.Linear(400, latent_dim)  # 方差# 解码器self.fc3 = nn.Linear(latent_dim, 400)self.fc4 = nn.Linear(400, 64 * 64)def encode(self, x):h1 = torch.relu(self.fc1(x.view(-1, 64 * 64)))return self.fc21(h1), self.fc22(h1)  # 返回均值和方差def reparameterize(self, mu, logvar):std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h3 = torch.relu(self.fc3(z))return torch.sigmoid(self.fc4(h3)).view(-1, 1, 64, 64)  # 重构图像def forward(self, x):mu, logvar = self.encode(x)z = self.reparameterize(mu, logvar)return self.decode(z), mu, logvar# 损失函数:重构误差 + KL 散度
def loss_function(recon_x, x, mu, logvar):BCE = nn.functional.binary_cross_entropy(recon_x.view(-1, 64 * 64), x.view(-1, 64 * 64), reduction='sum')# KL 散度return BCE + 0.5 * torch.sum(torch.exp(logvar) + mu ** 2 - 1 - logvar)# 设置超参数
batch_size = 128
epochs = 10
latent_dim = 2
learning_rate = 1e-3# 数据加载
train_loader = DataLoader(ShapeDataset(num_samples=2000), batch_size=batch_size, shuffle=True)# 创建模型和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
def train(epoch):model.train()train_loss = 0for batch_idx, data in enumerate(train_loader):data = data.to(device)optimizer.zero_grad()recon_batch, mu, logvar = model(data)loss = loss_function(recon_batch, data, mu, logvar)loss.backward()train_loss += loss.item()optimizer.step()if batch_idx % 100 == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item() / len(data):.6f}')print(f'Train Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')# 测试并显示一些真实图像和生成的图像
def test():model.eval()with torch.no_grad():# 获取一批真实的图像(原始图像)real_images = next(iter(train_loader))[:64]  # 只取前64个图像real_images = real_images.cpu().numpy()# 从潜在空间随机生成一些样本sample = torch.randn(64, latent_dim).to(device)generated_images = model.decode(sample).cpu().numpy()# 显示真实图像和生成的图像,分别标明fig, axes = plt.subplots(8, 8, figsize=(8, 8))axes = axes.flatten()for i in range(64):if i < 32:  # 前32个显示真实图像axes[i].imshow(real_images[i].squeeze(), cmap='gray')axes[i].set_title('Real', fontsize=8)else:  # 后32个显示生成图像axes[i].imshow(generated_images[i - 32].squeeze(), cmap='gray')axes[i].set_title('Generated', fontsize=8)axes[i].axis('off')plt.tight_layout()plt.show()# 训练模型
for epoch in range(1, epochs + 1):train(epoch)# 训练完成后,显示生成的图像
test()

解释:

  1. 真实图像 (real_images):我们通过 next(iter(train_loader)) 获取一批真实图像,并将其转换为 NumPy 数组,以便 matplotlib 显示。
  2. 生成图像 (generated_images):通过模型生成的图像,使用 decode() 方法生成潜在空间的样本。
  3. 图像展示:前 32 张图像展示真实图像,后 32 张图像展示生成的图像。每个图像上方都有 RealGenerated 标注。

结果:

  • 前32个图像:显示真实图像,并标注为 Real
  • 后32个图像:显示通过训练后的 VAE 生成的图像,并标注为 Generated

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/894339.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《AI大模型开发笔记》DeepSeek技术创新点

一、DeepSeek横空出世 DeepSeek V3 以颠覆性技术架构创新强势破局&#xff01;革命性的上下文处理机制实现长文本推理成本断崖式下降&#xff0c;综合算力需求锐减90%&#xff0c;开启高效 AI 新纪元&#xff01; 最新开源的 DeepSeek V3模型不仅以顶尖基准测试成绩比肩业界 …

数仓实战项目,大数据数仓实战(离线数仓+实时数仓)

1.课程目标 2.电商行业与电商系统介绍 3.数仓项目整体技术架构介绍 4.数仓项目架构-kylin补充 5.数仓具体技术介绍与项目环境介绍 6.kettle的介绍与安装 7.kettle入门案例 这个连线是点击shift键&#xff0c;然后鼠标左键拖动 ctrls保存一下 csv输入配置 Excel输出配置 配置完 …

Spring Web MVC基础第一篇

目录 1.什么是Spring Web MVC&#xff1f; 2.创建Spring Web MVC项目 3.注解使用 3.1RequestMapping&#xff08;路由映射&#xff09; 3.2一般参数传递 3.3RequestParam&#xff08;参数重命名&#xff09; 3.4RequestBody&#xff08;传递JSON数据&#xff09; 3.5Pa…

【Linux】使用VirtualBox部署Linux虚拟机

1. 下载并安装 VirtualBox 访问 VirtualBox 官网&#xff0c;下载适合你操作系统的版本&#xff08;Windows&#xff09;。安装 VirtualBox&#xff0c;按照安装向导的提示完成安装。 2. 下载 Linux 发行版 ISO 文件 访问你选择的 Linux 发行版官方网站&#xff08;例如&…

Day07:缓存-数据淘汰策略

Redis的数据淘汰策略有哪些 ? &#xff08;key过期导致的&#xff09; 在redis中提供了两种数据过期删除策略 第一种是惰性删除&#xff0c;在设置该key过期时间后&#xff0c;我们不去管它&#xff0c;当需要该key时&#xff0c;我们再检查其是否过期&#xff0c;如果过期&…

[原创](Modern C++)现代C++的关键性概念: 正则表达式

常用网名: 猪头三 出生日期: 1981.XX.XX 企鹅交流: 643439947 个人网站: 80x86汇编小站 编程生涯: 2001年~至今[共24年] 职业生涯: 22年 开发语言: C/C、80x86ASM、PHP、Perl、Objective-C、Object Pascal、C#、Python 开发工具: Visual Studio、Delphi、XCode、Eclipse、C Bui…

sobel边缘检测算法

人工智能例子汇总&#xff1a;AI常见的算法和例子-CSDN博客 Sobel边缘检测算法是一种用于图像处理中的边缘检测方法&#xff0c;它能够突出图像中灰度变化剧烈的地方&#xff0c;也就是边缘。该算法通过计算图像在水平方向和垂直方向上的梯度来检测边缘&#xff0c;梯度值越大…

Google Chrome-便携增强版[解压即用]

Google Chrome-便携增强版 链接&#xff1a;https://pan.xunlei.com/s/VOI0OyrhUx3biEbFgJyLl-Z8A1?pwdf5qa# a 特点描述 √ 无升级、便携式、绿色免安装&#xff0c;即可以覆盖更新又能解压使用&#xff01; √ 此增强版&#xff0c;支持右键解压使用 √ 加入Chrome增强…

FLTK - FLTK1.4.1 - demo - bitmap

文章目录 FLTK - FLTK1.4.1 - demo - bitmap概述笔记END FLTK - FLTK1.4.1 - demo - bitmap 概述 // 功能 : 演示位图数据在按钮上的显示 // * 以按钮为范围或者以窗口为范围移动 // * 上下左右, 文字和图像的相对位置 // 失能按钮&#xff0c;使能按钮 // 知识点 // FLTK可…

分布式数据库架构与实践:原理、设计与优化

&#x1f4dd;个人主页&#x1f339;&#xff1a;一ge科研小菜鸡-CSDN博客 &#x1f339;&#x1f339;期待您的关注 &#x1f339;&#x1f339; 1. 引言 随着大数据和云计算的快速发展&#xff0c;传统单机数据库已难以满足大规模数据存储和高并发访问的需求。分布式数据库&…

设计模式Python版 桥接模式

文章目录 前言一、桥接模式二、桥接模式示例三、桥接模式与适配器模式的联用 前言 GOF设计模式分三大类&#xff1a; 创建型模式&#xff1a;关注对象的创建过程&#xff0c;包括单例模式、简单工厂模式、工厂方法模式、抽象工厂模式、原型模式和建造者模式。结构型模式&…

携程Android开发面试题及参考答案

在项目中,给别人发的动态点赞功能是如何实现的? 数据库设计:首先要在数据库中为动态表添加一个点赞字段,用于记录点赞数量,同时可能需要一个点赞关系表,记录用户与动态之间的点赞关联,包括点赞时间等信息。界面交互:在 Android 界面上,为点赞按钮设置点击事件监听器。…

【C语言】main函数解析

文章目录 一、前言二、main函数解析三、代码示例四、应用场景 一、前言 在学习编程的过程中&#xff0c;我们很早就接触到了main函数。在Linux系统中&#xff0c;当你运行一个可执行文件&#xff08;例如 ./a.out&#xff09;时&#xff0c;如果需要传入参数&#xff0c;就需要…

CSS核心

CSS的引入方式 内部样式表是在 html 页面内部写一个 style 标签&#xff0c;在标签内部编写 CSS 代码控制整个 HTML 页面的样式。<style> 标签理论上可以放在 HTML 文档的任何地方&#xff0c;但一般会放在文档的 <head> 标签中。 <style> div { color: r…

传奇引擎游戏微端的作用

传奇引擎游戏微端是一种优化的游戏客户端分发与运行方式&#xff0c;其主要目的是通过减少玩家的下载压力和提升游戏启动速度&#xff0c;让玩家更快地进入游戏。微端在传奇私服以及其他网络游戏中广泛使用&#xff0c;尤其适用于容量较大的游戏客户端。下面从作用、实现原理和…

从0开始使用面对对象C语言搭建一个基于OLED的图形显示框架(基础组件实现)

目录 基础组件实现 如何将图像和文字显示到OLED上 如何绘制图像 如何绘制文字 如何获取字体&#xff1f; 如何正确的访问字体 如何抽象字体 如何绘制字符串 绘制方案 文本绘制 更加方便的绘制 字体附录 ascii 6x8字体 ascii 8 x 16字体 基础组件实现 我们现在离手…

吴晓波 历代经济变革得失@简明“中国经济史” - 读书笔记

目录 《历代经济变革得失》读书笔记一、核心观点二、主要内容&#xff08;一&#xff09;导论&#xff08;二&#xff09;春秋战国时期&#xff08;三&#xff09;汉代&#xff08;四&#xff09;北宋&#xff08;五&#xff09;明清时期&#xff08;六&#xff09;近现代&…

Theorem

Theorem 打开题&#xff1a; from Crypto.Util.number import *from gmpy2 import *flag bxxxm bytes_to_long(flag) #flaglong_to_bytes(m)p getPrime(512) #随机生成一个512位的素数pq next_prime(p) #p之后的下一个…

变量的作用域和生命周期

一、根据变量的作用域不同&#xff0c;可分为 局部变量 和 全局变量 1. 作用域&#xff1a;变量起作用的范围&#xff08;变量定义之后&#xff0c;在哪里可以访问变量&#xff09;。 就近原则&#xff1a;当不同作用域里面有两个或者多个同名变量&#xff0c;那么遵循就近原…

力扣【669. 修剪二叉搜索树】Java题解

一开始在想为什么题目说存在唯一答案。然后发现是二叉搜索树就合理了。如下图&#xff1a;如果0节点小于low&#xff0c;那其左子树也都小于low&#xff0c;故可以排除&#xff1b;对于4&#xff0c;其右子树也是可以排除。 代码如下&#xff1a; class Solution {public Tre…