【深度学习基础】:VGG实战篇(图像风格迁移)

文章目录

  • 前言
  • style transfer原理
    • 原理解析
    • 损失函数
  • style transfer代码
    • 效果图
  • fast style transfer 代码
    • 效果图

前言

本篇来带大家看看VGG的实战篇,这次来带大家看看计算机视觉中一个有趣的小任务,图像风格迁移。

可运行代码位于: Style_transfer (可直接下载运行)

我们所进行的图像风格迁移是基于VGG实现的哦,如果对VGG网络还不太了解的,可以先去看看这边文章:【深度学习基础】:VGG原理篇-CSDN博客

image-20250502122534540

style transfer原理

原理解析

图像风格迁移是一种通过将一张图像的风格应用到另一张图像的内容上,从而生成一张新图像的技术。这种技术通常基于深度学习模型,特别是卷积神经网络(CNN)。

实现步骤

  • 首先,我们需要两张图像:内容图像和风格图像。内容图像是我们想要保持内容的图像,而风格图像是我们想要应用到内容图像上的风格的图像。接下来,我们通过使用预训练的卷积神经网络(如VGG19)来提取图像的特征。
  • 在提取特征之后,我们计算内容损失和风格损失。内容损失度量生成图像与内容图像在内容上的差异,通常通过比较两个图像在浅层特征上的差异来计算。风格损失度量生成图像与风格图像在风格上的差异,风格通常通过计算图像特征之间的相关性来表示,即风格图像的特征图之间的gram矩阵。(这个等会会详细讲述)

其实图像风格迁移的核心是一个优化问题,目标是生成一张图像,使得内容损失和风格损失都最小化。我们初始化一张随机生成的图像。然后,我们使用卷积神经网络提取这张图像的特征,并计算内容损失和风格损失。接着,我们计算总损失,并进行反向传播和优化,以更新生成图像的像素值。这个过程会重复进行多次,直到生成图像在内容和风格上都接近目标图像。

损失函数

好,其实大概的原理基本上都清楚了,但是我们来看,我们是该如何去保证图片的内容是我们所选的内容图片,但是风格确实我们所选的风格图片呢?

首先来看内容函数的损失函数,其是这个还是很好理解的,我们所使用的就是平方差损失函数,我们希望我们的生成图片能够和内容图片具有相同的内容,所以最小化两者的差距即可。但是这里有个细节,我们该选择那个层的输出作为我们的标准呢?其实看最上面的图也能看出,我们的网络越深,提取到的内容图片损失的细节就会越多,所以我们会更多的选择浅层的内容图片特征作为标准。

image-20250502125827494

def calculate_content_loss(original_feat, generated_feat) -> torch.Tensor:"""计算内容损失,即生成特征图与标准特征图的规范化误差平方和"""b, c, h, w = original_feat.shapex = 2. * c * h * w  # 规范化系数return torch.sum((generated_feat - original_feat) ** 2) / x

然后我们在来看风格损失函数,其实风格损失函数的关键在于gram矩阵。那么我们回到最初,我们为什么能够提取出来图像的风格特征呢?其实就是依靠Gram矩阵了,其用于捕捉图像中不同通道之间的相关性,这些相关性可以反映出图像的纹理、颜色分布等风格特征。通俗说就是出现尖尖的形状的时候边上应该出现些什么。

image-20250502125816341

然后风格损失函数如下,同样的我们会选择深层的风格图像特征作为我们的标准,因为深层网络提取出来的更多的是抽象的语义信息,能够代表其本质的特征,抽象的特征。

深度学习项目二: 图像的风格迁移和图像的快速风格迁移 (含数据和所需源码)_牛客博客

def gram_matrix(feature):"""计算特征图的格莱姆矩阵,作为风格特征表示:param feature: 输入特征图:return: 格莱姆矩阵"""b, c, h, w = feature.size()feature = feature.view(b, c, h * w)  # 拉平空间维度G = torch.bmm(feature, feature.transpose(1, 2))return Gdef calculate_style_loss(style_feat, generated_feat) -> torch.Tensor:"""计算风格损失,即生成特征图与标准特征图的格拉姆矩阵的规范化误差平方和"""b, c, h, w = style_feat.shapeG = gram_matrix(generated_feat)A = gram_matrix(style_feat)x = 4. * ((h * w) ** 2) * (c ** 2)  # 规范化系数return torch.sum((G - A) ** 2) / x

所以最终的损失函数就是风格损失和内容损失的和,其中会乘上不同的系数。

image-20250502125754226

style transfer代码

这里只展现了部分的核心代码,所有代码位于Style_transfer ,下载可直接使用。

我们通过处理目标图片,其一开始就是一张白噪声图,然后通过网络损失不断修改其像素,使得目标图片内容与内容图片相近,风格与风格图片相近。

image-20250502133002891

import argparse
import os
import torch
import torch.optim as optim
from tqdm import tqdmfrom models import VGG
from utils import make_transform, load_image, save_image, calculate_style_loss, calculate_content_lossdef parse_args():parser = argparse.ArgumentParser(description='Style Transfer=')parser.add_argument('--content_image', type=str, default='./data/content1.jpg', help='Path to the content image')parser.add_argument('--style_image', type=str, default='./data/style2.jpg', help='Path to the style image')parser.add_argument('--output_dir', type=str, default='./output/iterative_style_transfer', help='Output directory')parser.add_argument('--image_size', type=int, nargs=2, default=[300, 450], help='Image size (height, width)')parser.add_argument('--content_weight', type=float, default=1, help='Content weight')parser.add_argument('--style_weight', type=float, default=15, help='Style weight')parser.add_argument('--epochs', type=int, default=20, help='Number of training epochs')parser.add_argument('--steps_per_epoch', type=int, default=100, help='Number of steps per epoch')parser.add_argument('--learning_rate', type=float, default=0.03, help='Learning rate')return parser.parse_args()if __name__ == "__main__":args = parse_args()# 检查文件路径assert os.path.exists(args.content_image), f"content image is not exist: {args.content_image}"assert os.path.exists(args.style_image), f"style image is not exist: {args.style_image}"# 内容特征层及loss加权系数content_layers = {'5': 0.5, '10': 0.5}# 风格特征层及loss加权系数style_layers = {'0': 0.2, '5': 0.2, '10': 0.2, '19': 0.2, '28': 0.2}# ----------------训练即推理过程----------------device = torch.device("cuda" if torch.cuda.is_available() else "cpu")transform = make_transform(args.image_size, normalize=True)content_img = load_image(args.content_image, transform).to(device)style_img = load_image(args.style_image, transform).to(device)# 用小随机噪声初始化图像(避免标准正态分布过大震荡)generated_img = torch.rand_like(content_img).mul(0.1).requires_grad_().to(device)save_image(generated_img, args.output_dir, 'noise_init.jpg')vgg_model = VGG(content_layers, style_layers).to(device).eval()# 关闭梯度追踪,节省显存with torch.no_grad():content_features, _ = vgg_model(content_img)_, style_features = vgg_model(style_img)optimizer = optim.Adam([generated_img], lr=args.learning_rate)for epoch in range(args.epochs):p_bar = tqdm(range(args.steps_per_epoch), desc=f'Epoch {epoch + 1}/{args.epochs}')for step in p_bar:generated_content, generated_style = vgg_model(generated_img)content_loss = sum(args.content_weight * content_layers[name] * calculate_content_loss(content_features[name], gen_content)for name, gen_content in generated_content.items())style_loss = sum(args.style_weight * style_layers[name] * calculate_style_loss(style_features[name], gen_style)for name, gen_style in generated_style.items())total_loss = style_loss + content_lossoptimizer.zero_grad()total_loss.backward()# 梯度裁剪(可选但稳健)torch.nn.utils.clip_grad_norm_([generated_img], max_norm=1.0)optimizer.step()p_bar.set_postfix(style_loss=style_loss.item(), content_loss=content_loss.item(), total_loss=total_loss.item())# 保存中间结果save_image(generated_img, args.output_dir, f'generated_epoch_{epoch + 1}.jpg')

效果图

89_8

fast style transfer 代码

刚才我们看了基于迭代的图像风格迁移,那么能不能有个方法,我们训练一个专用于一种风格的图像迁移的网络,其网络训练好之后,我们就不再需要老是输入风格图像,这个训练好的网络可以处理所有输入的图像输出其已经训练好的风格。所以fast style transfer 来了。其不仅可以处理图片,同样可以处理视频哦!

image-20250502132529144

这里只展现了部分的核心代码,所有代码位于Style_transfer ,下载可直接使用。

import argparse
import os
from datetime import datetime, timedelta
from typing import List, Iterablefrom PIL import Imageimport torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import cv2
import numpy as npfrom models import VGG, TransNet
from datasets import ImageDataset
from utils import load_image, save_image, make_transform, save_model, calculate_style_loss, calculate_content_loss, \denormalizedef parse_args():parser = argparse.ArgumentParser(description='Fast Style Transfer')parser.add_argument('--mode', type=str, choices=['train', 'image', 'video'], default='train',help='Operation mode: train, image, video')parser.add_argument('--image_size', type=int, nargs=2, default=[300, 450], help='Image size (height, width)')parser.add_argument('--style_image', type=str, default='./data/style3.jpg', help='Style image path (training mode only)')# The directory data/train2017/default_class contains several images. Due to the format required by ImageFolder, we need to use a class folder to contain the images, even though the class label is not used.parser.add_argument('--content_dataset', type=str, default='data/train2014', help='Content image dataset path (training mode only)')parser.add_argument('--content_weight', type=float, default=1., help='Content weight (training mode only)')parser.add_argument('--style_weight', type=float, default=15., help='Style weight (training mode only)')parser.add_argument('--model_save_path', type=str, default=f'./checkpoint/style2.pth', help='Model save path (training mode)')parser.add_argument('--pretrained_model_path', type=str, help='Pretrained model load path (training mode only)')parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (training mode)')parser.add_argument('--save_interval', type=int, default=120, help='Save interval (seconds) (training mode)')parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (training mode)')parser.add_argument('--output_dir', type=str, default='./output/realtime_transfer', help='Output directory (training mode)')parser.add_argument('--model_path', type=str, default='./checkpoint/style1.pth', help='Model load path (image and video mode)')parser.add_argument('--input_images_dir', type=str, help='Input images root directory (image mode only)')parser.add_argument('--output_images_dir', type=str, default='./output/fast_style_transfer/image_generated.jpg',help='Output images root directory (image mode only)')parser.add_argument('--video_input', type=str, default='data/maigua.mp4', help='Input video path (video mode only)')parser.add_argument('--video_output', type=str, default='output/videos/maigua.mp4',help='Output video path (video mode only)')parser.add_argument('--batch_size', type=int, default=4, help='Batch size')return parser.parse_args()def train(model,vgg,lr, epochs, batch_size, style_weight, content_weight, style_layers, content_layers,device, transform,image_style,content_dataset_root,save_path, output_dir,save_interval=timedelta(seconds=120)):optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # 对目标网络进行优化dataset = ImageDataset(content_dataset_root, transform=transform)dataloader = DataLoader(dataset, batch_size, shuffle=True)_, style_features = vgg(image_style)p_bar = tqdm(range(epochs))last_save_time = datetime.now() - save_intervalfor epoch in p_bar:running_content_loss, running_style_loss = 0.0, 0.0for i, content_img in enumerate(dataloader):content_img = content_img.to(device)image_generated = model(content_img)  # 只使用内容图像进行风格迁移generated_content, generated_style = vgg(image_generated)style_loss = sum(style_weight * style_layers[name] * calculate_style_loss(style_features[name], gen_style) forname, gen_style in generated_style.items())content_features, _ = vgg(content_img)  # 计算内容图的内容特征content_loss = sum(content_weight * content_layers[name] * calculate_content_loss(content_features[name], gen_content) forname, gen_content in generated_content.items())total_loss = style_loss + content_lossoptimizer.zero_grad()total_loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)  # 梯度裁剪optimizer.step()running_content_loss += content_loss.item()running_style_loss += style_loss.item()p_bar.set_postfix(progress=f'{(i + 1) / len(dataloader) * 100:.3f}%',style_loss=f"{style_loss.item():.3f}",content_loss=f"{content_loss.item():.3f}",last_save_time=last_save_time)if datetime.now() - last_save_time > save_interval:last_save_time = datetime.now()#writer.add_images('image_generated', denormalize(image_generated), epoch * len(dataloader) + i)save_model(model, save_path)  # 'fast_style_transfer.pth'save_image(torch.cat((image_generated, content_img), 3), output_dir, f'{epoch}_{i}.jpg')def process_images(images: Iterable[Image.Image], transform, model, device) -> List[Image.Image]:images = torch.stack([transform(image) for image in images]).to(device)model.to(device)batch_generated = model(images)batch_generated = denormalize(batch_generated).detach().cpu()batch_generated = [transforms.ToPILImage()(image) for image in batch_generated]return batch_generateddef process_video(video_path, output_path, transform, model, device, batch_size=4):# 打开视频文件cap = cv2.VideoCapture(video_path)output_dir, filename = os.path.split(output_path)if output_dir and not os.path.exists(output_dir):os.makedirs(output_dir)# 获取视频属性frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))fps = cap.get(cv2.CAP_PROP_FPS)total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))# 定义视频编码器和创建 VideoWriter 对象fourcc = cv2.VideoWriter_fourcc(*'mp4v')out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))# 初始化 tqdm 进度条pbar = tqdm(total=total_frames, desc="Processing Video")# 读取视频并批量处理frames = []while True:ret, frame = cap.read()if not ret:breakframes.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))if len(frames) == batch_size:batch_generated = process_images(frames, transform, model, device)for gen_frame in batch_generated:gen = cv2.cvtColor(np.array(gen_frame), cv2.COLOR_RGB2BGR)cv2.imshow("output", gen)gen = cv2.resize(gen, (frame_width, frame_height))out.write(gen)if cv2.waitKey(1) & 0xFF == ord('s'):breakframes.clear()pbar.update(batch_size)if frames:batch_generated = process_images(frames, transform, model, device)for gen_frame in batch_generated:out.write(cv2.cvtColor(np.array(gen_frame), cv2.COLOR_RGB2BGR))pbar.update(len(frames))print(f'video successfully saved to: {output_path}')pbar.close()cap.release()out.release()if __name__ == '__main__':args = parse_args()# print(args)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# ----------------路径参数----------------# 内容特征层及loss加权系数content_layers = {'5': 0.5, '10': 0.5}  # 使用vgg的较浅层特征作为内容特征,保证生成图片内容结构相似性# 风格特征层及loss加权系数style_layers = {'0': 0.2, '5': 0.2, '10': 0.2, '19': 0.2, '28': 0.2}  # 使用vgg不同深度的风格特征,生成风格更加层次丰富transform = make_transform(size=args.image_size, normalize=True)  # 图像变换image_style = load_image(args.style_image, transform=transform).to(device)  # 风格图像vgg = VGG(content_layers, style_layers).to(device)  # 特征提取网络,只用来提取特征,不进行训练model = TransNet(input_size=args.image_size).to(device)  # 内容生成网络,用于生成风格图片,进行训练if args.mode != 'train' and getattr(args, 'model_path'):if not os.path.exists(args.model_path):raise FileNotFoundError(f'{args.model_path}不存在!')model.load_state_dict(torch.load(args.model_path))elif args.mode == 'train' and getattr(args, 'pretrained_model_path') and os.path.exists(args.pretrained_model_path):if not os.path.exists(args.pretrained_model_path):raise FileNotFoundError(f'{args.pretrained_model_path}不存在!')model.load_state_dict(torch.load(args.pretrained_model_path))if args.mode == 'train':# 训练模式# 使用大规模内容图像数据训练快速图像风格迁移网络,比如COCO2017数据集train(model, vgg, args.learning_rate, args.epochs, args.batch_size, args.style_weight, args.content_weight,style_layers,content_layers, device, transform, image_style, args.content_dataset, args.model_save_path,args.output_dir,save_interval=timedelta(seconds=args.save_interval))elif args.mode == 'image':# 使用训练好的风格迁移模型演示批量处理图片if not os.path.exists(args.output_images_dir):os.makedirs(args.output_images_dir)for filename in tqdm(os.listdir(args.input_images_dir), desc='Processing Images'):try:filepath = os.path.join(args.input_images_dir, filename)images_generated = process_images([Image.open(filepath)], transform, model, device)images_generated[0].save(os.path.join(args.output_images_dir, filename))except Exception as e:passelif args.mode == 'video':# 视频处理模式process_video(args.video_input, args.video_output, transform, model, device,batch_size=args.batch_size)else:raise ValueError("未知的运行模式")

效果图

tutieshi_640x480_3s

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/78321.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python爬虫基础:requests库详解与案例

1.Requests模块的使用 requests模块的介绍与安装 作用:发送网络请求,返回响应数据。 中文文档:https://requests.readthedocs.io/projects/cn/zh_CN/latest/ 对于爬虫任务,使用 requests模块基本能够解决绝大部分的数据抓取的…

Spring 容器相关的核心注解​

以下是 Spring 容器中用于 ​​Bean 管理、依赖注入、配置控制​​ 的关键注解,按功能分类说明: ​​1. Bean 声明与注册​​ 注解作用示例​​Component​​通用注解,标记一个类为 Spring Bean(自动扫描注册) Compo…

C与指针5——字符串合集

常用函数 1、拷贝、长度、比较 size_t strlen();\\返回无符号整形 char* strcpy();char* strncpy();\\拷贝 int strcmp();int strncmp();\\比较 char* strcat();char* strncat();\\连接2、查找 char* strchr(const char * st,int ch);\\找字符第一次出现的位置 char* strrch…

论软件需求管理

目录 摘要(300~330字) 正文(2000~2500字,2200字为宜) 背景介绍(500字做左右) 论点论据(1500字做左右) 收尾(200字左右) 注:本篇论…

[特殊字符] 如何在比赛前调整到最佳状态:科学与策略结合的优化指

🧠 概述 在竞技体育中,赛前状态的调整对比赛结果起着决定性作用。所谓“最佳状态”,不仅指生理上的巅峰表现,更包括心理、认知、营养和恢复等多方面的协同优化。本文结合运动科学、心理学和营养学的研究成果,探讨赛前…

一种实波束前视扫描雷达目标二维定位方法——论文阅读

一种实波束前视扫描雷达目标二维定位方法 1. 专利的研究目标与实际问题意义2. 专利提出的新方法、模型与公式2.1 运动平台几何建模与回波信号构建2.1.1 距离历史建模2.1.2 回波信号模型2.2 距离向运动补偿技术2.2.1 匹配滤波与距离压缩2.3 加权最小二乘目标函数2.3.1 方位向信号…

基于 Spring Boot 瑞吉外卖系统开发(八)

基于 Spring Boot 瑞吉外卖系统开发(八) 自动填充公共字段 MyBatis-Plus公共字段自动填充,也就是在插入或者更新的时候为指定字段赋予指定的值,使用它的好处就是可以统一对这些字段进行处理,降低了冗余代码的数量。本…

【前端】从零开始的搭建结构(技术栈:Node.js + Express + MongoDB + React)book-management

项目路径总结 后端结构 server/ ├── controllers/ # 业务逻辑 │ ├── authController.js │ ├── bookController.js │ ├── genreController.js │ └── userController.js ├── middleware/ # 中间件 │ ├── authMiddleware…

【RAG】向量?知识库的底层原理:向量数据库の技术鉴赏 | HNSW(导航小世界)、LSH、K-means

一、向量化表示的核心概念 1.1 特征空间与向量表示 多维特征表示:通过多个特征维度(如体型、毛发长度、鼻子长短等)描述对象,每个对象对应高维空间中的一个坐标点,来表示狗这个对象,这样可以区分出不同种…

如何用CSS实现HTML元素的旋转效果

原文:如何用CSS实现HTML元素的旋转效果 | w3cschool笔记 (本文为科普文章,请勿标记为付费) 在网页制作中,为 HTML 元素设置旋转效果可使其更灵动,提升用户体验。本文将深入浅出地介绍如何利用 CSS 实现 H…

Spark集群搭建之Yarn模式

配置集群 1.上传并解压spark-3.1.2-bin-hadoop3.2.tgz,重命名解压之后的目录为spark-yarn。 2. 修改一下spark的环境变量,/etc/profile.d/my_env.sh 。 # spark 环境变量 export SPARK_HOME/opt/module/spark-yarn export PATH$PATH:$SPARK_HOME/bin:$SP…

xLua笔记

Generate Code干了什么 肉眼可见的,在Asset文件夹生成了XLua/Gen文件夹,里面有一些脚本。然后对加了[CSharpCallLua]的变量寻找引用,发现它被XLua/Gen/DelegatesGensBridge引用了。也可以在这里查哪些类型加了[CSharpCallLua]。 public over…

【tcp连接windows redis】

tcp连接windows redis 修改redis.conf 修改redis.conf bind * -::*表示禁用保护模式,允许外部网络连接 protected-mode no

【序列贪心】摆动序列 / 最长递增子序列 / 递增的三元子序列 / 最长连续递增序列

⭐️个人主页:小羊 ⭐️所属专栏:贪心算法 很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~ 目录 摆动序列最长递增子序列递增的三元子序列最长连续递增序列 摆动序列 摆动序列 贪心策略:统计出所有的极大值和极小…

STM32F103C8T6使用MLX90614模块

首先说明: 1.SMBus和I2C的区别 我曾尝试用江科大的I2C底层去直接读取该模块,但是无法成功,之后AI生成的的代码也无法成功。 思来想去最大的可能就是SMBus这个协议的问题,根据百度得到的结果如下: SMBus和I2C的区别 链…

tp5 php获取农历年月日干支甲午

# 切换为国内镜像源 composer config -g repo.packagist composer https://mirrors.aliyun.com/composer/# 再次尝试安装 composer require overtrue/chinese-calendar核心写法一个农历转公历,一个公历转农历 农历闰月可能被错误标记(例如 闰四月 应表示…

Ubuntu搭建Conda+Python开发环境

目录 一、环境说明 1、测试环境为ubuntu24.04.1 2、更新系统环境 3、安装wget工具 4、下载miniconda安装脚本 二、安装步骤 1、安装miniconda 2、source conda 3、验证版本 4、配置pip源 三、conda用法 1、常用指令 一、环境说明 1、测试环境为ubuntu24.04.1 2、更…

Vscode+git笔记

1.U是untracked m是modify modified修改了的。 2.check out 查看观察 3 status changed 暂存区 4.fetch v 取来拿来 5.orangion 起源代表远程分支 git checkout就是可以理解为进入的意思。

模拟SIP终端向Freeswitch注册用户

1、简介 使用go语言编写一个程序,模拟SIP-T58终端在Freeswitch上注册用户 2、思路 以客户端向服务端Freeswitch发起REGISTER请求,告知服务器当前的联系地址构造SIP REGISTER请求 创建UDP连接,连接到Freeswitch的5060端口发送初始的REGISTER请…

DeepSeek实战--LLM微调

1.为什么是微调 ? 微调LLM(Fine-tuning Large Language Models) 是指基于预训练好的大型语言模型(如GPT、LLaMA、PaLM等),通过特定领域或任务的数据进一步训练,使其适应具体需求的过程。它是将…