GPEN模型微调入门:自定义数据集训练步骤详解教程
1. 镜像环境说明
本镜像基于GPEN人像修复增强模型构建,预装了完整的深度学习开发环境,集成了推理及评估所需的所有依赖,开箱即用。用户无需手动配置复杂的运行时依赖,可直接进入模型微调与训练阶段。
| 组件 | 版本 |
|---|---|
| 核心框架 | PyTorch 2.5.0 |
| CUDA 版本 | 12.4 |
| Python 版本 | 3.11 |
| 推理代码位置 | /root/GPEN |
主要依赖库:
facexlib: 用于人脸检测与对齐basicsr: 基础超分框架支持opencv-python,numpy<2.0,datasets==2.21.0,pyarrow==12.0.1sortedcontainers,addict,yapf
所有依赖均已通过 Conda 环境管理工具打包至torch25虚拟环境中,确保版本兼容性和运行稳定性。
2. 快速上手
2.1 激活环境
在使用 GPEN 模型前,请先激活预设的 Python 环境:
conda activate torch25该环境已包含所有必要的深度学习库和工具链,避免因版本冲突导致运行失败。
2.2 模型推理 (Inference)
进入项目主目录并执行推理脚本:
cd /root/GPEN场景 1:运行默认测试图
python inference_gpen.py此命令将加载内置测试图像(Solvay_conference_1927.jpg),输出结果为output_Solvay_conference_1927.png。
场景 2:修复自定义图片
python inference_gpen.py --input ./my_photo.jpg输入文件路径由--input参数指定,输出自动保存为output_my_photo.jpg。
场景 3:自定义输入输出文件名
python inference_gpen.py -i test.jpg -o custom_name.png支持通过-i和-o分别设置输入与输出路径,便于集成到自动化流程中。
注意:所有推理结果将默认保存在项目根目录下,建议定期备份或重定向输出路径以避免覆盖。
3. 已包含权重文件
为保障离线可用性与快速启动能力,镜像内已预下载官方发布的预训练权重文件,存储于 ModelScope 缓存路径:
~/.cache/modelscope/hub/iic/cv_gpen_image-portrait-enhancement该目录包含以下关键组件:
- 生成器权重(Generator):负责高保真人脸细节重建
- 人脸检测模型:基于 RetinaFace 实现精准面部定位
- 关键点对齐模块:提升多角度人像处理鲁棒性
若首次运行未触发自动下载,请检查网络连接或手动验证缓存完整性。
4. 自定义数据集准备与格式规范
4.1 数据配对原则
GPEN 采用监督式训练策略,要求每条样本包含一对图像:
- 高质量图像(HR):清晰、无压缩失真、分辨率不低于目标尺寸
- 低质量图像(LR):对应 HR 图像经人工降质处理后的版本
推荐使用 FFHQ 或 CelebA-HQ 等公开高清人脸数据集作为原始 HR 数据源。
4.2 低质量图像生成方法
由于真实低质图像难以获取且缺乏精确配对关系,通常采用合成方式生成 LR 图像。推荐以下两种主流方案:
方法一:使用 RealESRGAN 进行退化增强
from basicsr.data.degradations import random_add_gaussian_noise, random_mixed_kernels import cv2 import numpy as np def degrade_image(hr_path, lr_save_path): img = cv2.imread(hr_path) # 添加模糊核 kernel = random_mixed_kernels( ['iso', 'aniso'], [0.7, 0.3], 4, 2, 0.5, [-0.5, 0.5], [-1, 1], noise_range=None ) img = cv2.filter2D(img, -1, kernel) # 添加噪声 img = random_add_gaussian_noise(img, sigma_range=[1, 30]) # 下采样模拟低分辨率 h, w = img.shape[:2] img = cv2.resize(img, (w//2, h//2), interpolation=cv2.INTER_LINEAR) img = cv2.resize(img, (w, h), interpolation=cv2.INTER_LINEAR) cv2.imwrite(lr_save_path, img) # 示例调用 degrade_image('./hr_images/face_001.png', './lr_images/face_001.png')方法二:使用 BSRGAN 工具链批量生成
BSRGAN 提供完整的图像退化管道,支持多种模糊核、JPEG 压缩、颜色扰动等操作,适合大规模数据构建。
# 安装 BSRGAN pip install bsrgan # 批量生成示例(伪代码) for hr_img in hr_dataset: lr_img = bsrgan.degrade(hr_img, scale=1, quality_factor=30) save_pair(hr_img, lr_img)4.3 数据组织结构
建议按照如下目录结构组织训练数据:
datasets/ ├── train/ │ ├── hr/ │ │ └── img_001.png │ │ └── ... │ └── lr/ │ └── img_001.png │ └── ... └── val/ ├── hr/ └── lr/并在配置文件中明确指定dataroot_gt和dataroot_lq路径。
5. 微调训练全流程详解
5.1 训练脚本入口
进入代码目录后,使用train_gpen.py启动训练任务:
cd /root/GPEN python train_gpen.py --config configs/gpen_bilinear_512.py5.2 配置文件修改要点
以gpen_bilinear_512.py为例,需根据实际需求调整以下参数:
# 数据路径配置 'dataroot_gt': '/root/datasets/train/hr', # 高清图像路径 'dataroot_lq': '/root/datasets/train/lr', # 低清图像路径 'val_dataroot_gt': '/root/datasets/val/hr', 'val_dataroot_lq': '/root/datasets/val/lr', # 模型参数 'lq_size': 512, # 输入尺寸 'net_type': 'GPEN-Bilinear', # 可选:GPEN-Bilinear / GPEN-Deformable # 优化器设置 'lr_generator': 1e-4, # 生成器学习率 'lr_discriminator': 5e-5, # 判别器学习率 'total_iter': 100000, # 总迭代次数 # 日志与保存 'print_freq': 100, # 每N步打印loss 'save_checkpoint_freq': 5000, # 每N步保存一次模型 'path': { 'pretrain_network_g': None, # 若继续训练,填入预训练权重路径 }提示:若从头开始训练,可留空
pretrain_network_g;若进行微调,建议加载官方权重以加速收敛。
5.3 启动训练任务
CUDA_VISIBLE_DEVICES=0 python train_gpen.py --config configs/gpen_bilinear_512.py训练过程中日志将实时输出至终端,并记录在./experiments目录下的时间戳子文件夹中。
5.4 训练过程监控
系统会自动生成以下内容用于监控:
- Loss 曲线:保存在
experiments/exp_name/logs - 可视化中间结果:每
save_checkpoint_freq步保存一组对比图(LR vs Output vs GT) - Checkpoint 模型:
.pth格式权重文件,可用于后续推理或继续训练
6. 推理与评估最佳实践
6.1 使用微调后模型进行推理
将训练好的权重复制到推理目录,并更新inference_gpen.py中的模型路径:
model_path = './experiments/gpen_512_net_g.pth'然后执行:
python inference_gpen.py --input ./test_low_quality.jpg --output ./restored_face.png6.2 客观指标评估
使用evaluate.py脚本计算 PSNR 和 LPIPS 指标:
python evaluate.py \ --gt_folder /root/datasets/val/hr \ --sr_folder /root/datasets/val/output \ --metric psnr,lpips- PSNR:反映像素级重建精度
- LPIPS:感知相似度,越低表示视觉效果越接近原图
7. 常见问题与解决方案
7.1 OOM(显存不足)问题
现象:训练时报错CUDA out of memory
解决方法:
- 降低
batch_size至 1 或 2 - 使用
--fp16开启混合精度训练(需确认 CUDA 支持) - 减小输入分辨率(如从 512→256)
7.2 图像边缘伪影严重
原因:边界填充方式不当或退化模式不匹配
对策:
- 在数据预处理阶段增加随机裁剪扰动
- 调整生成器中的归一化层类型(InstanceNorm → BatchNorm)
- 引入边缘感知损失函数(Edge Loss)
7.3 模型过拟合
表现:训练 Loss 持续下降但验证集效果变差
缓解措施:
- 增加数据增强强度(颜色抖动、随机擦除)
- 提前停止(Early Stopping)
- 使用更小的学习率微调最后几万步
8. 总结
本文详细介绍了如何基于预置 GPEN 镜像完成从环境配置、数据准备、模型微调到推理评估的完整流程。核心要点包括:
- 环境即用性:镜像已集成 PyTorch 2.5 + CUDA 12.4 全套依赖,省去繁琐安装过程。
- 数据配对是关键:必须构建高质量的 HR-LR 成对数据集,推荐使用 RealESRGAN 或 BSRGAN 合成退化样本。
- 配置灵活可调:通过修改
.py配置文件即可控制训练行为,支持学习率、尺寸、迭代数等全面定制。 - 训练稳定高效:结合日志与可视化监控,能够及时发现并解决常见问题如 OOM、伪影、过拟合等。
掌握上述流程后,开发者可快速将 GPEN 应用于特定场景的人像修复任务,如老照片复原、视频画质增强、移动端美颜等。
获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。