315网站行业公司注册网站官网
news/
2025/10/4 3:48:58/
文章来源:
315网站行业,公司注册网站官网,保定seo关键词优化外包,挖矿网站怎么免费建设文章目录 模型介绍网络结构数据集可视化网络的其他细节模型推理 模型介绍
CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络#xff0c;实现了一种在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y 的方法。
该模型一个重要应用领域是域迁移(Do… 文章目录 模型介绍网络结构数据集可视化网络的其他细节模型推理 模型介绍
CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络实现了一种在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y 的方法。
该模型一个重要应用领域是域迁移(Domain Adaptation)即图像风格迁移。在 CycleGAN 之前就已经有了域迁移模型比如 Pix2Pix昇思学习打卡-19-生成式/Pix2Pix实现图像转换 但是 Pix2Pix 要求训练数据必须是成对的而现实生活中要找到两个域画风中成对出现的图片是相当困难的因此 CycleGAN 诞生了它只需要两种域的数据而不需要他们有严格对应关系是一种新的无监督的图像迁移网络。
网络结构
CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成下面这个例子以苹果和橘子为例介绍讲解的很形象 下图中可以理解为苹果为橘子为将苹果生成橘子风格的生成器为将橘子生成的苹果风格的生成器和为其相应判别器。模型最终能够输出两个模型的权重分别将两种图像的风格进行彼此迁移生成新的图像。 该网络需要多个损失函数在所有损失里面循环一致损失(Cycle Consistency Loss)是最重要的可以这样理解 下图中苹果图片经过生成器得到伪橘子̂然后将伪橘子̂结果送进生成器 又产生苹果风格的结果 ̂ 最后将生成的苹果风格结果 ̂ 与原苹果图片 一起计算出循环一致损失反之亦然。循环损失捕捉了这样的直觉即如果我们从一个域转换到另一个域然后再转换回来我们应该到达我们开始的地方。 循环一致损失能够保证重建图像与输入图像紧密匹配。
数据集可视化
import numpy as np
import matplotlib.pyplot as pltmean 0.5 * 255
std 0.5 * 255plt.figure(figsize(12, 5), dpi60)
for i, data in enumerate(dataset.create_dict_iterator()):if i 5:show_images_a data[image_A].asnumpy()show_images_b data[image_B].asnumpy()plt.subplot(2, 5, i1)show_images_a (show_images_a[0] * std mean).astype(np.uint8).transpose((1, 2, 0))plt.imshow(show_images_a)plt.axis(off)plt.subplot(2, 5, i6)show_images_b (show_images_b[0] * std mean).astype(np.uint8).transpose((1, 2, 0))plt.imshow(show_images_b)plt.axis(off)else:break
plt.show()网络的其他细节 构建生成器时此模型使用ResNet 模型的结构 构建判别器判别器其实是一个二分类网络模型输出判定该图像为真实图的概率。 定义优化器和损失函数优化器使用Adam关于损失函数主要关注循环一致损失函数 前向计算使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。 计算梯度和反向传播其中梯度计算也是分开不同的模型来进行的 最后是模型训练模型训练训练分为两个主要部分训练判别器和训练生成器在前文的判别器损失函数中论文采用了最小二乘损失代替负对数似然目标。 训练判别器训练判别器的目的是最大程度地提高判别图像真伪的概率。按照论文的方法需要训练判别器来最小化 −()[(()−1)2] 训练生成器如 CycleGAN 论文所述我们希望通过最小化 −()[((()−1)2]来训练生成器以产生更好的虚假图像。
模型推理
%%time
import os
from PIL import Image
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from mindspore import load_checkpoint, load_param_into_net# 加载权重文件
def load_ckpt(net, ckpt_dir):param_GA load_checkpoint(ckpt_dir)load_param_into_net(net, param_GA)g_a_ckpt ./CycleGAN_apple2orange/ckpt/g_a.ckpt
g_b_ckpt ./CycleGAN_apple2orange/ckpt/g_b.ckptload_ckpt(net_rg_a, g_a_ckpt)
load_ckpt(net_rg_b, g_b_ckpt)# 图片推理
fig plt.figure(figsize(11, 2.5), dpi100)
def eval_data(dir_path, net, a):def read_img():for dir in os.listdir(dir_path):path os.path.join(dir_path, dir)img Image.open(path).convert(RGB)yield img, dirdataset ds.GeneratorDataset(read_img, column_names[image, image_name])trans [vision.Resize((256, 256)), vision.Normalize(mean[0.5 * 255] * 3, std[0.5 * 255] * 3), vision.HWC2CHW()]dataset dataset.map(operationstrans, input_columns[image])dataset dataset.batch(1)for i, data in enumerate(dataset.create_dict_iterator()):img data[image]fake net(img)fake (fake[0] * 0.5 * 255 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))img (img[0] * 0.5 * 255 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))fig.add_subplot(2, 8, i1a)plt.axis(off)plt.imshow(img.asnumpy())fig.add_subplot(2, 8, i9a)plt.axis(off)plt.imshow(fake.asnumpy())eval_data(./CycleGAN_apple2orange/predict/apple, net_rg_a, 0)
eval_data(./CycleGAN_apple2orange/predict/orange, net_rg_b, 4)
plt.show()推理结果如下 此章节学习到此结束感谢昇思平台。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/926569.shtml
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!