ema_mnist_blog

使用ModelEmaV2优化MNIST分类模型

在深度学习模型的训练过程中,参数波动可能会导致模型在测试集上的性能不稳定。为了解决这个问题,可以使用指数移动平均(EMA)技术来平滑参数的更新,从而获得更稳定的模型。本文将介绍如何在MNIST数据集上使用ModelEmaV2来优化分类模型,并分析其效果。

实验背景

MNIST数据集是一个经典的手写数字识别数据集,包含60,000张训练图像和10,000张测试图像。我们的目标是训练一个简单的神经网络模型来分类这些手写数字,并使用EMA技术来优化模型参数。

模型定义与EMA实现

首先,我们定义一个简单的全连接神经网络模型,并实现ModelEmaV2来进行EMA参数更新。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import copy# 定义简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(28*28, 10)def forward(self, x):x = x.view(-1, 28*28)x = self.fc(x)return x# 初始化模型、损失函数和优化器
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 定义 EMA 模型
class ModelEmaV2(nn.Module):def __init__(self, model, decay=0.99, device='cpu'):super(ModelEmaV2, self).__init__()self.ema_model = copy.deepcopy(model).to(device)self.ema_model.eval()self.decay = decayself.device = devicedef update(self, model):with torch.no_grad():model_params = dict(model.named_parameters())ema_params = dict(self.ema_model.named_parameters())for k in model_params.keys():ema_params[k].mul_(self.decay).add_(model_params[k], alpha=1 - self.decay)def forward(self, x):return self.ema_model(x)

数据加载与预处理

我们使用torchvision库来加载和预处理MNIST数据集。

# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

训练与评估

我们进行4个epoch的训练,并在每个epoch结束后评估模型和EMA模型的准确率。

# 训练和评估
num_epochs = 4
results = []for epoch in range(num_epochs):model.train()for inputs, targets in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()# 更新 EMA 模型ema_model.update(model)# 计算每个epoch的准确率model.eval()ema_model.eval()correct = 0total = 0ema_correct = 0ema_total = 0with torch.no_grad():for inputs, targets in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += targets.size(0)correct += (predicted == targets).sum().item()# 测试 EMA 模型ema_outputs = ema_model(inputs)_, ema_predicted = torch.max(ema_outputs.data, 1)ema_total += targets.size(0)ema_correct += (ema_predicted == targets).sum().item()normal_accuracy = 100 * correct / totalema_accuracy = 100 * ema_correct / ema_totallag = normal_accuracy - ema_accuracyresults.append({'epoch': epoch + 1,'normal_accuracy': normal_accuracy,'ema_accuracy': ema_accuracy,'lag': lag})results

实验结果分析

实验结果如下表所示:

EpochNormal Model AccuracyEMA Model AccuracyLag
191.0990.970.12
292.5492.460.08
393.5393.500.03
494.0394.13-0.10

从结果可以看出,在训练的前几轮,EMA模型的准确率稍微滞后于正常模型,但随着训练的进行,两者的准确率逐渐接近,甚至在第四轮时,EMA模型的准确率略高于正常模型。

结论

通过实验可以看出,EMA技术在一定程度上平滑了模型参数的波动,使得模型在测试集上的表现更加稳定。尽管在训练的初期EMA模型的准确率稍有滞后,但随着训练的进行,EMA模型的表现逐渐赶上并超过了正常模型。这表明EMA技术对于提高模型的稳定性和性能具有重要作用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/19400.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

手拉手springboot整合kafka发送消息

环境介绍技术栈springbootmybatis-plusmysqlrocketmq软件版本mysql8IDEAIntelliJ IDEA 2022.2.1JDK17Spring Boot3.1.7kafka2.13-3.7.0 创建topic时,若不指定topic的分区(Partition主题分区数)数量使,则默认为1个分区(partition) springboot加入依赖kafk…

探索无限可能性——微软 Visio 2021 改变您的思维方式

在当今信息化时代,信息流动和数据处理已经成为各行各业的关键。微软 Visio 2021 作为领先的流程图和图表软件,帮助用户以直观、动态的方式呈现信息和数据,从而提高工作效率,优化业务流程。本文将介绍 Visio 2021 的特色功能及其在…

华为OD机试 - 游戏分组 - 递归(Java 2024 C卷 100分)

华为OD机试 2024C卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试(JAVA)真题(A卷B卷C卷)》。 刷的越多,抽中的概率越大,每一题都有详细的答题思路、详细的代码注释、样例测试…

精准检测,安全无忧:安全阀检测实践指南

安全阀作为一种重要的安全装置,在各类工业系统和设备中发挥着举足轻重的作用。 它通过自动控制内部压力,有效防止因压力过高而引发的设备损坏和事故风险,因此,对安全阀进行定期检测,确保其性能完好、工作可靠&#xf…

使用pytorch构建ResNet50模型训练猫狗数据集

数据集 1.导包 import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms, models import numpy as np import matplotlib.pyplot as plt import os from tqdm.auto import t…

流媒体服务器SMS-语音对讲(一)

1.简介 在国标语音对讲对接中,会发现不同的厂商或不同型号的设备,对讲流程都不一样,本文主要介绍流媒体与设备之间的交互情况。 SMS流媒体服务代码库地址:https://gitee.com/inyeme/simple-media-server 2.流媒体与设备交互的可能…

JS中延迟加载的方式有哪些

延迟加载(Lazy loading)是一种性能优化策略,它通过将资源的加载推迟到真正需要使用的时候,来减少页面初始加载的时间和资源消耗。以下是几种常见的延迟加载方式: 1. 图片延迟加载:将页面中的图片的src属性…

Maven pom文件profile的properties在yaml配置文件替换失效问题

Maven profile的properties在yaml配置文件替换失效问题 Maven profile的properties在yaml配置文件替换失效问题原来错误的配置修改后的配置 Maven profile的properties在yaml配置文件替换失效问题 原因:spring-boot项目需要使用进行分割,如yaml配置文件…

Golang:使用embed引入静态文件

Go 语言从 1.16 版本开始引入了一个新的标准库 embed,可以在二进制文件中引入静态文件 指令:/go:embed 通过一个简单的小实例,来演示将静态文件引入到golang的二进制打包产物中 项目结构 $ tree . ├── main.go └── static└── he…

max6675热电偶温度采集

思路来源 参考价格 概述 MAX6675具有冷端补偿和将来自K型热电偶的信号数字化。数据以12位分辨率输出,SPI™兼容, 只读格式。该转换器将温度分解为0.25C,允许读数高达1024C,并显示热电偶8LSB在0C至 700C 引脚连接 温度采样电路 …

中间件复习之-消息队列

消息队列在分布式架构的作用 消息队列:在消息的传输过程中保存消息的容器,生产者和消费者不直接通讯,依靠队列保证消息的可靠性,避免了系统间的相互影响。 主要作用: 业务解耦异步调用流量削峰 业务解耦 将模块间的…

python中正则表达式学习

文章目录 介绍基本语法常用函数捕获组和命名组非捕获组贪婪匹配和非贪婪匹配多行模式和点匹配所有模式示例总结 介绍 Python 中的正则表达式(regular expressions, 简称 regex)由 re 模块提供。正则表达式是一种用于匹配字符串的强大工具,常…

MySQL之创建高性能的索引(八)

创建高性能的索引 覆盖索引 通常大家都会根据查询的WHERE条件来创建合适的索引,不过这只是索引优化的一个方面。设计优秀的索引应该考虑到整个查询,而不单单是WHERE条件部分。索引确实是一种查找数据的高效方式,但是MySQL也可以使用索引来直…

向量数据库引领 AI 创新——Zilliz 亮相 2024 亚马逊云科技中国峰会

2024年5月29日,亚马逊云科技中国峰会在上海召开,此次峰会聚集了来自全球各地的科技领袖、行业专家和创新企业,探讨云计算、大数据、人工智能等前沿技术的发展趋势和应用场景。作为领先的向量数据库技术公司,Zilliz 在本次峰会上展…

【漏洞复现】电信网关配置管理系统 rewrite.php 文件上传漏洞

0x01 产品简介 中国电信集团有限公司(英文名称"China Telecom”、简称“"中国电信”)成立于2000年9月,是中国特大型国有通信企业、上海世博会全球合作伙伴。电信网关配置管理系统是一个用于管理和配置电信网络中网关设备的软件系统。它可以帮助网络管理员…

在线IP检测如何做?代理IP需要检查什么?

当我们的数字足迹无处不在,隐私保护显得愈发重要。而代理IP就像是我们的隐身斗篷,让我们在各项网络业务中更加顺畅。 我们常常看到别人购买了代理IP服务后,通在线检测网站检查IP,相当于一个”售前检验““售后质检”的作用。但是…

2024-5-31 石群电路-19

2024-5-31,星期五,10:53,天气:阴雨,心情:晴。今天就要回学校啦,当大家看到这篇推文的时候我已经要收拾收拾去赶返校的火车啦,和女朋友短暂分别,不过小别胜新婚吗&#xf…

css动画效果(边框流光闪烁阴影效果)

1.整体效果 https://mmbiz.qpic.cn/sz_mmbiz_gif/EGZdlrTDJa7odDQYuaatklJUMc5anU10PWLAt14rNnNUD6oHJG9U63fc0yibiapuDViatVk62ma3K63oqQ3U1VtMQ/640?wx_fmtgif&fromappmsg&wxfrom13 CSS边框流光闪烁阴影动画效果是一种令人印象深刻的技术,它通过动态的光…

笔记-docker基于ubuntu22.04安装Jitsi Meet

背景 利用JitsiMeet打造一个可以在线会议的环境,根据躺的坑,做个记录 参考 JitsMeet部署安装说明 开始操作 环境 docker run -it --name ubuntu22.04 ubuntu:22.04 /bin/bash问题 1、安装 openjdk-11 apt install openjdk-11-jdk配置环境变量&…

es初始化

一.初始化es public void initES() {/*LOGGER.info("host" host);LOGGER.info("port" port);LOGGER.info("scheme" scheme);LOGGER.info("userName" userName);LOGGER.info("password" password);*/// 客户端连接创建…