【cv】cycleGAN代码解析:train.py

【cv】cycleGAN代码解析:train.py

Posted on 2025-09-25 16:37  SaTsuki26681534  阅读(0)  评论(0)    收藏  举报
import time                                   # 计时:统计每轮/每次迭代耗时
from options.train_options import TrainOptions # 训练期命令行参数解析器(继承 BaseOptions 并添加训练相关项)
from data import create_dataset                # 工厂函数:按 opt.dataset_mode 创建数据集实例
from models import create_model                # 工厂函数:按 opt.model 创建对应模型(如 CycleGAN、pix2pix)
from util.visualizer import Visualizer         # 可视化与日志工具:显示/保存图像、打印/绘制损失曲线
if __name__ == "__main__":                     # 仅在作为脚本运行时执行,避免被 import 时跑训练opt = TrainOptions().parse()               # 解析命令行参数,得到训练配置对象 opt# 从train_options.py里得到命令行参数的对象dataset = create_dataset(opt)              # 依据选项创建数据集(并可能内部创建 DataLoader)# 这里的dataset对象里包含数据集对象以及dataLoaderdataset_size = len(dataset)                # 获取数据集中图像数量(通常是样本数或步数的近似)print(f"The number of training images = {dataset_size}")  # 打印样本数量,便于确认数据是否加载正确# 创建模型并进行一定的初始化操作model = create_model(opt)                  # 创建指定的模型实例(如 CycleGANModel)model.setup(opt)                           # 常规设置:加载/打印网络,创建学习率调度器(若训练/续训)visualizer = Visualizer(opt)               # 创建可视化器:负责 visdom/HTML 图像展示与损失曲线绘制total_iters = 0                            # 全局累计的“迭代步数”(以样本/批为单位累加)for epoch in range(                        # 外层按 epoch 循环opt.epoch_count,                       # 起始 epoch(支持从中间轮次续训)opt.n_epochs + opt.n_epochs_decay + 1  # 训练轮数 = 预热阶段 + 线性衰减阶段;+1 使得上界包含在内):# 构造for循环,开始每个epoch的操作epoch_start_time = time.time()         # 记录该轮开始时间(统计整轮耗时)iter_data_time = time.time()           # 记录上一次取数据的时间(用于统计 data loading 时间)# (先定义好iter_data_time这个变量)epoch_iter = 0                         # 当前 epoch 内已处理的样本数(或近似步数),每轮重置# 记录当前epoch里经历的iter数visualizer.reset()                     # 重置可视化器:确保至少每个 epoch 会把结果写入 HTML

enumerate机制和可迭代对象

enumerate

enumerate是 Python 的内置函数,核心作用是为可迭代对象的元素添加索引,方便在迭代时同时获取 “索引” 和 “元素值”

enumerate(iterable)会返回一个枚举对象(enumerate object),它本质是一个迭代器(iterator)。每次迭代时,这个迭代器会返回一个元组(索引, 元素),其中:

第一个元素是当前迭代的索引(默认从 0 开始,可通过start参数指定起始值,如enumerate(lst, start=1));
第二个元素是可迭代对象iterable中的对应元素。

因此,在for i, item in enumerate(iterable):中,i接收索引,item接收元素,这是对元组(i, item)的解包操作

可迭代对象

enumerate函数的传入参数必须是可迭代对象

在 Python 中,可迭代对象需要满足:实现了__iter__()方法,该方法返回一个迭代器(iterator);而迭代器需要实现__next__()方法(用于返回下一个元素)和__iter__()方法(返回自身)

参考文献:https://zhuanlan.zhihu.com/p/7364648529

        for i, data in enumerate(dataset):     # 内层按批次迭代数据集(DataLoader 可迭代)# 当这里迭代dataset对象时,其实是在迭代里面的dataLoader对象,以Iter为单位# 即,每次取出一个iter的数据# data对象里包含着dataLoader本次取出的数据的相关信息,在后面的set_input方法里会把这些数据加载进去iter_start_time = time.time()      # 记录本次迭代计算开始时间if total_iters % opt.print_freq == 0:      # 每隔 print_freq 次统计一次“取数耗时”# parser.add_argument('--print_freq', type=int, default=100, # help='frequency of showing training results on console')# print_freq是打印训练结果的频率t_data = iter_start_time - iter_data_timetotal_iters += opt.batch_size       # 全局步数累加(以 batch_size 作为步长)epoch_iter += opt.batch_size        # 当前轮的步数累加model.set_input(data)               # 解包 dataloader 返回的 data,并搬到正确的 device# 这里模型已经得到了dataLoader里的数据了model.optimize_parameters()         # 前向、计算损失、反传、更新网络参数(一次标准训练步)# 这一句就是训练的核心操作

回顾一下这个函数

    def optimize_parameters(self):"""Calculate losses, gradients, and update network weights; called in every training iteration"""# 1) 前向:生成假图与重建图self.forward()# 2) 优化生成器(冻结两个判别器的梯度)self.set_requires_grad([self.netD_A, self.netD_B], False)self.optimizer_G.zero_grad()self.backward_G()self.optimizer_G.step()# 3) 优化判别器(解冻判别器)self.set_requires_grad([self.netD_A, self.netD_B], True)self.optimizer_D.zero_grad()self.backward_D_A()self.backward_D_B()self.optimizer_D.step()
if total_iters % opt.display_freq == 0:  # 到了展示频率:显示图像并(可选)保存到 HTMLsave_result = total_iters % opt.update_html_freq == 0  # 是否本次也写 HTML(更低频地写盘)model.compute_visuals()         # 生成额外可视化结果(由模型实现,如重建/中间图)visualizer.display_current_results(model.get_current_visuals(),  # 从模型取出需要展示的可视化张量epoch,                        # 当前 epoch 序号total_iters,                  # 全局迭代计数(用于命名与记录)save_result                   # 是否保存 HTML(否则只在 visdom 上显示))if total_iters % opt.print_freq == 0:    # 到了打印频率:打印损失并记录到磁盘losses = model.get_current_losses()  # 从模型取回当前各项损失(有序字典)t_comp = (time.time() - iter_start_time) / opt.batch_size  # 计算每张/每样本平均计算耗时visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)  # 控制台打印visualizer.plot_current_losses(total_iters, losses)                         # 动态曲线if total_iters % opt.save_latest_freq == 0:  # 到了“保存最新模型”的频率:保存快照print(f"saving the latest model (epoch {epoch}, total_iters {total_iters})")save_suffix = f"iter_{total_iters}" if opt.save_by_iter else "latest"  # 可按迭代编号或统一 latestmodel.save_networks(save_suffix)  # 以 <suffix> 作为文件名后缀存盘(见 BaseModel.save_networks)iter_data_time = time.time()         # 更新“上一次取数时间”,用于下一个 batch 的 t_data 统计model.update_learning_rate()             # 每个 epoch 结束时根据策略更新学习率(调度器 step)if epoch % opt.save_epoch_freq == 0:     # 到了“按 epoch 频率保存”的时刻:保存 latest 与该轮编号print(f"saving the model at the end of epoch {epoch}, iters {total_iters}")model.save_networks("latest")        # 保存成 latest(覆盖)model.save_networks(epoch)           # 再保存一个以 epoch 编号命名的权重(留存历史)print(                                   # 打印本轮结束信息与整轮耗时(四舍五入到秒)f"End of epoch {epoch} / {opt.n_epochs + opt.n_epochs_decay} \t "f"Time Taken: {time.time() - epoch_start_time:.0f} sec")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/917237.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

做移动网站优化网站建设过程中要怎么打开速度

这次我们将学着怎么从一个深度图里面导出边界。我们对3种不同种类的点很感兴趣:物体的边框的点&#xff0c;阴影边框点&#xff0c;和面纱点(在障碍物边界和阴影边界)&#xff0c;这是一个很典型的现象在通过雷达获取的3D深度。 下面是代码 /* \author Bastian Steder */#incl…

注册 网站开发 公司重庆招聘网

作者 | 轩辕之风O来源 | 编程技术宇宙相信大家这两天应该被这么一条新闻刷屏了&#xff1a;这个漏洞到底是怎么回事&#xff1f;核弹级&#xff0c;真的有那么厉害吗&#xff1f;怎么利用这个漏洞呢&#xff1f;我看了很多技术分析文章&#xff0c;都太过专业&#xff0c;很多非…

创建网站怎么创电子商务网站有哪些类型

项目介绍&#xff1a; 使用javaspringbootmysql开发的法律咨询网&#xff08;文书&#xff09;&#xff0c;系统包含管理员、用户角色&#xff0c;功能如下&#xff1a; 管理员&#xff1a;登录系统&#xff1b;用户管理&#xff1b;文章管理&#xff08;法律知识&#xff09…

网站建设公司业务在哪里来百度网站的总结

I老师就职于双非二本院校&#xff0c;希望通过出国研修以提升科研背景&#xff0c;在公派访学和申请导师出资的博士后之间&#xff0c;其选择了后者。最终我们落实了美国耶鲁大学的职位&#xff0c;头衔为Associate Research Scientist&#xff08;副研究科学家&#xff09;&am…

企业网站 联系我们电商培训机构有哪些?哪家比较好

常用示例 入门 Hello CMake CMake 是一个用于配置跨平台源代码项目应该如何配置的工具建立在给定的平台上。 ├── CMakeLists.txt # 希望运行的 CMake命令 ├── main.cpp # 带有main 的源文件 ├── include # 头文件目录 │ └── header.h └── src # 源代码目录 ├…

深入解析:李宏毅2023机器学习作业 HW01实操

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

基于Java+SpringBoot+SSM,Flask福聚苑社区团购体系(源码+LW+调试文档+讲解等)/福聚苑社区/团购系统/社区团购/福聚苑/团购/社区/环境/福聚苑小区/在线团购/社区购物

基于Java+SpringBoot+SSM,Flask福聚苑社区团购体系(源码+LW+调试文档+讲解等)/福聚苑社区/团购系统/社区团购/福聚苑/团购/社区/环境/福聚苑小区/在线团购/社区购物pre { white-space: pre !important; word-wrap: nor…

按需引入echarts

--// echarts-config.js // ECharts按需引入配置文件 import * as echarts from echarts/core; import {BarChart,LineChart,PieChart,ScatterChart,RadarChart } from echarts/charts; import {TitleComponent,Toolti…

软件构造的用户交互设计 4章

交互设计的原则 1.尽量保持一致 2.满足普遍可用性 3.提供信息反馈 4.设计对话框以产生结束信息 5.预防并成立错误 6.允许撤销操作 7.支持内部控制点 8.减轻短时记忆负担 交互设计的基本过程 标识和建立用户需求 提出满…

自定义制作docker容器自动自愈容器镜像

包括:完整的 autoheal.sh(支持每分钟检查一次、连续 5 次 unhealthy 才重启) Dockerfile docker-compose.yml 详细文档,包含参数说明、用法1️⃣ autoheal.sh #!/usr/bin/env sh set -e set -o pipefailDOCKER_SOC…

阀门公司网站建设广州动漫制作公司

1 mpl_toolkits.mplot3d 功能介绍 mpl_toolkits.mplot3d 是 Matplotlib 库中的一个子模块&#xff0c;用于绘制和可视化三维图形&#xff0c;包括三维散点图、曲面图、线图等。它提供了丰富的功能来创建和定制三维图形。以下是 mpl_toolkits.mplot3d 的主要功能和功能简介&am…

如何利用海外 NetNut 网络代理与 AICoding 实战获取 iPhone 17 新品用户评论数据?

如何利用海外 NetNut 网络代理与 AICoding 实战获取 iPhone 17 新品用户评论数据?如何利用海外 NetNut 网络代理与 AICoding 实战获取 iPhone 17 新品用户评论数据? 一、引言 在数据驱动时代,开发者与研究者越来越依…

第一次编码器测试

共1055圈 平均2047.974408 平均每张丢失距离 0.00001132 mm可以忽略 不丢帧

04-FreeRTOS的概述及编程规范

概述 本文对FreeRTOS源码进行概述,包括其核心文件作用,及其编程规范,有助于阅读rtos的内核源码,更好的帮助理解。 一、FreeRTOS 源码核心结构概述 FreeRTOS 是轻量级实时操作系统,核心功能围绕 “任务调度” 和 “…

10_ select/poll/epoll实现服务端的io多路复用

一、io多路复用 在现有模型中,似乎每一个线程都做了同样的事情,1、监听客户端消息;2、业务消息处理。 “一消息一线程”的缺点究其根本,在于让每个线程都做了同样重复、且消耗资源巨大的事情——单独持有fd、监听客…

模拟实战配置实验

vlan之间的互通 要实现 VLAN 10(192.168.150.0/24)、VLAN 100(192.168.100.0/24)、VLAN 200(192.168.200.0/24) 之间的互联互通,核心原理是:二层交换机仅负责 VLAN 内流量转发,跨 VLAN 流量需通过三层设备(核…

微网站建设的现状设计吧

首先要理解double的存储方式&#xff0c;具体可查找相关的博客本文实现的是将8个字节(存储为16进制的字符串)转化为对应的double类型double MainWindow::qByteArraytodouble(QString qstr){QByteArray byte;StringToHex(qstr,byte);double result;memcpy(&result, byte.dat…

聚力赋能|竹云受邀出席2025华为全联接大会 - 详解

聚力赋能|竹云受邀出席2025华为全联接大会 - 详解2025-09-25 16:21 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; displ…

Linux安装Kafka(无Zookeeper模式)保姆级教程,云服务器安装部署,Windows内存不够允许看看

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

国标GB28181公网直播EasyGBS如何构建全域覆盖的应急管理与安全生产解决方案?

在当今社会,安全生产和应急管理已经成为各行各业不可或缺的重要部分。全面提高安全生产管理水平、构建责任全覆盖、监管全过程、监管全方位的综合治理体系已成为社会发展的必然趋势。国标GB28181网页直播平台EasyGBS作…