网站开发工程师和软件工程手机网站 底部菜单
web/
2025/10/8 16:56:20/
文章来源:
网站开发工程师和软件工程,手机网站 底部菜单,杭州网站推广找哪家,网站后台模板 jquery许多名画造假者费尽毕生的心血#xff0c;试图模仿出艺术名家的风格。如今#xff0c;CycleGAN就可以初步实现这个神奇的功能。这个功能就是风格迁移#xff0c;比如下图#xff0c;照片可以被赋予莫奈#xff0c;梵高等人的绘画风格这属于是无配对数据(unpaired)产生的图…许多名画造假者费尽毕生的心血试图模仿出艺术名家的风格。如今CycleGAN就可以初步实现这个神奇的功能。这个功能就是风格迁移比如下图照片可以被赋予莫奈梵高等人的绘画风格这属于是无配对数据(unpaired)产生的图片也就是说你有一些名人名家的作品也有一些你想转换风格的真实图片这两种图片是没有任何交集的。在之前的文章(用AI增强人类想象力)中提到的Pix2Pix方法的关键是提供了在这两个域中有相同数据的训练样本。CycleGAN的创新点在于能够在源域和目标域之间无须建立训练数据间一对一的映射就可实现这种迁移想要做到这点有两个比较重要的点第一个就是双判别器。如上图a所示两个分布X,Y生成器GF分别是X到Y和Y到X的映射两个判别器Dx,Dy可以对转换后的图片进行判别。第二个点就是cycle-consistency loss用数据集中其他的图来检验生成器这是防止G和F过拟合比如想把一个小狗照片转化成梵高风格如果没有cycle-consistency loss生成器可能会生成一张梵高真实画作来骗过Dx而无视输入的小狗。需要注意的是广为流传的下图有个容易让人理解错误的地方那就是下图中的input和output那几张图两匹马应该除了花纹其他一致的除此之外结构还是挺清晰的对抗损失生成器和判别器的loss函数和GAN是一样的判别器D尽力检测出生成器G产生的假图片生成器尽力生成图片骗过判别器具体数理推导可以看我专栏之前的文章李刚GAN 对抗生成网络入门辅助理解zhuanlan.zhihu.com对抗loss由两部分组成以及Cycle Consistency 损失作者说理论上对抗训练可以学习映射输出G和F它们分别作为目标域Y和X产生相同的分布。然而具有足够大的容量网络可以将相同的输入图像集合映射到目标域中的任何图像的随机排列。因此单独的对抗性loss不能保证可以映射单个输入。需要另外来一个loss,保证G和F不仅能满足各自的判别器还能应用于其他图片。也就是说G和F可能合伙偷懒骗人给G一个图G偷偷把小狗变成梵高自画像F再把梵高自画像变成输入。Cycle Consistency loss的到来制止了这种投机取巧的行为他用梵高其他的画作测试FG用另外真实照片测试GF看看能否变回到原来的样子这样保证了GF在整个XY分布区间的普适性。整体所以整个loss就是下面的式子就像训练两个auoto-encoder一样作者在后文比对了单独拿出不同部分的效果比如只用Cycle Consistency loss只用对抗GAN 前向cycle-consistency loss (F(G(x)) ≈ x), GAN 后向 cycle-consistency loss (G(F(y)) ≈ y)以及cycleGAN的效果。代码实现首先是一些参数ngf 32 # Number of filters in first layer of generatorndf 64 # Number of filters in first layer of discriminatorbatch_size 1 # batch_sizepool_size 50 # pool_sizeimg_width 256 # Imput image will of width 256img_height 256 # Input image will be of height 256img_depth 3 # RGB format构造生成器Generator(EncoderTransformerDecoder)假设所有图片都是256*256的彩图需要先用卷积神经网络提取特征在这里input_gen是输入图像num_features是我们从卷积层中提取出的输出特征的数量(滤波器的数量)window_widthwindow_height代表滤波器尺寸。 stride_widthstrideheight是滤波器如何在整个图上移动的参数。输出的O_C1是尺寸[25625632]的矩阵。也可以在后边自行添加Relu等函数。o_c1 general_conv2d(input_gen,num_featuresngf,window_width7,window_height7,stride_width1,stride_height1)#定义卷积层函数def general_conv2d(inputconv, o_d64, f_h7, f_w7, s_h1, s_w1):with tf.variable_scope(name):conv tf.contrib.layers.conv2d(inputconv, num_features, [window_width, window_height], [stride_width, stride_height],padding, activation_fnNone, weights_initializertf.truncated_normal_initializer(stddevstddev),biases_initializertf.constant_initializer(0.0))后面是相似的卷积步骤最后一层输出o_enc_A是(6464256)的矩阵o_c2 general_conv2d(o_c1, num_features64*2, window_width3, window_height3, stride_width2, stride_height2)# o_c2.shape (128, 128, 128)o_enc_A general_conv2d(o_c2, num_features64*4, window_width3, window_height3, stride_width2, stride_height2)# o_enc_A.shape (64, 64, 256)Transformer可以将这些层视为图像的不同附近特征的组合然后基于这些特征来决定如何将图像的特征向量转换到另一个分布。作者使用了6层resnet块其中输入的残差被添加到输出中。这样做是为了确保先前层的输入的属性也可用于以后的层因此它们的输出不会偏离原始输入否则原始图像的特性将不被保留在输出中。任务的主要目的之一是保留原始输入的特性如对象的大小和形状因此残差网络非常适合这些类型的变换。关于resnet详见 ResNet原理及其在TF-Slim中的实现o_r1 build_resnet_block(o_enc_A, num_features64*4)o_r2 build_resnet_block(o_r1, num_features64*4)o_r3 build_resnet_block(o_r2, num_features64*4)o_r4 build_resnet_block(o_r3, num_features64*4)o_r5 build_resnet_block(o_r4, num_features64*4)o_enc_B build_resnet_block(o_r5, num_features64*4)#定义resnetdef resnet_blocks(input_res, num_features):out_res_1 general_conv2d(input_res, num_features,window_width3,window_heigth3,stride_width1,stride_heigth1)out_res_2 general_conv2d(out_res_1, num_features,window_width3,window_heigth3,stride_width1,stride_heigth1)return (out_res_2 input_res)下面是decoder用反卷积把这些特征变回成图片o_d1 general_deconv2d(o_enc_B, num_featuresngf*2 window_width3, window_height3, stride_width2, stride_height2)o_d2 general_deconv2d(o_d1, num_featuresngf, window_width3, window_height3, stride_width2, stride_height2)gen_B general_conv2d(o_d2, num_features3, window_width7, window_height7, stride_width1, stride_height1)#定义反卷积层def general_deconv2d(inputconv, outshape, o_d64, f_h7, f_w7, s_h1, s_w1, stddev0.02, paddingVALID, namedeconv2d, do_normTrue, do_reluTrue, relufactor0):with tf.variable_scope(name):conv tf.contrib.layers.conv2d_transpose(inputconv, o_d, [f_h, f_w], [s_h, s_w], padding, activation_fnNone, weights_initializertf.truncated_normal_initializer(stddevstddev),biases_initializertf.constant_initializer(0.0))if do_norm:conv instance_norm(conv)# conv tf.contrib.layers.batch_norm(conv, decay0.9, updates_collectionsNone, epsilon1e-5, scaleTrue, scopebatch_norm)if do_relu:if(relufactor 0):conv tf.nn.relu(conv,relu)else:conv lrelu(conv, relufactor, lrelu)return conv判别器的构成在这里救不赘述了无非就是用CNN把生成的图片变成一些特征图再用全连接变成最后的decision(真或假)定义loss function判别器lossloss_1是对于真图的判定越接近1越好loss_2是对于假图的判定越接近0越好loss是两个loss相加D_A_loss_1 tf.reduce_mean(tf.squared_difference(dec_A,1))D_B_loss_1 tf.reduce_mean(tf.squared_difference(dec_B,1))D_A_loss_2 tf.reduce_mean(tf.square(dec_gen_A))D_B_loss_2 tf.reduce_mean(tf.square(dec_gen_B))D_A_loss (D_A_loss_1 D_A_loss_2)/2D_B_loss (D_B_loss_1 D_B_loss_2)/2生成器loss:g_loss_B_1 tf.reduce_mean(tf.squared_difference(dec_gen_A,1))g_loss_A_1 tf.reduce_mean(tf.squared_difference(dec_gen_A,1))Cycle Consistency loss: 保证原始图像和循环图像之间的差异应该尽可能小注意10*cyc_loss是赋予Cycle Consistency loss更大的权值作者并没有讨论这个参数是怎么确定下来的cyc_loss tf.reduce_mean(tf.abs(input_A-cyc_A)) tf.reduce_mean(tf.abs(input_B-cyc_B))g_loss_A g_loss_A_1 10*cyc_lossg_loss_B g_loss_B_1 10*cyc_loss模型训练for epoch in range(0,100):# Define the learning rate schedule. The learning rate is kept# constant upto 100 epochs and then slowly decayedif(epoch 100) :curr_lr 0.0002else:curr_lr 0.0002 - 0.0002*(epoch-100)/100# Running the training loop for all batchesfor ptr in range(0,num_images):# Train generator G_A-B_, gen_B_temp sess.run([g_A_trainer, gen_B],feed_dict{input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})# We need gen_B_temp because to calculate the error in training D_B_ sess.run([d_B_trainer],feed_dict{input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})# Same for G_B-A and D_A as follow_, gen_A_temp sess.run([g_B_trainer, gen_A],feed_dict{input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})_ sess.run([d_A_trainer],feed_dict{input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/89160.shtml
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!