重生之从零开始的神经网络算法学习之路 —— 第八篇 大型数据集与复杂模型的 GPU 训练实践

news/2025/9/25 21:26:07/文章来源:https://www.cnblogs.com/cmxcxd1314/p/19094873

重生之从零开始的神经网络算法学习之路——第八篇 大型数据集与复杂模型的GPU训练实践

引言

在前一篇中,我们实现了基础的SRCNN超分辨率模型并掌握了后台训练技巧。本篇将进一步拓展实验规模:引入更大规模的数据集、实现更复杂的网络结构,并优化GPU训练策略,以应对更具挑战性的图像重建任务。通过这些实践,我们将深入理解大规模深度学习实验的关键技术和工程细节。

项目目录结构

一个规范的项目结构有助于代码管理和团队协作,以下是我们超分辨率项目的完整目录结构:

esrgan-super-resolution/
│
├── src/                      # 源代码目录
│   ├── __init__.py
│   ├── models/               # 模型定义
│   │   ├── __init__.py
│   │   ├── esrgan.py         # ESRGAN生成器实现
│   │   └── discriminator.py  # 判别器实现
│   │
│   ├── data/                 # 数据处理相关
│   │   ├── __init__.py
│   │   ├── datasets.py       # 数据集类定义
│   │   ├── downloader.py     # 数据集下载工具
│   │   └── transforms.py     # 数据增强与转换
│   │
│   ├── losses/               # 损失函数
│   │   ├── __init__.py
│   │   ├── content_loss.py   # 内容损失
│   │   └── gan_loss.py       # GAN损失
│   │
│   ├── utils/                # 工具函数
│   │   ├── __init__.py
│   │   ├── metrics.py        # 评估指标(PSNR等)
│   │   ├── logger.py         # 日志工具
│   │   └── helpers.py        # 辅助函数
│   │
│   └── training/             # 训练相关
│       ├── __init__.py
│       ├── trainer.py        # 训练器类
│       └── validator.py      # 验证器类
│
├── configs/                  # 配置文件目录
│   ├── base_config.yaml      # 基础配置
│   └── esrgan_config.yaml    # ESRGAN专用配置
│
├── scripts/                  # 脚本目录
│   ├── train_esrgan.py       # 训练脚本
│   ├── evaluate.py           # 评估脚本
│   └── predict.py            # 预测脚本
│
├── data/                     # 数据目录
│   ├── raw/                  # 原始数据
│   │   ├── DIV2K/
│   │   └── Flickr2K/
│   └── processed/            # 处理后的数据
│
├── checkpoints/              # 模型检查点
│   ├── generator/
│   └── discriminator/
│
├── logs/                     # 日志文件
│   └── tensorboard/          # TensorBoard日志
│
├── results/                  # 结果输出
│   ├── comparisons/          # 图像对比结果
│   └── samples/              # 生成样本
│
├── docs/                     # 文档
│   ├── setup.md              # 环境搭建说明
│   └── usage.md              # 使用说明
│
├── main.py                   # 主程序入口
├── requirements.txt          # 依赖项
└── README.md                 # 项目说明

大型数据集的获取与处理

自动下载与解压实现

为了提升模型性能,我们使用DIV2K和Flickr2K两个大型数据集进行训练。以下是优化后的数据集自动下载与处理流程(对应src/data/downloader.py):

import os
import wget
import zipfile
import tarfile
from tqdm import tqdm# 数据集配置
DATASETS = {"DIV2K": {"train": "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip","valid": "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip"},"Flickr2K": {"url": "https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar"}
}def progress_bar(current, total, width=80):"""自定义进度条"""progress = int(width * current / total)bar = '=' * progress + '-' * (width - progress)print(f'[{bar}] {current/total*100:.1f}%', end='\r')def download_dataset(url, save_dir):"""下载数据集并显示进度条"""os.makedirs(save_dir, exist_ok=True)filename = url.split('/')[-1]file_path = os.path.join(save_dir, filename)if not os.path.exists(file_path):print(f"下载 {filename}...")wget.download(url, file_path, bar=progress_bar)print("\n下载完成")return file_pathdef extract_archive(file_path, extract_dir):"""解压数据集"""print(f"解压 {file_path} 到 {extract_dir}...")os.makedirs(extract_dir, exist_ok=True)if file_path.endswith('.zip'):with zipfile.ZipFile(file_path, 'r') as zip_ref:# 获取所有文件列表files = zip_ref.namelist()# 使用tqdm显示解压进度for file in tqdm(files, desc="解压中"):zip_ref.extract(file, extract_dir)elif file_path.endswith('.tar') or file_path.endswith('.tar.gz'):with tarfile.open(file_path, 'r') as tar_ref:members = tar_ref.getmembers()for member in tqdm(members, desc="解压中"):tar_ref.extract(member, extract_dir)def prepare_datasets(base_dir):"""准备所有数据集"""# 下载DIV2Kdiv2k_dir = os.path.join(base_dir, "DIV2K")for split, url in DATASETS["DIV2K"].items():file_path = download_dataset(url, div2k_dir)extract_archive(file_path, os.path.join(div2k_dir, split))# 下载Flickr2Kflickr_dir = os.path.join(base_dir, "Flickr2K")flickr_url = DATASETS["Flickr2K"]["url"]file_path = download_dataset(flickr_url, flickr_dir)extract_archive(file_path, flickr_dir)print("所有数据集准备完成")if __name__ == "__main__":# 可直接运行此脚本下载数据prepare_datasets(os.path.join(os.path.dirname(__file__), '../../data/raw'))

优化的数据加载器

对应src/data/datasets.py文件,实现高效处理大型数据集的加载器:

import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torchvision.transforms as transformsclass SuperResolutionDataset(Dataset):"""超分辨率数据集基础类"""def __init__(self, root_dir, scale_factor=4, patch_size=128, train=True):self.root_dir = root_dirself.scale_factor = scale_factorself.patch_size = patch_sizeself.train = train# 收集所有图像路径self.image_paths = []for dirpath, _, filenames in os.walk(root_dir):for fname in filenames:if fname.lower().endswith(('.png', '.jpg', '.jpeg')):self.image_paths.append(os.path.join(dirpath, fname))# 数据转换self.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])def __len__(self):return len(self.image_paths)def __getitem__(self, idx):# 读取图像img_path = self.image_paths[idx]hr_img = cv2.imread(img_path)hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)# 生成低分辨率图像h, w = hr_img.shape[:2]lr_size = (w // self.scale_factor, h // self.scale_factor)lr_img = cv2.resize(hr_img, lr_size, interpolation=cv2.INTER_CUBIC)# 训练时随机裁剪patchif self.train:# 随机裁剪高分辨率图像h, w = hr_img.shape[:2]x = np.random.randint(0, w - self.patch_size)y = np.random.randint(0, h - self.patch_size)hr_patch = hr_img[y:y+self.patch_size, x:x+self.patch_size]# 对应裁剪低分辨率图像lr_patch_size = self.patch_size // self.scale_factorlr_patch = lr_img[y//self.scale_factor : y//self.scale_factor + lr_patch_size,x//self.scale_factor : x//self.scale_factor + lr_patch_size]# 应用数据增强if np.random.random() > 0.5:hr_patch = cv2.flip(hr_patch, 1)lr_patch = cv2.flip(lr_patch, 1)return self.transform(lr_patch), self.transform(hr_patch)else:# 验证时使用完整图像return self.transform(lr_img), self.transform(hr_img)class CombinedDataset(ConcatDataset):"""组合多个数据集的包装类"""def __init__(self, dataset_paths, scale_factor=4, patch_size=128, train=True):datasets = []for path in dataset_paths:datasets.append(SuperResolutionDataset(path, train=train,scale_factor=scale_factor,patch_size=patch_size))super().__init__(datasets)def create_optimized_dataloaders(batch_size, dataset_paths, scale_factor=4, patch_size=128,num_workers=8, pin_memory=True):"""创建优化的数据加载器"""# 训练数据集train_dataset = CombinedDataset(dataset_paths,scale_factor=scale_factor,patch_size=patch_size,train=True)# 验证数据集(使用DIV2K验证集)val_dataset = SuperResolutionDataset([p for p in dataset_paths if 'DIV2K' in p][0],train=False,scale_factor=scale_factor,patch_size=patch_size)# 使用预加载和多进程加速train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=pin_memory,prefetch_factor=2,  # 预加载下一批数据persistent_workers=True  # 保持工作进程存活)val_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=pin_memory)return train_loader, val_loader

复杂模型实现:ESRGAN

对应src/models/esrgan.py文件,实现ESRGAN生成器:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass ResidualDenseBlock(nn.Module):"""残差密集块,ESRGAN的核心组件"""def __init__(self, nf=64, gc=32, bias=True):super(ResidualDenseBlock, self).__init__()self.conv1 = nn.Conv2d(nf + 0 * gc, gc, 3, 1, 1, bias=bias)self.conv2 = nn.Conv2d(nf + 1 * gc, gc, 3, 1, 1, bias=bias)self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)# 初始化权重self._initialize_weights()def _initialize_weights(self):"""权重初始化"""for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)def forward(self, x):x1 = self.lrelu(self.conv1(x))x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))# 残差连接return x5 * 0.2 + xclass RRDB(nn.Module):"""残差在残差密集块"""def __init__(self, nf, gc=32):super(RRDB, self).__init__()self.rdb1 = ResidualDenseBlock(nf, gc)self.rdb2 = ResidualDenseBlock(nf, gc)self.rdb3 = ResidualDenseBlock(nf, gc)def forward(self, x):out = self.rdb1(x)out = self.rdb2(out)out = self.rdb3(out)# 残差连接return out * 0.2 + xclass RRDBNet(nn.Module):"""ESRGAN 生成器的基础模块(RRDB 网络)"""def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4):super(RRDBNet, self).__init__()self.scale = scale# 示例结构:卷积 + RRDB块 + 上采样 + 输出卷积self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)self.body = self._make_rrdb_blocks(num_feat, num_block, num_grow_ch)self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)self.upsampler = self._make_upsampler(num_feat, scale)self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True)def _make_rrdb_blocks(self, num_feat, num_block, num_grow_ch):blocks = []for _ in range(num_block):blocks.append(RRDB(num_feat, num_grow_ch))return nn.Sequential(*blocks)def _make_upsampler(self, num_feat, scale):# 实现上采样模块(如PixelShuffle)upsampler = []for _ in range(int(torch.log2(torch.tensor(scale)))):upsampler.append(nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True))upsampler.append(nn.PixelShuffle(2))return nn.Sequential(*upsampler)def forward(self, x):# 实现前向传播逻辑feat = self.conv_first(x)body_feat = self.conv_body(self.body(feat))feat = feat + body_featout = self.conv_last(self.upsampler(feat))return out# 定义ESRGAN生成器(继承RRDB网络,保持接口一致性)
class ESRGAN(RRDBNet):"""ESRGAN生成器类(与RRDB网络结构一致,用于统一接口)"""def __init__(self, scale_factor=4, **kwargs):super(ESRGAN, self).__init__(scale=scale_factor,** kwargs)

判别器实现

对应src/models/discriminator.py文件:

import torch
import torch.nn as nnclass Discriminator(nn.Module):"""ESRGAN判别器"""def __init__(self, num_in_ch=3, num_feat=64):super(Discriminator, self).__init__()self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)# 特征提取层self.features = nn.Sequential(nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False),nn.BatchNorm2d(num_feat),nn.LeakyReLU(0.2, True),nn.Conv2d(num_feat, num_feat*2, 4, 2, 1, bias=False),nn.BatchNorm2d(num_feat*2),nn.LeakyReLU(0.2, True),nn.Conv2d(num_feat*2, num_feat*4, 4, 2, 1, bias=False),nn.BatchNorm2d(num_feat*4),nn.LeakyReLU(0.2, True),nn.Conv2d(num_feat*4, num_feat*8, 4, 2, 1, bias=False),nn.BatchNorm2d(num_feat*8),nn.LeakyReLU(0.2, True),nn.Conv2d(num_feat*8, num_feat*8, 4, 2, 1, bias=False),nn.BatchNorm2d(num_feat*8),nn.LeakyReLU(0.2, True))self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv_last = nn.Conv2d(num_feat*8, 1, 1, 1, 0)def forward(self, x):x = self.lrelu(self.conv_first(x))x = self.features(x)x = self.avg_pool(x)x = self.conv_last(x)return x

生成对抗训练策略

损失函数实现

对应src/losses/content_loss.py

import torch
import torch.nn as nn
from torchvision import models, transformsclass ContentLoss(nn.Module):"""内容损失函数,使用VGG特征提取器"""def __init__(self, device):super(ContentLoss, self).__init__()# 使用预训练的VGG作为特征提取器vgg = models.vgg19(pretrained=True).features[:35].eval()for param in vgg.parameters():param.requires_grad = Falseself.vgg = vgg.to(device)self.criterion = nn.L1Loss()self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])def forward(self, sr, hr):# 归一化输入以匹配VGG训练条件sr_norm = self.normalize(sr)hr_norm = self.normalize(hr)# 提取特征sr_feat = self.vgg(sr_norm)hr_feat = self.vgg(hr_norm)return self.criterion(sr_feat, hr_feat)

对应src/losses/gan_loss.py

import torch
import torch.nn as nnclass GANLoss(nn.Module):"""GAN损失函数"""def __init__(self, gan_type='vanilla', real_label_val=1.0, fake_label_val=0.0):super(GANLoss, self).__init__()self.gan_type = gan_typeself.real_label_val = real_label_valself.fake_label_val = fake_label_valif self.gan_type == 'vanilla':self.loss = nn.BCEWithLogitsLoss()elif self.gan_type == 'lsgan':self.loss = nn.MSELoss()else:raise NotImplementedError(f"GAN type {self.gan_type} is not implemented")def forward(self, pred, target_is_real):if target_is_real:target_val = self.real_label_valelse:target_val = self.fake_label_valtarget = torch.full_like(pred, fill_value=target_val, device=pred.device)return self.loss(pred, target)

GPU训练优化技巧

混合精度训练

对应src/training/trainer.py中的训练实现:

import torch
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
import time
import numpy as np
from tqdm import tqdm
from src.utils.metrics import psnrclass ESRGANTrainer:def __init__(self, generator, discriminator, content_criterion, gan_criterion,g_optimizer, d_optimizer,device, log_interval=10):self.generator = generatorself.discriminator = discriminatorself.content_criterion = content_criterionself.gan_criterion = gan_criterionself.g_optimizer = g_optimizerself.d_optimizer = d_optimizerself.device = deviceself.log_interval = log_interval# 初始化混合精度训练self.scaler = GradScaler(enabled=True)def train_epoch(self, train_loader, epoch, grad_accum_steps=4):"""训练一个epoch"""self.generator.train()self.discriminator.train()total_gen_loss = 0.0total_dis_loss = 0.0total_psnr = 0.0pbar = tqdm(train_loader, desc=f"Epoch {epoch}")for batch_idx, (lr_imgs, hr_imgs) in enumerate(pbar):lr_imgs = lr_imgs.to(self.device)hr_imgs = hr_imgs.to(self.device)# 训练判别器self.d_optimizer.zero_grad()with autocast():# 生成超分辨率图像sr_imgs = self.generator(lr_imgs)# 判别器对真实图像的预测real_pred = self.discriminator(hr_imgs)# 判别器对生成图像的预测fake_pred = self.discriminator(sr_imgs.detach())  # detach避免更新生成器# 计算判别器损失real_loss = self.gan_criterion(real_pred, True)fake_loss = self.gan_criterion(fake_pred, False)dis_loss = (real_loss + fake_loss) * 0.5# 反向传播self.scaler.scale(dis_loss).backward()# 梯度累积if (batch_idx + 1) % grad_accum_steps == 0:self.scaler.step(self.d_optimizer)self.scaler.update()self.d_optimizer.zero_grad()# 训练生成器self.g_optimizer.zero_grad()with autocast():# 生成器损失 = 内容损失 + GAN损失content_loss = self.content_criterion(sr_imgs, hr_imgs)fake_pred = self.discriminator(sr_imgs)gan_loss = self.gan_criterion(fake_pred, True)# 内容损失权重更高gen_loss = content_loss * 0.01 + gan_loss * 0.005# 计算PSNRbatch_psnr = psnr(hr_imgs, sr_imgs)self.scaler.scale(gen_loss).backward()# 梯度累积if (batch_idx + 1) % grad_accum_steps == 0:self.scaler.step(self.g_optimizer)self.scaler.update()self.g_optimizer.zero_grad()# 累计损失total_gen_loss += gen_loss.item()total_dis_loss += dis_loss.item()total_psnr += batch_psnr# 日志输出if batch_idx % self.log_interval == 0:avg_gen_loss = total_gen_loss / (batch_idx + 1)avg_dis_loss = total_dis_loss / (batch_idx + 1)avg_psnr = total_psnr / (batch_idx + 1)pbar.set_postfix({'gen_loss': f'{avg_gen_loss:.4f}','dis_loss': f'{avg_dis_loss:.4f}','psnr': f'{avg_psnr:.2f}'})return (total_gen_loss / len(train_loader), total_dis_loss / len(train_loader), total_psnr / len(train_loader))

主训练脚本

对应scripts/train_esrgan.py

import os
import argparse
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from src.models.esrgan import ESRGAN
from src.models.discriminator import Discriminator
from src.losses.content_loss import ContentLoss
from src.losses.gan_loss import GANLoss
from src.data.datasets import create_optimized_dataloaders
from src.training.trainer import ESRGANTrainer
from src.training.validator import validate
from src.utils.logger import init_tensorboard, log_to_tensorboard
from src.utils.helpers import save_checkpoint, load_checkpointdef parse_args():parser = argparse.ArgumentParser(description='Train ESRGAN model')parser.add_argument('--epochs', type=int, default=2000, help='Number of epochs')parser.add_argument('--batch_size', type=int, default=16, help='Batch size')parser.add_argument('--lr', type=float, default=0.0002, help='Learning rate')parser.add_argument('--lr_decay', type=float, default=0.5, help='Learning rate decay')parser.add_argument('--scale_factor', type=int, default=4, help='Upscaling factor')parser.add_argument('--patch_size', type=int, default=192, help='Patch size for training')parser.add_argument('--start_epoch', type=int, default=0, help='Start epoch')parser.add_argument('--checkpoint_interval', type=int, default=50, help='Checkpoint interval')parser.add_argument('--log_interval', type=int, default=20, help='Log interval')parser.add_argument('--num_workers', type=int, default=8, help='Number of workers for dataloader')parser.add_argument('--pin_memory', action='store_true', default=True, help='Pin memory for dataloader')parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device')parser.add_argument('--dataset_path', type=str, default='./data/raw', help='Dataset path')parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints', help='Checkpoint directory')parser.add_argument('--log_dir', type=str, default='./logs/tensorboard', help='Log directory')parser.add_argument('--resume', type=str, default=None, help='Resume from checkpoint')return parser.parse_args()def main():# 解析配置参数args = parse_args()# 设置设备device = torch.device(args.device)print(f"使用设备: {device}")# 初始化模型generator = ESRGAN(scale_factor=args.scale_factor).to(device)discriminator = Discriminator().to(device)# 定义损失和优化器content_criterion = ContentLoss(device)gan_criterion = GANLoss(gan_type='vanilla')g_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.9, 0.999))d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr * 0.1, betas=(0.9, 0.999))# 学习率调度器 - 阶梯式衰减g_scheduler = lr_scheduler.StepLR(g_optimizer, step_size=10, gamma=args.lr_decay)d_scheduler = lr_scheduler.StepLR(d_optimizer, step_size=10, gamma=args.lr_decay)# 加载检查点(如果需要)if args.resume:args.start_epoch = load_checkpoint(args.resume, generator, discriminator, g_optimizer, d_optimizer)# 准备数据集路径dataset_paths = [os.path.join(args.dataset_path, "DIV2K"),os.path.join(args.dataset_path, "Flickr2K")]# 加载大型数据集train_loader, val_loader = create_optimized_dataloaders(batch_size=args.batch_size,dataset_paths=dataset_paths,scale_factor=args.scale_factor,patch_size=args.patch_size,num_workers=args.num_workers,pin_memory=args.pin_memory)# 初始化训练器trainer = ESRGANTrainer(generator, discriminator,content_criterion, gan_criterion,g_optimizer, d_optimizer,device, args.log_interval)# 初始化TensorBoardwriter = init_tensorboard(args.log_dir)# 训练循环for epoch in range(args.start_epoch, args.epochs):start_time = time.time()# 梯度累积参数grad_accum_steps = 4  # 累积4个batch的梯度# 训练gen_loss, dis_loss, train_psnr = trainer.train_epoch(train_loader, epoch, grad_accum_steps)# 更新学习率g_scheduler.step()d_scheduler.step()# 验证val_psnr, val_images = validate(generator, val_loader, device)# 日志记录print(f'Epoch {epoch}/{args.epochs}, 'f'Gen Loss: {gen_loss:.4f}, Dis Loss: {dis_loss:.4f}, 'f'Train PSNR: {train_psnr:.2f} dB, Val PSNR: {val_psnr:.2f} dB, 'f'Time: {time.time() - start_time:.2f}秒')# 写入TensorBoardlog_to_tensorboard(writer, epoch, {'gen_loss': gen_loss,'dis_loss': dis_loss,'psnr': train_psnr,'gen_lr': g_optimizer.param_groups[0]['lr']}, {'psnr': val_psnr}, val_images)# 保存检查点if (epoch + 1) % args.checkpoint_interval == 0:save_checkpoint(epoch, generator, discriminator, g_optimizer, d_optimizer, args.checkpoint_dir)writer.close()if __name__ == "__main__":main()

扩展运行脚本

对应项目根目录下的run_esrgan.sh

#!/bin/bash
# run_esrgan.sh# 设置工作目录
cd /home/vscode/workspace# 记录开始时间
start_time=$(date +%s)
echo "实验开始时间: $(date)"# 检查GPU状态
nvidia-smi# 创建输出目录
mkdir -p logs checkpoints results# 运行训练脚本,增加内存优化参数
nohup python3 -u scripts/train_esrgan.py \--epochs 2000 \--batch_size 16 \--lr 0.0002 \--scale_factor 4 \--patch_size 192 \--checkpoint_interval 50 \--log_interval 20 \--dataset_path ./data/raw \> training_log_esrgan_$(date +%Y%m%d_%H%M%S).txt 2>&1 &# 记录进程ID和日志文件
echo "训练任务已在后台启动,PID: $!"
log_file="training_log_esrgan_$(date +%Y%m%d_%H%M%S).txt"
echo "日志文件: $log_file"# 监控GPU使用情况(每5分钟记录一次)
while true; doecho "GPU监控: $(date)" >> $log_filenvidia-smi >> $log_file 2>&1sleep 300  # 5分钟
done &

实验监控与分析

对应src/utils/logger.py

import os
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriterdef init_tensorboard(log_dir):"""初始化TensorBoard"""os.makedirs(log_dir, exist_ok=True)writer = SummaryWriter(log_dir=log_dir)return writerdef log_to_tensorboard(writer, epoch, train_metrics, val_metrics, images):"""将训练指标和图像写入TensorBoard"""# 日志指标writer.add_scalar('Loss/Generator', train_metrics['gen_loss'], epoch)writer.add_scalar('Loss/Discriminator', train_metrics['dis_loss'], epoch)writer.add_scalar('PSNR/Train', train_metrics['psnr'], epoch)writer.add_scalar('PSNR/Validation', val_metrics['psnr'], epoch)writer.add_scalar('LearningRate/Generator', train_metrics['gen_lr'], epoch)# 日志图像(每10个epoch)if epoch % 10 == 0 and images is not None:lr_img, sr_img, hr_img = imageswriter.add_image('Input/LowResolution', lr_img, epoch)writer.add_image('Output/SuperResolution', sr_img, epoch)writer.add_image('Target/HighResolution', hr_img, epoch)# 保存图像到文件save_comparison_plot(lr_img, sr_img, hr_img, epoch)def save_comparison_plot(lr, sr, hr, epoch, save_dir='results/comparisons'):"""保存图像对比结果"""os.makedirs(save_dir, exist_ok=True)# 转换为适合显示的格式lr = lr.permute(1, 2, 0).cpu().detach().numpy()sr = sr.permute(1, 2, 0).cpu().detach().numpy()hr = hr.permute(1, 2, 0).cpu().detach().numpy()# 反归一化lr = (lr * 0.5 + 0.5) * 255sr = (sr * 0.5 + 0.5) * 255hr = (hr * 0.5 + 0.5) * 255# 绘制对比图plt.figure(figsize=(15, 5))plt.subplot(131)plt.title('Low Resolution')plt.imshow(lr.astype('uint8'))plt.axis('off')plt.subplot(132)plt.title('Super Resolution')plt.imshow(sr.astype('uint8'))plt.axis('off')plt.subplot(133)plt.title('High Resolution')plt.imshow(hr.astype('uint8'))plt.axis('off')plt.tight_layout()plt.savefig(f'{save_dir}/comparison_epoch_{epoch}.png', dpi=300, bbox_inches='tight')plt.close()

总结与后续方向

通过本篇实验,我们实现了一个结构完整的超分辨率项目,包括:

  1. 规范的项目结构:将代码模块化,分离数据处理、模型定义、损失函数和训练逻辑
  2. 大型数据集管理:自动下载、解压和组合多个大型数据集,优化数据加载流程
  3. 复杂模型构建:实现了基于残差密集块的ESRGAN模型,相比SRCNN能生成更丰富的细节
  4. 高级训练策略:引入混合精度训练、梯度累积和阶梯式学习率调度,提升GPU利用率
  5. 完善监控体系:结合日志文件、GPU监控和TensorBoard可视化,全面跟踪实验过程

后续可探索的方向:

  • 尝试更大规模的模型(如RCAN、SwinIR)
  • 引入感知损失和GAN的改进变体(如Relativistic GAN)
  • 实现模型并行和数据并行,利用多GPU进行训练
  • 探索模型压缩和加速技术,实现实时超分辨率
  • 尝试视频超分辨率任务,考虑时间维度的一致性

下一篇我们将探索更前沿的视觉Transformer模型在超分辨率任务中的应用,进一步提升重建质量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/917552.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Avalonia:开发Android应用

我把成功开发Android应用的经过记录下来,在开发过程中,模拟器经常出问题,将Java Development Kit的位置和Android SDK的位置改动一下,就解决了模拟器报错的问题,这是在Github上看到的解决办法。 先建Models文件夹…

MIT s6.828环境搭建

前言:建议ubuntu镜像版本在22.04以下,亲测新版本会报错 本文默认读者ubuntu搭建完成,且可以联网 sudo apt update开始配置环境前先更新软件包列表sudo apt install -y binutils gcc git libpixman-l-dev python2 pk…

做微网站的第三方登录wordpress 目录布局

目录 一、引言 二、代码整体结构 三、宏定义与头文件 四、插入排序函数(Insertsort) 函数作用 代码要点分析 五、希尔排序函数(ShellSort) 函数作用 代码要点分析 六、打印数组函数(PrintSort&#x…

关键词搜索爱站网自己如何建立网站

一、设计模式分类 软件开发的23种模式,主要分类有创建型模式,结构型模式,行为型模式三种,相关分类如下: 设计模式是一种面向对象编程的思想,它是由Gamma等人在《设计模式:可复用面向对象软件的…

详细介绍:ES6核心基础

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

kubernetes事件监控工具--Kube-Event

在日常使用中,总会碰到容器“不经意间”重启的情况,但我完全不知道上次重启是什么时候。容器一旦重启,旧实例就会被销毁,如果旧容器日志没有被收集或转存,就彻底丢失了。这样一来,想通过历史日志排查问题原因就显…

wordpress 中英文站点佛山的网站建设公司

引言 C语⾔是结构化的程序设计语⾔,这⾥的结构指的是顺序结构、选择结构、循环结构。为什么有着三种结构呢,大家其实可以想象一下,生活中的绝大数事情都可以抽象着三种结构,而我们今天要给大家介绍的就是三大结构之一——选择结构…

做电子简历的网站悦西安

本文主要介绍Linux 字体颜色的调整,常用于shell脚本当中。我们举一个例子:echo-e"\033[44;37;5m ME \033[0m COOL" 以上命令设置背景成为蓝色,前景白色,闪烁光标,输出字符“ME”,然后重新设置屏幕…

企业档案管理系统:精准破局制造行业档案管理困境 - 指南

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

实用指南:【 GUI自动化测试】GUI自动化测试(一) 环境安装与测试

实用指南:【 GUI自动化测试】GUI自动化测试(一) 环境安装与测试pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "C…

怎么建一个卖东西的网站网站怎样关键词排名优化

在日常编码环节,很大比例的错误处理工作和参数的输入有关。当程序里的某些数据直接来自用户输入时,必须先校验这些输入值,再进行之后的处理,否则就会出现难以预料的错误。 需求: 写一个命令行小程序,它要求…

喵喵大王の新日记

2025 9.25 突然心血来潮了,于是开了新日记,但是实际上我也不一定更的多么频繁,毕竟上了大学还是有点忙的。才不是一直打三角洲懒得更新 这里应当有一篇新文章。啥时候写完想起来放上。本文来自博客园,作者:北烛青…

【JavaEE】MyBatis - Plus - 教程

【JavaEE】MyBatis - Plus - 教程pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco&q…

多GPU本地布署Wan2.2-T2V-A14B文本转视频模型 - yi

多GPU本地布署Wan2.2-T2V-A14B文本转视频模型一,#本机环境检查 执行nvidia-smi,查看右上角。验证显卡驱动已安装最高支持的版本。nvidia-smi#在调试时,为了实时观察GPU利用率,一般新开一个命令窗口,执行以下命令,…

NOI 模拟赛五

DPA. 纪念场切题。 记 \(f[i, j, x, 0/1, 0/1]\) 表示前 \(i\) 个车站都已经经过,\(i\rightarrow i+1\) 的边走过 \(j\) 次,总距离 \(\bmod m=x\) ,是否钦定起点,是否钦定终点(这 \(j\) 条边经过是有顺序)。 为了…

常州装修网站建设公司企业的建站方式

运行软件前提前安装好OPC运行组件: 为方便演示,提前准备好了一个DAServer服务器: 接下来开始配置: 该软件主要实现的功能如下: 配置过程也相对简单: 第一步: 编辑如下文件: 第二步…

企业微信手机片网站制作上海建筑工程招投标网

这是什么?这是有关警告,错误和注意事项的许多答案,这些警告,错误和注意事项在您对PHP进行编程时可能会遇到,并且不知道如何解决它们。这也是一个社区Wiki,因此邀请所有人参与添加并维护此列表。为什么是这样…

免费咨询律师24小时电话桂平seo快速优化软件

数学建模常用的算法分类 全国大学生数学建模竞赛中,常见的算法模型有以下30种: 最小二乘法数值分析方法图论算法线性规划整数规划动态规划贪心算法分支定界法蒙特卡洛方法随机游走算法遗传算法粒子群算法神经网络算法人工智能算法模糊数学时间序列分析马…

中小型网站建设与管理总结wordpress手机怎么用

性能对比:Memcached 与 Redis 的关键差异 在选择合适的缓存系统时,Memcached 和 Redis 是最常被提及的两种技术。它们都是内存存储系统,用于提高数据访问速度和应用性能。尽管它们在功能上有很多相似之处,但在性能、特性和应用场…

AI热点周报(09.14~09.20):Gemini集成到Chrome、Claude 强化记忆、Qwen3-Next快捷落地,AI走向集成化,工程化?

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …