import os # 标准库:操作系统相关(本文件中未直接使用)
import torch # PyTorch 主库
from pathlib import Path # 处理路径
from collections import OrderedDict # 有序字典:用于按固定顺序组织可视化/损失项
from abc import ABC, abstractmethod # 抽象基类支持
from . import networks # 网络与训练调度相关工具(调度器、网络构造等)class BaseModel(ABC):"""This class is an abstract base class (ABC) for models.To create a subclass, you need to implement the following five functions:-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).-- <set_input>: unpack data from dataset and apply preprocessing.-- <forward>: produce intermediate results.-- <optimize_parameters>: calculate losses, gradients, and update network weights.-- <modify_commandline_options>: (optionally) add model-specific options and set default options."""# ↑ 英文文档字符串:说明该类是所有模型的抽象基类;子类必须实现的 5 个关键方法def __init__(self, opt):"""Initialize the BaseModel class.Parameters:opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptionsWhen creating your custom class, you need to implement your own initialization.In this function, you should first call <BaseModel.__init__>(self, opt)Then, you need to define four lists:-- self.loss_names (str list): specify the training losses that you want to plot and save.-- self.model_names (str list): define networks used in our training.-- self.visual_names (str list): specify the images that you want to display and save.-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example."""self.opt = opt # 保存全局/命令行配置self.gpu_ids = opt.gpu_ids # 设备 id 列表self.isTrain = opt.isTrain # 训练/测试标志# 根据 gpu_ids 是否为空来选择设备;若为空则用 CPU,否则用第一个 GPUself.device = torch.device("cuda:{}".format(self.gpu_ids[0])) if self.gpu_ids else torch.device("cpu") # get device name: CPU or GPU# 检查点目录:<checkpoints_dir>/<experiment_name>self.save_dir = Path(opt.checkpoints_dir) / opt.name # save all the checkpoints to save_dir# 为了提升 cudnn 搜索最优卷积算法的速度:当不是 scale_width 预处理时开启 benchmarkif (opt.preprocess != "scale_width"
# 当预处理方式不是"scale_width"时,开启该模式,以加速后续的卷积等操作。): # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.torch.backends.cudnn.benchmark = True
# cudnn是 NVIDIA 推出的针对深度学习的 GPU 加速库,专门优化卷积、池化等常用操作。
# torch.backends.cudnn.benchmark = True表示开启cudnn 的基准测试模式:
# 开启后,程序会在首次运行时对当前硬件上可用的卷积算法(如不同的卷积实现方式)进行一次 “基准测试”(耗时很短),
# 找到当前输入尺寸下最优的算法。
# 之后的卷积操作会固定使用这个最优算法,避免每次运行时重新选择算法的开销,从而提升整体性能# 下面这些列表由子类在 __init__ 中填充,用于日志记录、保存、展示及优化器管理self.loss_names = [] # 需要记录/可视化/保存的损失项名(不含前缀 'loss_')self.model_names = [] # 需要管理/保存/加载的网络名称后缀(比如 ['G','D'])self.visual_names = [] # 需要可视化/保存成网页的图像张量名称self.optimizers = [] # 优化器列表(通常与 model_names 对应)self.image_paths = [] # 当前 batch 对应的图像路径列表(供日志/可视化使用)self.metric = 0 # 学习率策略 'plateau' 的监控指标@staticmethoddef modify_commandline_options(parser, is_train):"""Add new model-specific options, and rewrite default values for existing options.Parameters:parser -- original option parseris_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.Returns:the modified parser."""# 子类可重写该静态方法,往 argparse 中添加模型特定参数或覆盖默认值;默认不做修改return parser@abstractmethoddef set_input(self, input):"""Unpack input data from the dataloader and perform necessary pre-processing steps.Parameters:input (dict): includes the data itself and its metadata information."""# 抽象方法:从数据加载器中取出一个 batch,并完成必要的预处理/搬运到 self.devicepass@abstractmethoddef forward(self):"""Run forward pass; called by both functions <optimize_parameters> and <test>."""# 抽象方法:前向计算,用于训练与测试阶段pass@abstractmethoddef optimize_parameters(self):"""Calculate losses, gradients, and update network weights; called in every training iteration"""# 抽象方法:一次训练迭代内的完整优化步骤(计算损失→反传→更新参数)passdef setup(self, opt):"""Load and print networks; create schedulersParameters:opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions"""# 训练阶段:基于当前优化器创建学习率调度器(可能是 step/plateau/cosine 等)if self.isTrain:self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]# 测试阶段或断点续训:根据 epoch / iter 载入已保存的网络权重if not self.isTrain or opt.continue_train:load_suffix = "iter_%d" % opt.load_iter if opt.load_iter > 0 else opt.epoch # 加载优先按指定 iter,否则按 epochself.load_networks(load_suffix)# 打印网络参数规模与(可选)结构self.print_networks(opt.verbose)# 若环境支持(PyTorch 2.0+),可选地对网络应用 torch.compile 以获得运行时优化if hasattr(torch, "compile"):self.compile_networks()def eval(self):"""Make models eval mode during test time"""# 测试时将各网络切换到 eval()(影响 BN/Dropout 等行为)for name in self.model_names:if isinstance(name, str):net = getattr(self, "net" + name)net.eval()def test(self):"""Forward function used in test time.This function wraps <forward> function in no_grad() so we don't save intermediate steps for backpropIt also calls <compute_visuals> to produce additional visualization results"""# 测试前向:禁用梯度,避免保存中间结果,随后计算可视化用的额外输出with torch.no_grad():self.forward()self.compute_visuals()def compute_visuals(self):"""Calculate additional output images for visdom and HTML visualization"""# 钩子:由子类实现,生成额外的可视化图像(例如:中间特征/重建结果)passdef get_image_paths(self):"""Return image paths that are used to load current data"""# 返回当前 batch 的原始图像路径(通常由 set_input() 填充)return self.image_pathsdef update_learning_rate(self):"""Update learning rates for all the networks; called at the end of every epoch"""# 记录更新前的学习率(读取第一个优化器的第一个 param_group)old_lr = self.optimizers[0].param_groups[0]["lr"]# 按策略推进调度器:'plateau' 需要传入监控指标;其它策略直接 step()for scheduler in self.schedulers:if self.opt.lr_policy == "plateau":scheduler.step(self.metric)else:scheduler.step()# 打印学习率变化lr = self.optimizers[0].param_groups[0]["lr"]print("learning rate %.7f -> %.7f" % (old_lr, lr))def get_current_visuals(self):"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""# 将 visual_names 中列出的张量取出来,按名称组织成有序字典(便于日志和网页展示)visual_ret = OrderedDict()for name in self.visual_names:if isinstance(name, str):visual_ret[name] = getattr(self, name)return visual_retdef get_current_losses(self):"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""# 将 loss_names 中列出的损失张量(属性名为 'loss_'+name)读取并转成 floaterrors_ret = OrderedDict()for name in self.loss_names:if isinstance(name, str):errors_ret[name] = float(getattr(self, "loss_" + name)) # float(...) works for both scalar tensor and float numberreturn errors_retdef save_networks(self, epoch):"""Save all the networks to the disk.Parameters:epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)"""# 依次保存 model_names 中的所有网络到 <save_dir>/<epoch>_net_<name>.pthfor name in self.model_names:if isinstance(name, str):save_filename = f"{epoch}_net_{name}.pth"save_path = self.save_dir / save_filenamenet = getattr(self, "net" + name)if len(self.gpu_ids) > 0 and torch.cuda.is_available():# DataParallel/DistributedDataParallel 的 net 可能包在 .module 里;# 保存时先移到 CPU 再保存,随后将网络移回 GPU(首个 id)torch.save(net.module.cpu().state_dict(), save_path)net.cuda(self.gpu_ids[0])else:torch.save(net.cpu().state_dict(), save_path)def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""# 递归地修补老版本(0.4 之前)InstanceNorm 的 state_dict 键/缓冲区不兼容问题key = keys[i]if i + 1 == len(keys): # at the end, pointing to a parameter/buffer# 如果是 InstanceNorm 且键是 running_mean / running_var,而对应属性为 None,则从 state_dict 移除该键if module.__class__.__name__.startswith("InstanceNorm") and (key == "running_mean" or key == "running_var"):if getattr(module, key) is None:state_dict.pop(".".join(keys))# 同理,移除 num_batches_tracked(老版不包含)if module.__class__.__name__.startswith("InstanceNorm") and (key == "num_batches_tracked"):state_dict.pop(".".join(keys))else:# 递归深入下一级模块self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)def load_networks(self, epoch):"""Load all the networks from the disk.Parameters:epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)"""# 依次加载各网络:从 <save_dir>/<epoch>_net_<name>.pth 读入 state_dict 并载入到模型for name in self.model_names:if isinstance(name, str):load_filename = f"{epoch}_net_{name}.pth"load_path = self.save_dir / load_filenamenet = getattr(self, "net" + name)# 若包裹在 DataParallel 中,先取出实际模块if isinstance(net, torch.nn.DataParallel):net = net.moduleprint(f"loading the model from {load_path}")# 注意:map_location 使用设备字符串;weights_only=True 仅加载权重张量(更安全)state_dict = torch.load(load_path, map_location=str(self.device), weights_only=True)# 某些保存格式会带 _metadata,用不到则删除避免干扰if hasattr(state_dict, "_metadata"):del state_dict._metadata# 老版本 InstanceNorm 的兼容性修补(在遍历中会修改字典,故先拷贝键列表)for key in list(state_dict.keys()): # need to copy keys here because we mutate in loopself.__patch_instance_norm_state_dict(state_dict, net, key.split("."))net.load_state_dict(state_dict)def compile_networks(self, **compile_kwargs):"""Apply torch.compile to all networks for optimization.Parameters:**compile_kwargs -- keyword arguments to pass to torch.compile(e.g., mode='reduce-overhead', backend='inductor')"""# 对每个网络应用 torch.compile(如 PyTorch 2.0 的 AOT 编译),可显著优化推理/训练性能for name in self.model_names:if isinstance(name, str):net = getattr(self, "net" + name)compiled_net = torch.compile(net, **compile_kwargs)setattr(self, "net" + name, compiled_net) # 将编译后的网络回写到实例属性print(f"[Network {name}] compiled with torch.compile")setattr(self, "net" + name, compiled_net) # 再次回写(功能上与上一行重复,虽无害但属冗余)def print_networks(self, verbose):"""Print the total number of parameters in the network and (if verbose) network architectureParameters:verbose (bool) -- if verbose: print the network architecture"""# 打印参数量统计;若 verbose=True 也打印完整网络结构print("---------- Networks initialized -------------")for name in self.model_names:if isinstance(name, str):net = getattr(self, "net" + name)num_params = 0for param in net.parameters():num_params += param.numel() # 累加所有参数张量的元素个数if verbose:print(net)print("[Network %s] Total number of parameters : %.3f M" % (name, num_params / 1e6))print("-----------------------------------------------")def set_requires_grad(self, nets, requires_grad=False):"""Set requies_grad=Fasle for all the networks to avoid unnecessary computationsParameters:nets (network list) -- a list of networksrequires_grad (bool) -- whether the networks require gradients or not"""# 将一个网络或网络列表统一设置 requires_grad 标志(常用于冻结判别器/特征提取器以节省计算)if not isinstance(nets, list):nets = [nets]for net in nets:if net is not None:for param in net.parameters():param.requires_grad = requires_grad
总结:
BaseModel.__init__()
└── 初始化成员变量(无内部方法调用)BaseModel.setup()
├── self.load_networks() # 加载网络权重
├── self.print_networks() # 打印网络信息
└── self.compile_networks() # (条件触发)编译网络(若torch.compile可用)BaseModel.test()
├── self.forward() # 调用前向传播(抽象方法,子类实现)
└── self.compute_visuals() # 计算可视化结果(默认空实现)BaseModel.update_learning_rate()
└── scheduler.step() # 调用优化器调度器的step方法(依赖networks模块的scheduler)BaseModel.save_networks()
└── torch.save() # 保存网络状态字典BaseModel.load_networks()
├── torch.load() # 加载网络状态字典
├── self.__patch_instance_norm_state_dict() # 修复InstanceNorm兼容性问题
└── net.load_state_dict() # 加载状态字典到网络BaseModel.compile_networks()
└── torch.compile() # 编译网络(PyTorch 2.0+特性)BaseModel.print_networks()
└── 遍历网络参数(无内部方法调用,打印信息)BaseModel.set_requires_grad()
└── 直接操作网络参数(无内部方法调用)
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/915962.shtml
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!相关文章
企业网站设计与实现论文移动网站系统
听说这是目录哦 FinalShell连接VMware🌤️解决重连失效FinalShell的使用 免密登录⛈️能量站😚 FinalShell连接VMware🌤️
保持虚拟机的开机状态,打开FinalShell,如果虚拟机关机或者挂起,连接就会断开。 …
做网站时图片要切片有什么作用可以做砍价链接的网站
车牌识别系统
YOLOv5和LPRNet的车牌识别系统结合了深度学习技术的先进车牌识别解决方案。该系统整合了YOLOv5目标检测框架和LPRNet文本识别模型
1. YOLOv5目标检测框架
YOLO是一种先进的目标检测算法,以其实时性能和高精度闻名。YOLOv5是在前几代基础上进行优化的…
南昌网站建设规划方案传媒公司网站源码php
引人入胜的开篇:想要搞清楚LSTM中的每个公式的每个细节为什么是这样子设计吗?想知道simple RNN是如何一步步的走向了LSTM吗?觉得LSTM的工作机制看不透?恭喜你打开了正确的文章! 前方核弹级高能预警!本文信息…
微信版网站开发用安卓做网站
幸福树,一种寓意美好的观赏型植物,它生长非常迅速,稍不注意就长的非常茂盛。而要想保证幸福树的美貌,跟人的头发一样,我们要给它适当的修剪,那幸福树怎么修剪呢?为了大家能养出美丽的幸福树来&a…
HarmonyOS后台任务调度:JobScheduler与WorkManager实战指南
本文将深入探讨HarmonyOS 5(API 12)中的后台任务调度机制,重点讲解JobScheduler和WorkManager的使用方法、适用场景及最佳实践,帮助开发者实现高效、智能的后台任务管理。1. 后台任务调度概述
HarmonyOS提供了两种…
学校站群框架如何开发插件实现Word图片的批量上传与编辑?
pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …
化妆品公司网站源码wordpress ip锁定插件
在React的类组件中,从组件创建到组件被挂载到页面中,这个过程react存在一系列的生命周期函数,最主要的生命周期函数是componentDidMount、componentDidUpdate、componentWillUnmount
生命周期图例如下 1. componentDidMount组件挂载
如果你…
怎样改网站英文域名保定定兴网站建设
来源:新战略机器人为什么需要协作机器人?协作机器人的兴起意味着传统机器人必然有某种程度的不足,或者无法适应新的市场需求。总结一下,主要有几点:传统机器人部署成本高其实相对来讲,工业机器人本身的价格…
广西工程造价信息网佛山seo优化排名推广
1、先登录服务器创建新目录aaa
2、云盘都快照备份下。后续操作完核实无误了,您根据您需求删除快照就行, 然后登录服务器内执行:
fdisk -l
sblk
blkid
ll /aaa
3、执行:(以下命令是进行数据盘做ext4文件系统并挂载…
HarmonyOS事件订阅与通知:后台事件处理
本文将深入探讨HarmonyOS 5(API 12)中的事件订阅与通知机制,重点讲解如何在后台处理事件,实现应用的实时响应和跨设备协同。内容涵盖核心API、实现步骤、实战示例及性能优化建议。1. 事件订阅与通知机制概述
Harmo…
HarmonyOS后台任务管理:短时与长时任务实战指南
本文将深入探讨HarmonyOS 5(API 12)中的后台任务管理机制,详细讲解短时任务和长时任务的适用场景、实现方法、性能优化及最佳实践,帮助开发者构建高效节能的后台任务系统。1. 后台任务概述与分类
HarmonyOS提供了完…
Kali Linux 2025.3 发布 (Vagrant Nexmon) - 领先的渗透测试发行版
Kali Linux 2025.3 发布 (Vagrant & Nexmon) - 领先的渗透测试发行版Kali Linux 2025.3 发布 (Vagrant & Nexmon) - 领先的渗透测试发行版
The most advanced Penetration Testing Distribution
请访问原文链接…
C语言多线程同步详解:从互斥锁到条件变量
在多线程编程中,线程同步是确保多个线程正确协作的关键技术。当多个线程访问共享资源时,如果没有适当的同步机制,可能会导致数据竞争、死锁等问题。本文将详细介绍C语言中常用的线程同步技术。
为什么需要线程同步?…
收废铁的做网站有优点吗完整网站设计
一、卸载 1. sudo apt-get autoclean 如果你的硬盘空间不大的话,可以定期运行这个程序,将已经删除了的软件包的.deb安装文件从硬盘中删除掉。如果你仍然需要硬盘空间的话,可以试试apt-get clean,这会把你已安装的软件包的安装包也…
微网站的好处服务器架设国外做违法网站
文章目录 给飞行中的飞机换引擎安全意识十原则开发层面产品层面运维层面给飞行中的飞机换引擎
所谓给飞行中的飞机(或飞驰的汽车)换引擎,说的是我们需要对一个正在飞速发展的系统进行大幅度的架构改造,比如把 All-in-one 的架构改造成微服务架构,尽可能减少或者消除停服的…
企业网站建设前言宁海县做企业网站
数据挖掘主要侧重解决四类问题:分类、聚类、关联、预测。数据挖掘非常清晰的界定了它所能解决的几类问题。这是一个高度的归纳,数据挖掘的应用就是把这几类问题演绎的一个过程。 数据挖掘最重要的要素是分析人员的相关业务知识和思维模式。一般来说&…
确实网站的建设目标一个网站突然打不开
https://www.jb51.net/article/106525.htm
本文实例讲述了JS实现的五级联动菜单效果。分享给大家供大家参考,具体如下:
js实现多级联动的方法很多,这里给出一种5级联动的例子,其实可以扩展成N级联动,在做项目的时候碰到了这样一…