网站建设与管理教学视频教程做地方网站数据哪里来
网站建设与管理教学视频教程,做地方网站数据哪里来,网站开发语言占有率,做微信活动是做网站还是做小程序好1.CGAN原理
生成器#xff0c;输入的是c和z#xff0c;z是随机噪声#xff0c;c是条件#xff0c;对应MNIST数据集#xff0c;要求规定生成数字是几。 输出是生成的虚假图片。
判别器的输入是 1.生成器输出的虚假图片x; 2.对应图片的标签c 来自真实数据集#xff0c;且…1.CGAN原理
生成器输入的是c和zz是随机噪声c是条件对应MNIST数据集要求规定生成数字是几。 输出是生成的虚假图片。
判别器的输入是 1.生成器输出的虚假图片x; 2.对应图片的标签c 来自真实数据集且标签是对的就是1 如果是生成器生成的虚假照片就直接是1都不需要看是否与标签对应
上面第二张图的意思就是当图片是来自真实数据集再来看是否与标签对应
2.CGAN损失函数 上面这个值生成器越小越好即判别器认为真实图片是真实图片的概率越低越好认为虚假图片是真实图片的概率越高越好 判别器越大越好即判别器认为真实图片是真实图片的概率越大越好认为虚假图片是真实图片的概率越小越好
criterion(output, label)
在判别器中 1output是预测来自真实数据集的图片和标签是否是真实且符合标签的概率label是1 2output是预测虚假图片是否是虚假图片的概率label是0 在生成器中 output是判别器预测虚假图片是否是真实图片的概率label是1 以上三种都是交叉熵越小越好
3.生成器和判别器的源码
class Generator(nn.Module):def __init__(self, num_channel1, nz100, nc10, ngf64):super(Generator, self).__init__()self.main nn.Sequential(# 输入维度 110 x 1 x 1nn.ConvTranspose2d(nz nc, ngf * 8, 4, 1, 0, biasFalse),nn.BatchNorm2d(ngf * 8),nn.ReLU(True),# 特征维度 (ngf*8) x 4 x 4nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, biasFalse),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# 特征维度 (ngf*4) x 8 x 8nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, biasFalse),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# 特征维度 (ngf*2) x 16 x 16nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, biasFalse),nn.BatchNorm2d(ngf),nn.ReLU(True),# 特征维度 (ngf) x 32 x 32nn.ConvTranspose2d(ngf, num_channel, 4, 2, 1, biasFalse),nn.Tanh()# 特征维度. (num_channel) x 64 x 64)self.apply(weights_init)def forward(self, input_z, onehot_label):input_ torch.cat((input_z, onehot_label), dim1)n, c input_.size()input_ input_.view(n, c, 1, 1)return self.main(input_)class Discriminator(nn.Module):def __init__(self, num_channel1, nc10, ndf64):super(Discriminator, self).__init__()self.main nn.Sequential(# 输入维度 (num_c3# channelnc) x 64 x 64 1*64*64的图像和10维的类别 10维类别先转换成10*64*64 然后合并就是11*64*64# 输入通道 输出通道 卷积核的大小 步长 填充#原始输入张量b 11 64 64nn.Conv2d(num_channel nc, ndf, 4, 2, 1, biasFalse), #b 64 32 32nn.LeakyReLU(0.2, inplaceTrue),# 特征维度 (ndf) x 32 x 32nn.Conv2d(ndf, ndf * 2, 4, 2, 1, biasFalse), #b 64*2 16 16nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplaceTrue),# 特征维度 (ndf*2) x 16 x 16nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, biasFalse), #b 64*4 8 8nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplaceTrue),# 特征维度 (ndf*4) x 8 x 8nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, biasFalse), #b 64*8 4 4nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplaceTrue),# 特征维度 (ndf*8) x 4 x 4nn.Conv2d(ndf * 8, 1, 4, 1, 0, biasFalse), #b 1 1 1 其实就是一个数值区间在正无穷到负无穷之间nn.Sigmoid())self.apply(weights_init)def forward(self, images, onehot_label):device cuda if torch.cuda.is_available() else cpuh, w images.shape[2:]n, nc onehot_label.shape[:2]label onehot_label.view(n, nc, 1, 1) * torch.ones([n, nc, h, w]).to(device)input_ torch.cat([images, label], 1)return self.main(input_)
4.训练过程
MODEL_G_PATH ./
LOG_G_PATH Log_G.txt
LOG_D_PATH Log_D.txt
IMAGE_SIZE 64
BATCH_SIZE 128
WORKER 1
LR 0.0002
NZ 100
NUM_CLASS 10
EPOCH 10data_loader loadMNIST(img_sizeIMAGE_SIZE, batch_sizeBATCH_SIZE) #原始图片宽高是28*28的给改变成64*64
device torch.device(cuda:0 if torch.cuda.is_available() else cpu)
netG Generator().to(device)
netD Discriminator().to(device)
criterion nn.BCELoss()
real_label 1.
fake_label 0.
optimizerD optim.Adam(netD.parameters(), lrLR, betas(0.5, 0.999))
optimizerG optim.Adam(netG.parameters(), lrLR, betas(0.5, 0.999))g_writer LossWriter(save_pathLOG_G_PATH)
d_writer LossWriter(save_pathLOG_D_PATH)fix_noise torch.randn(BATCH_SIZE, NZ, devicedevice)
fix_input_c (torch.rand(BATCH_SIZE, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(device)
fix_input_c onehot(fix_input_c, NUM_CLASS)img_list []
G_losses []
D_losses []
iters 0print(开始训练)
for epoch in range(EPOCH):print(正在保存网络并评估...)save_network(MODEL_G_PATH, netG, epoch)with torch.no_grad():fake_imgs netG(fix_noise, fix_input_c).detach().cpu()images recover_image(fake_imgs)full_image np.full((5 * 64, 5 * 64, 3), 0, dtypeuint8)for i in range(25):row i // 5col i % 5full_image[row * 64:(row 1) * 64, col * 64:(col 1) * 64, :] images[i]plt.imshow(full_image)#plt.show()plt.imsave({}.png.format(epoch), full_image)for data in data_loader:##################################################判别器交叉熵越小越好# 1. 更新判别器D: 最大化 log(D(x)) log(1 - D(G(z)))# 等同于最小化 - log(D(x)) - log(1 - D(G(z)))#################################################netD.zero_grad()real_imgs, input_c data #这里的input_c其实就是数据集每一批中的每个图片对应的标签input_c input_c.to(device)input_c onehot(input_c, NUM_CLASS).to(device)# 1.1 来自数据集的样本#这里这一步就是想训练判别器能够识别出是否真实图片以及图片与对应的标签是否对应real_imgs real_imgs.to(device)b_size real_imgs.size(0)label torch.full((b_size,), real_label, dtypetorch.float, devicedevice)#上面的torch.full是生成一维的 b_size这么多的填充值为1.的张量# real_label 1.# fake_label 0.# 使用鉴别器对数据集样本做判断output netD(real_imgs, input_c).view(-1) #view() 方法被用来将模型输出的张量进行扁平化操作即将张量中的所有元素都展开成一个一维向量# 计算交叉熵损失 -log(D(x))errD_real criterion(output, label)# 对判别器进行梯度回传errD_real.backward()D_x output.mean().item() #对同一批预测结果的交叉熵取平均值## 1.2 生成随机向量 这一步想要训练判别器是否能够识别出是虚假图片noise torch.randn(b_size, NZ, devicedevice)# 生成随机标签input_c (torch.rand(b_size, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(device)input_c onehot(input_c, NUM_CLASS)#fix_noise torch.randn(BATCH_SIZE, NZ, devicedevice)#fix_input_c (torch.rand(BATCH_SIZE, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(device)#fix_input_c onehot(fix_input_c, NUM_CLASS)# 来自生成器生成的样本fake netG(noise, input_c)label.fill_(fake_label)# real_label 1.# fake_label 0.# 使用鉴别器对生成器生成样本做判断output netD(fake.detach(), input_c).view(-1) #view() 方法被用来将模型输出的张量进行扁平化操作即将张量中的所有元素都展开成一个一维向量# 计算交叉熵损失 -log(1 - D(G(z)))errD_fake criterion(output, label)# 对判别器进行梯度回传errD_fake.backward()D_G_z1 output.mean().item()# 对判别器计算总梯度,-log(D(x))-log(1 - D(G(z)))errD errD_real errD_fake# 更新判别器optimizerD.step()################################################## 2. 更新生成器G: 最小化 log(D(x)) log(1 - D(G(z)))# 等同于最小化log(1 - D(G(z)))即最小化-log(D(G(z)))# 也就等同于最小化-log(D(G(z)))*1log(1-D(G(z)))*0# 令生成器样本标签值为1上式就满足了交叉熵的定义#################################################netG.zero_grad()# 对于生成器训练令生成器生成的样本为真label.fill_(real_label)# real_label 1.# fake_label 0.output netD(fake, input_c).view(-1)# 对生成器计算损失errG criterion(output, label)# 因为这里判别器的角度label真实应该是0但是站在生成器的角度label真实应该是1即生成器希望生成的虚假图片让判别器识别的时候会误以为1才比较好即误以为是真实的图片# 所以生成器交叉熵也是越小越好# 对生成器进行梯度回传errG.backward()D_G_z2 output.mean().item()# 更新生成器optimizerG.step()# 输出损失状态if iters % 5 0:print([%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f% (epoch, EPOCH, iters % len(data_loader), len(data_loader),errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))d_writer.add(losserrD.item(), iiters)g_writer.add(losserrG.item(), iiters)# 保存损失记录G_losses.append(errG.item())D_losses.append(errD.item())iters 1
5.关于交叉熵
熵代表确定性熵越小越好说明确定性越好 在这里因为参照的是真实标签它的熵是0 而交叉熵-熵相对熵 故相对熵在预测情况相对真实情况的时候相对熵交叉熵相对熵越小说明预测情况越接近真实情况 同理交叉熵越小说明预测情况越接近真实情况。
在二分类01任务中经过卷积、正则化、激活函数ReLU等操作之后假如生成了一个(B,1,1,1)的张量,每个值在无穷小无穷大之间经过sigmoid函数会变成一个(B,1,1,1)的张量数值h在01之间如果这个h0.5说明模型预测的是1如果h0.5说明模型预测的是0但是这是模型预测的标签值y*,而还有个真实标签值y。假如现在h0.6,那么说明模型预测的标签y*是1,真实标签却是0
交叉熵 -y(lgh) -(1-y)(lg(1-h)) 即当y1时交叉熵是-lgh 这个情况下h越大越好 当y0时交叉熵是-lg(1-h) 这个情况下h越小越好
6.训练过程运行结果 7.测试结果 测试代码 NZ 100
NUM_CLASS 10
BATCH_SIZE 10
DEVICE cpu# fix_input_c (torch.rand(BATCH_SIZE, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(DEVICE)netG Generator()
netG restore_network(./, 49, netG)
fix_noise torch.randn(BATCH_SIZE, NZ, deviceDEVICE)
fix_input_c torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
device cuda if torch.cuda.is_available() else cpu
fix_input_c onehot(fix_input_c, NUM_CLASS)
fix_input_c fix_input_c.to(device)
fix_noise fix_noise.to(device)
netG netG.to(device)
#fake_imgs netG(fix_noise, fix_input_c).detach().cpu()# images recover_image(fake_imgs)
# full_image np.full((1 * 64, 10 * 64, 3), 0, dtypeuint8)
# for i in range(10):
# row i // 10
# col i % 10
# full_image[row * 64:(row 1) * 64, col * 64:(col 1) * 64, :] images[i]#fix_noise torch.randn(BATCH_SIZE, NZ, deviceDEVICE)
full_image np.full((10 * 64, 10 * 64, 3), 0, dtypeuint8)
for num in range(10):input_c torch.tensor(np.ones(10, dtypeint64) * num)input_c onehot(input_c, NUM_CLASS)fix_noise fix_noise.to(device)input_c input_c.to(device)fake_imgs netG(fix_noise, input_c).detach().cpu()images recover_image(fake_imgs)for i in range(10):row numcol i % 10full_image[row * 64:(row 1) * 64, col * 64:(col 1) * 64, :] images[i]plt.imshow(full_image)
plt.show()
plt.imsave(hah.png, full_image)
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/88318.shtml
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!