VAE In JAX【个人记录向】

news/2025/9/19 18:30:50/文章来源:https://www.cnblogs.com/tlx-blog/p/19101480

和上一篇 SAC In JAX 一样,我们用 JAX 实现 VAE,配置一样,只需要安装符合版本的 torchvision 即可,实现中提供了 tensorboard 记录以及最后的可视化展示,测试集即为最经典的 MNIST,代码如下:

import jax
import flax
import math
import optax
import numpy as np
import jax.numpy as jnp
from flax import linen as nn
from datetime import datetime
from torchvision import datasets
from flax.training import train_state
from matplotlib import pyplot as plt
from flax.training.train_state import TrainState
from stable_baselines3.common.logger import configuredef get_data():train_dataset = datasets.MNIST(root='./data', train=True, download=True)test_dataset = datasets.MNIST(root='./data', train=False)img_train = train_dataset.data.numpy().reshape(-1, 784)train_x_min = np.min(img_train, axis=0)train_x_max = np.max(img_train, axis=0)train_x = (img_train - train_x_min) / (train_x_max - train_x_min + 1e-7)train_y = train_dataset.targets.numpy()img_test = test_dataset.data.numpy().reshape(-1, 784)test_x = (img_test - train_x_min) / (train_x_max - train_x_min + 1e-7)test_y = test_dataset.targets.numpy()N = train_x.shape[0]M = test_x.shape[0]return jnp.asarray(train_x), jnp.asarray(train_y), jnp.asarray(test_x), jnp.asarray(test_y), N, Mclass VAE_encoder(nn.Module):hidden_dim: intlatent_dim: int@nn.compactdef __call__(self, x):x = nn.Dense(self.hidden_dim)(x)encode = nn.relu(x)mu = nn.Dense(self.latent_dim)(encode)log_sig = nn.Dense(self.latent_dim)(encode)return mu, log_sigclass VAE_decoder(nn.Module):output_dim: inthidden_dim: int@nn.compactdef __call__(self, latent):x = nn.Dense(self.hidden_dim)(latent)x = nn.relu(x)x = nn.Dense(self.output_dim)(x)x = nn.sigmoid(x)return xclass VAE:def __init__(self, input_dim, encoder_lr, decoder_lr, epochs, batch_size, logger, key):self.input_dim = input_dimself.encoder_lr, self.encoder_lr = encoder_lr, decoder_lrself.epochs = epochsself.batch_size = batch_sizeself.logger = loggerself.key = keyself.hidden_dim = 400self.latent_dim = 48self.encoder = VAE_encoder(self.hidden_dim, self.latent_dim)self.decoder = VAE_decoder(self.input_dim, self.hidden_dim)self.key, encoder_key, decoder_key = jax.random.split(self.key, 3)encoder_params = self.encoder.init(encoder_key, jnp.ones((self.batch_size, self.input_dim)))['params']decoder_params = self.decoder.init(decoder_key, jnp.ones((self.batch_size, self.latent_dim)))['params']encoder_optx = optax.adam(encoder_lr)decoder_optx = optax.adam(decoder_lr)self.encoder_state = TrainState.create(apply_fn=self.encoder.apply, params=encoder_params, tx=encoder_optx)self.decoder_state = TrainState.create(apply_fn=self.decoder.apply, params=decoder_params, tx=decoder_optx)@staticmethod@jax.jitdef forward(x, encoder_state, decoder_state, now_key):mu, log_std = encoder_state.apply_fn({"params": encoder_state.params}, x)now_key, eps_key = jax.random.split(now_key, 2)eps = jax.random.normal(eps_key, shape=mu.shape)latent = mu + eps * jnp.exp(log_std * 0.5)x_ = decoder_state.apply_fn({"params": decoder_state.params}, latent)return x_, now_key@staticmethod@jax.jitdef train_step(data, encoder_state, decoder_state, key):def loss_fn(encoder_param, decoder_param, encoder_state, decoder_state, now_key):mu, log_std = encoder_state.apply_fn({"params": encoder_param}, data)now_key, eps_key = jax.random.split(now_key, 2)eps = jax.random.normal(eps_key, shape=mu.shape)latent = mu + eps * jnp.exp(log_std * 0.5)x_ = decoder_state.apply_fn({"params": decoder_param}, latent)construction_loss = jnp.sum(jnp.sum(jnp.square(x_ - data), axis=1))commitment_loss = -0.5 * jnp.sum(1 + log_std - mu ** 2 - jnp.exp(log_std))loss = construction_loss + commitment_lossreturn loss, (construction_loss, commitment_loss, now_key)(loss, (construction, commitment_loss, key)), grads = jax.value_and_grad(loss_fn, has_aux=True, argnums=(0, 1))(encoder_state.params, decoder_state.params, encoder_state, decoder_state, key)encoder_state = encoder_state.apply_gradients(grads=grads[0])decoder_state = decoder_state.apply_gradients(grads=grads[1])return encoder_state, decoder_state, construction, commitment_loss, keydef train(self, train_x, N):for epoch in range(self.epochs):self.key, permutation_key = jax.random.split(self.key, 2)shuffled_indices = jax.random.permutation(permutation_key, N)now_data = train_x[shuffled_indices, :]# now_data = train_xtot_construction_loss, tot_commitment_loss = 0, 0for i in range(0, N, self.batch_size):batch_x = now_data[i: i + self.batch_size]self.encoder_state, self.decoder_state, construction_loss, commitment_loss, self.key = VAE.train_step(batch_x, self.encoder_state, self.decoder_state, self.key)tot_construction_loss += construction_losstot_commitment_loss += commitment_lossnow = datetime.now()time_str = now.strftime("%Y-%m-%d %H:%M:%S")print(f"Epoch {epoch + 1}, Construction_loss: {tot_construction_loss / N:.4f}, Commitment_loss: {tot_commitment_loss / N:.4f}! Time: {time_str}.")self.logger.record("Construction_loss", float(tot_construction_loss) / N)self.logger.record("Commitment_loss", float(tot_commitment_loss) / N)self.logger.dump(step=epoch + 1)def plot(test_x, test_y, VAE_model):original_images = []for digit in range(10):original_image = [test_x[i] for i in range(len(test_x)) if test_y[i] == digit][111]original_images.append(jnp.asarray(original_image.reshape(784)))input = jnp.stack(original_images, axis=0)input = jnp.asarray(input, dtype=jnp.float32)output, VAE_model.key = VAE_model.forward(input, VAE_model.encoder_state, VAE_model.decoder_state, VAE_model.key)# output, _, __ = VAE_model(input_tensor)generated_images = list(np.array(jax.lax.stop_gradient(output)))fig, axes = plt.subplots(2, 10, figsize=(15, 3))for i in range(10):axes[0, i].imshow(original_images[i].reshape(28, 28), cmap='gray')axes[0, i].set_title(f"Original {i}")axes[0, i].axis('off')axes[1, i].imshow(generated_images[i].reshape(28, 28), cmap='gray')axes[1, i].set_title(f"Generated {i}")axes[1, i].axis('off')plt.tight_layout()plt.savefig('result.png')plt.show()def main():start_time = datetime.now().strftime('%Y%m%d_%H%M%S')log_path = f"logs/VAEtest_{start_time}/"logger = configure(log_path, ["tensorboard"])train_x, train_y, test_x, test_y, N, M = get_data()key = jax.random.PRNGKey(41)VAEmodel = VAE(input_dim=784, encoder_lr=0.001, decoder_lr=0.001, epochs=30, batch_size=128, logger=logger, key=key)VAEmodel.train(train_x, N)plot(test_x, test_y, VAEmodel)if __name__ == '__main__':main()

实验结果

重建误差稳定下降:

image

在测试集中随机抽取(并非随机)每个数字各一个样例,展示重建后的图片,其实效果不太符合预期,应该是要调参,但是我懒得调了:

image

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/908002.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

BLE蓝牙配网双模式实操:STA+SoftAP技术原理与避坑指南

想让设备同时支持蓝牙快速配网与AP热点备份?STA+SoftAP双模式是关键!本文深度解析技术原理,结合真实项目案例总结实操避坑点,助你一文搞懂双模式配网逻辑,开发少走90%弯路。 本文特别分享蓝牙配网方案: 以Air800…

【小白也能懂】PyTorch 里的 0.5 到底是干啥的?——一次把 Normalize 讲透! - 教程

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

第58天:RCE代码amp;命令执行amp;过滤绕过amp;异或无字符amp;无回显方案amp;黑白盒挖掘

第58天:RCE代码&命令执行&过滤绕过&异或无字符&无回显方案&黑白盒挖掘 案例演示RCE & 代码执行 & 命令执行 RCE-利用&绕过&异或&回显 白盒-CTF-RCE代码命令执行 黑盒-运行-R…

057-Web攻防-SSRFDemo源码Gopher项目等

057-Web攻防-SSRF&Demo源码&Gopher项目等 知识点知识点: 1、SSRF-原理-外部资源加载 2、SSRF-利用-伪协议&无回显 3、SSRF-挖掘-业务功能&URL参数案例演示1、SSRF-原理&挖掘&利用&修复 2…

060-WEB攻防-PHP反序列化POP链构造魔术方法流程漏洞触发条件属性修改

060-WEB攻防-PHP反序列化&POP链构造&魔术方法流程&漏洞触发条件&属性修改 知识点: 1、PHP-反序列化-应用&识别&函数 2、PHP-反序列化-魔术方法&触发规则 3、PHP-反序列化-联合漏洞&P…

059-Web攻防-XXE安全DTD实体复现源码等

059-Web攻防-XXE安全&DTD实体&复现源码等 知识点 XML&XXE-传输-原理&探针&利用&玩法 XML&XXE-黑盒-JS&黑盒测试&类型修改 XML&XXE-白盒-CMS&PHPSHE&无回显什么是XML?…

061-WEB攻防-PHP反序列化原生类TIPSCVE绕过漏洞属性类型特征

061-WEB攻防-PHP反序列化&原生类TIPS&CVE绕过漏洞&属性类型特征知识点 1、PHP-反序列化-属性类型&显示特征 2、PHP-反序列化-CVE绕过&字符串逃逸 3、PHP-反序列化-原生类生成&利用&配合1、…

051-Web攻防-文件安全目录安全测试源码等

051-Web攻防-文件安全&目录安全&测试源码等 知识点1、文件安全-前后台功能点-下载&读取&删除 2、目录安全-前后台功能点-目录遍历&目录穿越演示案例:➢文件安全-下载&删除-案例黑白盒 ➢目录…

Dilworth定理及其在算法题中的应用

1. Dilworth定理 Dilworth定理由数学家Robert P. Dilworth于1950年提出,它描述了偏序集中链和反链之间的关系。偏序集:一个集合 equipped with a partial order(即一个自反、反对称、传递的关系)。 链:偏序集的一…

error: xxxxx does not have a commit checked out

$ git commit -m "test" *error: AW30N does not have a commit checked outfatal: updating files failed解决方法: 删除多余的文件夹AW30N

049-WEB攻防-文件上传存储安全OSS对象分站解析安全解码还原目录执行

049-WEB攻防-文件上传&存储安全&OSS对象&分站&解析安全&解码还原&目录执行-cnblog#文件-解析方案-执行权限&解码还原 1、执行权限文件上传后存储目录不给执行权限 原理:开启禁止目录执行…

云原生周刊:MetalBear 融资、Chaos Mesh 漏洞、Dapr 1.16 与 AI 平台新趋势

云原生热点 MetalBear 获得 1250 万美元种子轮融资,推动 Kubernetes 开发解决方案 以色列初创公司 MetalBear 宣布完成 1250 万美元种子轮融资,由 TLV Partners 领投,TQ Ventures、MTF、Netz Capital 及多位知名天使…

AI一周资讯 250913-250919

原文: https://mp.weixin.qq.com/s/bnJ-kyOojPi6rqgx0NOXxg 阿里版Cursor正式收费!Qoder全球推出付费订阅,小白用了都说“最懂我” 2025年9月15日,阿里AI编程平台Qoder(被称为“阿里版Cursor”)面向全球用户正式推…

045-WEB攻防-PHP应用SQL二次注入堆叠执行DNS带外功能点黑白盒条件-cnblog

045-WEB攻防-PHP应用&SQL二次注入&堆叠执行&DNS带外&功能点&黑白盒条件-cnblog PHP-MYSQL-二次注入-DEMO&74CMS1、DEMO-用户注册登录修改密码1.注册新用户时,将注入的内容包含在注册的用户名…

linux 命令语句

rt 快csp初赛了,发现s组第一题一般是linux 命令语句,碎屑。 c cat:连接和显示文件内容 cd:切换工作目录 chmod:修改文件或目录的权限 chown:修改文件或目录的所有者 cp:复制文件或目录 d df/du:显示磁盘使用情…

用 Kotlin 实现英文数字验证码识别

在本教程中,我们将使用 Kotlin 和 Tesseract OCR 库实现对英文数字验证码的识别。Tesseract 是一个开源的 OCR 引擎,能够从图像中提取文本内容。结合 Kotlin 的简洁语法,我们可以高效地完成这个任务。环境准备 (1)…

UM2003A 一款 200 ~ 960MHz ASK/OOK +18dBm 发射功率的单发射

UM2003A 一款 200 ~ 960MHz ASK/OOK +18dBm 发射功率的单发射Si24R03 是一款高度集成的低功耗 SOC 芯片,其集成了基于 RISC-V 核的低功耗 MCU 和工作在 2.4GHz ISM 频段的无线收发器模块。 MCU 模块具有低功耗、L…

达芬奇(DaVinci Reslove)字体文件 bugb标签

今天有小伙伴 突然 反馈 字幕轨导 execl的插件 字幕没有导出来。我拿到 srt文件是 文件1这样格式,我导入字幕轨后 ,再导出来 格式变了 多了 一个 <…

语音芯片怎样挑选?语音芯片关键选型要点?

语音芯片怎样挑选?语音芯片关键选型要点? 选择语音芯片需根据具体应用场景和性能需求进行综合评估,以下是关键选型要点: 一、核心性能参数 1、采样率与信噪比 高采样率(如16位ADC)可减少声音失真,信噪比≥75dB能…

KingbaseES Schema权限及空间限额

一、权限授予操作 1. 基础权限赋予 1.1 创建测试环境-- 1.创建测试用户 test=# CREATE USER schema_user WITH PASSWORD Schema@123; CREATE ROLE-- 2.创建测试Schema test=# CREATE SCHEMA test_schema AUTHORIZATION…