超分辨率(2)--基于EDSR网络实现图像超分辨率重建

目录

一.项目介绍

二.项目流程详解

2.1.构建网络模型

2.2.数据集处理

2.3.训练模块

2.4.测试模块

三.测试网络


一.项目介绍

EDSR全称Enhanced Deep Residual Networks,是SRResnet的升级版,其对网络结构进行了优化(去除了BN层),省下来的空间可以用于提升模型的size来增强表现力。

为什么要去除BN层:

Batch Norm是深度学习中非常重要的技术,不仅可以使训练更深的网络变容易,加速收敛,还有一定正则化的效果,可以防止模型过拟合。

但对于图像超分辨率来说,网络输出的图像在色彩、对比度、亮度上要求和输入一致,改变的仅仅是分辨率和一些细节,而Batch Norm,对图像来说类似于一种对比度的拉伸,任何图像经过Batch Norm后,其色彩的分布都会被归一化,也就是说,它破坏了图像原本的对比度信息,所以Batch Norm的加入反而影响了网络输出的质量。

网络结构及对比:

移除BN层后,模型更加轻量,BN层所消耗的存储空间等同于上一层CNN层所消耗的,作者指出相比于SRResNet,EDSR去掉BN层之后节约了40%的存储资源。

同时在BN腾出来的空间下插入更多的类似于残差块等CNN-based子网络来增加模型的表现力。

论文地址:

[1707.02921] Enhanced Deep Residual Networks for Single Image Super-Resolution (arxiv.org)icon-default.png?t=N7T8https://arxiv.org/abs/1707.02921源码地址:

developer0hye/EDAR: PyTorch implementation of Deep Convolution Networks based on EDSR for Compression(Jpeg) Artifacts Reduction (github.com)icon-default.png?t=N7T8https://github.com/developer0hye/EDAR

二.项目流程详解

2.1.构建网络模型

def default_conv(in_channels, out_channels, kernel_size, bias=True):return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size//2), bias=bias)class MeanShift(nn.Conv2d):def __init__(self, rgb_mean, rgb_std, sign=-1):super(MeanShift, self).__init__(3, 3, kernel_size=1)std = torch.Tensor(rgb_std)self.weight.data = torch.eye(3).view(3, 3, 1, 1)self.weight.data.div_(std.view(3, 1, 1, 1))self.bias.data = sign * torch.Tensor(rgb_mean)self.bias.data.div_(std)self.requires_grad = Falseclass ResBlock(nn.Module):def __init__(self, conv, n_feat, kernel_size,bias=True, act=nn.ReLU(True)):super(ResBlock, self).__init__()m = []for i in range(2):m.append(conv(n_feat, n_feat, kernel_size, bias=bias))if i == 0: m.append(act)# m是设置好的conv层# 设置网络内部层次结构为bodyself.body = nn.Sequential(*m)def forward(self, x):# 获取当前的结果res = self.body(x)# 当前得到的网络和最初的网络融合res += xreturn res

class EDAR(nn.Module):def __init__(self, conv=common.default_conv):super(EDAR, self).__init__()# 参数设置n_resblock = 8  # resnet长度n_feats = 64kernel_size = 3  # 卷积核大小#DIV 2K meanrgb_mean = (0.4488, 0.4371, 0.4040)rgb_std = (1.0, 1.0, 1.0)self.sub_mean = common.MeanShift(rgb_mean, rgb_std)# define head module# 经过卷积,特征图数由3->n_featsm_head = [conv(3, n_feats, kernel_size)]# define body module# Residual Block设置m_body = [common.ResBlock(conv, n_feats, kernel_size) for _ in range(n_resblock)]m_body.append(conv(n_feats, n_feats, kernel_size))# define tail module# 经过卷积,特征图数由n_feats->3m_tail = [conv(n_feats, 3, kernel_size)]self.add_mean = common.MeanShift(rgb_mean, rgb_std, 1)# 设置网络的三个层次self.head = nn.Sequential(*m_head)self.body = nn.Sequential(*m_body)self.tail = nn.Sequential(*m_tail)

前向传播过程:

    def forward(self, x):x = self.sub_mean(x)x = self.head(x)res = self.body(x)res += xx = self.tail(res)x = self.add_mean(x)# 将输入input张量每个元素的范围限制到区间 [min,max],返回结果到一个新张量。# 及输出一个新张量值x,并限制他的值在0~1之间return torch.clamp(x,0.0,1.0)

2.2.数据集处理

import os
import io
import random
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = Trueclass Dataset(object):def __init__(self, images_dir, patch_size=48, jpeg_quality=40, transforms=None):self.images = os.walk(images_dir).__next__()[2]self.images_path = []for img_file in self.images:if img_file.endswith((".ppm")):try:#print(os.path.join(images_dir, img_file))label = Image.open(os.path.join(images_dir, img_file))self.images_path.append(os.path.join(images_dir, img_file))except:print(f"Image {os.path.join(images_dir, img_file)} didn't get loaded")self.patch_size = patch_sizeself.jpeg_quality = jpeg_qualityself.transforms = transformsself.random_rotate = [0, 90, 180, 270]def __getitem__(self, idx):label = Image.open(self.images_path[idx]).convert('RGB')label = label.rotate(self.random_rotate[random.randrange(0,4)])# randomly crop patch from training setcrop_x = random.randint(0, label.width - self.patch_size)crop_y = random.randint(0, label.height - self.patch_size)# 使用crop函数对图片进行裁剪label = label.crop((crop_x, crop_y, crop_x + self.patch_size, crop_y + self.patch_size))# additive jpeg noisebuffer = io.BytesIO()label.save(buffer, format='jpeg', quality=random.randrange(self.jpeg_quality+1))input = Image.open(buffer).convert('RGB')if self.transforms is not None:input = self.transforms(input)label = self.transforms(label)#print("Image transformed")return input, labeldef __len__(self):return len(self.images_path)

2.3.训练模块

import argparse
import osfrom dataset import Dataset
from edar import EDARimport torch
from torch import nn
from torch.utils.data.dataloader import DataLoader
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torchvision import transforms
from torchvision.models.vgg import vgg16from utils import AverageMeter
from tqdm import tqdmif __name__ == '__main__':'''It enables benchmark mode in cudnn.benchmark mode is good whenever your input sizes for your network do not vary. This way, cudnn will look for the optimal set of algorithms for that particular configuration (which takes some time). This usually leads to faster runtime.But if your input sizes changes at each iteration, then cudnn will benchmark every time a new size appears, possibly leading to worse runtime performances.'''cudnn.benchmark = Truedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 参数设置parser = argparse.ArgumentParser()# required为true的参数则是必须要设置的参数parser.add_argument('--images_dir', type=str, required=True)parser.add_argument('--outputs_dir', type=str, required=True)parser.add_argument('--jpeg_quality', type=int, default=40)parser.add_argument('--patch_size', type=int, default=48)parser.add_argument('--batch_size', type=int, default=16)parser.add_argument('--num_epochs', type=int, default=400)parser.add_argument('--lr', type=float, default=1e-4)parser.add_argument('--threads', type=int, default=1)parser.add_argument('--seed', type=int, default=123)parser.add_argument('--resume', default='', type=str, metavar='PATH',help='path to latest checkpoint (default: none)')opt = parser.parse_args()# 如果输出文件夹不存在,则自动创建一个文件夹if not os.path.exists(opt.outputs_dir):os.makedirs(opt.outputs_dir)torch.manual_seed(opt.seed)transforms_train = transforms.Compose([transforms.ToTensor()])# 模型设置model = EDAR().to(device)print("Model loaded")if opt.resume:if os.path.isfile(opt.resume):state_dict = model.state_dict()for n, p in torch.load(opt.resume, map_location=lambda storage, loc: storage).items():if n in state_dict.keys():state_dict[n].copy_(p)else:raise KeyError(n)# 损失函数设置criterion = nn.L1Loss()# 优化器设置optimizer = optim.Adam(model.parameters(), lr=opt.lr)print("Data processing started")# 数据集设置dataset = Dataset(opt.images_dir, opt.patch_size, opt.jpeg_quality,transforms=transforms_train)dataloader = DataLoader(dataset=dataset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.threads,pin_memory=True,drop_last=True)print("Data loading completed")#vgg = vgg16(pretrained=True).cuda()#loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
#     for param in loss_network.parameters():
#         param.requires_grad = False# 开始训练for epoch in range(opt.num_epochs):epoch_losses = AverageMeter()print("Length of the dataset is", len(dataset))with tqdm(total=(len(dataset) - len(dataset) % opt.batch_size)) as _tqdm:_tqdm.set_description('epoch: {}/{}'.format(epoch + 1, opt.num_epochs))# 按照dataloader的格式取出datafor data in dataloader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)#print(inputs.size(), labels.size())outs = model(inputs)# 损失值计算,参数是预测值和实际值loss = criterion(outs, labels)#perception_loss = criterion(loss_network(outs), loss_network(labels))#loss = loss + perception_loss*0.06epoch_losses.update(loss.item(), len(inputs))# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 更新参数optimizer.step()_tqdm.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))_tqdm.update(len(inputs))torch.save(model.state_dict(), os.path.join(opt.outputs_dir, '{}_epoch_{}.pth'.format("EDAR_", epoch)))

2.4.测试模块

import argparse
import os
import io
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms
import PIL.Image as pil_image
import globfrom edar import EDARcudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")if __name__ == '__main__':# 参数设置parser = argparse.ArgumentParser()parser.add_argument('--weights_path', type=str, required=True)parser.add_argument('--image_path', type=str, required=True)parser.add_argument('--outputs_dir', type=str, required=True)parser.add_argument('--jpeg_quality', type=int, default=40)parser.add_argument('--input_dir', type=str, required=False)opt, unknown = parser.parse_known_args()model = EDAR()state_dict = model.state_dict()# 参数获取for n, p in torch.load(opt.weights_path, map_location=lambda storage, loc: storage).items():if n in state_dict.keys():state_dict[n].copy_(p)else:raise KeyError(n)model = model.to(device)print(device)model.eval()if opt.input_dir:filenames = [os.path.join(opt.input_dir, file) for file in os.listdir(opt.input_dir) if file.endswith(("ppm", "jpeg", "png", "jpg"))]print(filenames)else:filenames = opt.image_pathif not os.path.exists(opt.outputs_dir):os.makedirs(opt.outputs_dir)# 处理单个测试图片时使用:filename = filenamesprint("file is", filename)input = pil_image.open(filename).convert('RGB')print("Input size:", input.size)print("file is", filename)input = pil_image.open(filename).convert('RGB')print("Input size:", input.size)#buffer = io.BytesIO()#input.save(buffer, format='jpeg', quality=opt.jpeg_quality)#input = pil_image.open(buffer)#input.save(os.path.join(opt.outputs_dir, '{}_jpeg_q{}.png'.format(filename, opt.jpeg_quality)))input = transforms.ToTensor()(input).unsqueeze(0).to(device)output_path = os.path.join(opt.outputs_dir, '{}-{}.jpeg'.format(filename.split("/")[-1].split(".")[0], "edar"))if not os.path.exists(output_path):with torch.no_grad():pred = model(input)[-1]pred = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()output = pil_image.fromarray(pred, mode='RGB')print("Output size", output.size)print("Output dir is", opt.outputs_dir)output.save(output_path)#print(os.path.join(opt.outputs_dir, '{}_{}.png'.format(filename, "EDAR")))#print("Output saved")'''处理多个测试图片时使用:for filename in filenames:print("file is", filename)input = pil_image.open(filename).convert('RGB')print("Input size:", input.size)# buffer = io.BytesIO()# input.save(buffer, format='jpeg', quality=opt.jpeg_quality)# input = pil_image.open(buffer)# input.save(os.path.join(opt.outputs_dir, '{}_jpeg_q{}.png'.format(filename, opt.jpeg_quality)))input = transforms.ToTensor()(input).unsqueeze(0).to(device)output_path = os.path.join(opt.outputs_dir, '{}-{}.jpeg'.format(filename.split("/")[-1].split(".")[0], "edar"))if not os.path.exists(output_path):with torch.no_grad():pred = model(input)[-1]pred = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()output = pil_image.fromarray(pred, mode='RGB')print("Output size", output.size)print("Output dir is", opt.outputs_dir)output.save(output_path)# print(os.path.join(opt.outputs_dir, '{}_{}.png'.format(filename, "EDAR")))# print("Output saved")'''

三.测试网络

参数设置:

输入图片:

输出图片:

输入图片:

输出图片:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/746353.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

避免阻塞主线程 —— Web Worker 示例项目

前期回顾 迄今为止易用 —— 的 “盲水印“ 实现方案-CSDN博客https://blog.csdn.net/m0_57904695/article/details/136720192?spm1001.2014.3001.5501 目录 CSDN 彩色之外 📝 前言 🚩 技术栈 🛠️ 功能 🤖 如何运行 ♻️ …

《JAVA与模式》之工厂方法模式

系列文章目录 文章目录 系列文章目录前言一、工厂方法模式二、工厂方法模式的活动序列图三、工厂方法模式和简单工厂模式前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码…

【个人记录】CentOS7安装MySQL 5.7和libmysqlclient.so.20

记录 之前使用MariaDB 发现使用的libmysqlclient.so是18版本的,一些程序需要20版本的库,查了一下需要安装5.7以上版本的才有libmysqlclient.so.20,这里简单记录一下怎么安装。 安装MySQL 5.7 Yum源 yum install -y https://repo.mysql.com…

vuex怎么防止数据刷新丢失?

在Vue.js应用程序中,Vuex是一个用于管理应用程序状态的状态管理库。默认情况下,Vuex的状态存储在内存中,并且在页面刷新时会丢失。 为了防止数据刷新丢失,你可以考虑以下几种方法: 这些方法可以帮助你防止Vuex数据刷新…

浅析C++的指针与引用

浅析C的指针与引用 文章目录 浅析C的指针与引用一、对比引用与指针二、引用左值引用右值引用引用折叠 三、指针与引用的性能差距总结 一、对比引用与指针 总论: 引用指针必须初始化可以不初始化不能为空可以为空不能更换目标可以更换目标 引用必须初始化&#xff…

如何用SMU数字源表测试apd管的暗电流

01 APD工作原理 APD雪崩光电二极管的工作原理是基于光电效应和雪崩效应,当光子被吸收时,会产生电子空穴对,空穴向P区移动,电子向N区移动,由于电场的作用,电子与空穴相遇时会产生二次电子,形成雪…

串行通信——IIC总结

一.什么是IIC? IIC(Inter-Integrated Circuit)也称I2C,中文叫集成电路总线。是一个多主从的串行总线,由飞利浦公司发明的通讯总线,属于半双工同步传输类总线,仅由两条线就能完成多机通讯&#…

Android 辅助功能 -抢红包(二)

Android 辅助功能 -抢红包(二) 本篇文章继续讲述辅助功能实现抢红包的方案. 上篇文章主要讲了下辅助功能的基本使用,本文涉及的一些基础内容就不再赘述了. 有疑问的可以查看上篇文章: Android 辅助功能 -抢红包 1: 添加微信监听 修改xml文件,android:packageNames中新增微…

【解读】区块链和分布式记账技术标准体系建设指南

大家好,这里是苏泽。一个从业Java后端的区块链技术爱好者。 今天带大家来解读这份三部门印发的行业建设指南《区块链和分布式记账技术标准体系建设指南》 原文件可查看P020240112840724196854.pdf (www.gov.cn) 以下是个人解读,如有纰漏请指正&#xff…

Nginx 报错 504 Gateway Time-out 的解决方法

报错信息 504 Gateway Time-out 原因是程序执行时间过长,导致请求超时。 解决方法 首先,尽可能地优化程序代码的执行时间。 其次,修改配置文件。 修改 php.ini 配置文件。 max_execution_time 600 复制 修改 nginx.conf 配置文件。…

KY9 成绩排序

描述&#xff1a; 用一维数组存储学号和成绩&#xff0c;然后&#xff0c;按成绩排序输出。 输入描述&#xff1a; 输入第一行包括一个整数N(1<N<100)&#xff0c;代表学生的个数。 接下来的N行每行包括两个整数p和q&#xff0c;分别代表每个学生的学号和成绩。 输出描述…

【系统架构师】-第16章-嵌入式系统架构设计理论与实践

1、嵌入式系统发展 第一阶段&#xff1a;单片微型计算机 (SCM) 阶段&#xff0c;即单片机时代&#xff0c;五操作系统 第二阶段&#xff1a;微控制器 (MUC) 阶段&#xff0c;有简单操作系统 第三阶段&#xff1a;片上系统 (SoC)&#xff0c;兼容各种微处理器 第四阶段&…

常见滤波方式的区别的优势

一. 限幅滤波法 给定一个最大偏差X&#xff0c;如果本次值与上次差值小于X&#xff0c;则本次有效&#xff0c;否则无效&#xff0c;使用上次值代替。 #incldue <stdio.h>#define X 2 int lastvalue; //限幅滤波法 int filter(void) {int nowValue ;nowValue getValue…

软件测试 —— 测试用例设计报告

写出测试网站的测试用例&#xff0c;测试网站具体内容可看团购网站系统需求说明书1.2.doc 一、流程1&#xff1a;注册→登录 图1&#xff1a;注册->登录流程图 1、 使用场景设计法设计测试用例 1&#xff09; 找出基本流和备选流 基本流注册用户-成功登录系统备选流1注册…

Jenkins cron定时构建触发器

from&#xff1a; https://www.jenkins.io/doc/book/pipeline/syntax/#cron-syntax 以下内容为根据Jenkins官方文档cron表达式部分翻译过来&#xff0c;使用机翻加个人理解补充内容&#xff0c;包括举例。 目录 介绍举例&#xff1a;设置方法方法一&#xff1a;方法二&#xf…

3.2_1 虚拟内存的基本概念

3.2_1 虚拟内存的基本概念 虚拟存储技术也是存储空间扩充的一种技术&#xff0c;它比交换、覆盖技术要更先进一些。 &#xff08;一&#xff09;传统存储管理方式的特征、缺点 对于这种传统的存储管理方案&#xff0c;很多暂时用不到的数据也会长期占用内存&#xff0c;导致内存…

R 语言patchwork包拼图间隙

在R语言中&#xff0c;patchwork包是一个非常强大的工具&#xff0c;允许你轻松地将多个图表拼接在一起。如果你希望调整拼图间的间隙&#xff08;即图表之间的空白区域&#xff09;&#xff0c;可以通过使用plot_layout()函数来实现&#xff0c;其中可以指定guides参数和spaci…

【数据结构和算法初阶(C语言)】栈的概念和实现(后进先出---后来者居上的神奇线性结构带来的惊喜体验)

目录 1.栈 1.1栈的概念及结构 2.栈的实现 3.栈结构对数据的处理方式 3.1对栈进行初始化 3.2 从栈顶添加元素 3.3 打印栈元素 3.4移除栈顶元素 3.5获取栈顶元素 3.6获取栈中的有效个数 3.7 判断链表是否为空 3.9 销毁栈空间 4.结语及整个源码 1.栈 1.1栈的概念及结构 栈&am…

Codeforces Round 933 (Div. 3)

比赛地址传送门 A. Rudolf and the Ticket 题目大意&#xff1a; 给定两个数组和一个 k&#xff0c;要求从两个数组中各选一个数求和不大于 k&#xff0c;有多少种方案 思路&#xff1a; 维护一个数组 f[i] 代表小于等于 i 的数字的数量&#xff0c;遍历另一个数组&#xff0…

遇到:java.lang.reflect.InaccessibleObjectException: Unable to make错误应该如何解决

遇到 "java.lang.reflect.InaccessibleObjectException: Unable to make" 错误是因为你的代码尝试访问了一个不可访问的对象或方法。这通常会发生在使用反射机制时&#xff0c;尝试访问私有或受限制的成员时。要解决这个问题&#xff0c;你可以考虑以下几个步骤&…