pytorch第66页

news/2025/10/22 14:25:15/文章来源:https://www.cnblogs.com/shetingting/p/19157888
点击查看代码
import torch
from torch import optim, nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import classification_report
from PIL import Image
import time
from matplotlib import pyplot as pltdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#数据加载
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])# 加载CIFAR-10数据集
def load_data():train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)return train_loader, test_loader, test_dataset#定义MYVGG模型
class MYVGG(nn.Module):def __init__(self, num_classes=10):super(MYVGG, self).__init__()self.features = nn.Sequential(# Block 1nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),# Block 2nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(128, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),# Block 3nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),# Block 4nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),# Block 5nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),)self.classifier = nn.Sequential(nn.Dropout(0.5),nn.Linear(512, num_classes))def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)x = self.classifier(x)return x#训练函数
model = MYVGG().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)def train(model, train_loader, criterion, optimizer, epoch_num=50):model.train()train_loss = []train_acc = []for epoch in range(epoch_num):start_time = time.time()running_loss = 0.0current = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(output.data, 1)total += target.size(0)current += (predicted == target).sum().item()epoch_loss = running_loss / len(train_loader)epoch_acc = 100.0 * current / totaltrain_loss.append(epoch_loss)train_acc.append(epoch_acc)end_time = time.time()print(f"Epoch [{epoch+1}/{epoch_num}], Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.2f}%, Time: {end_time-start_time:.2f}s")return train_loss, train_acc#测试函数
def test(model, test_loader):model.eval()all_pred = []all_label = []with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)outputs = model(data)_, predicted = torch.max(outputs.data, 1)all_pred.extend(predicted.cpu().numpy())all_label.extend(target.cpu().numpy())all_pred = np.array(all_pred)all_label = np.array(all_label)accuracy = (all_pred == all_label).mean()accuracy = 100.0 * accuracyprint(f'测试准确率: {accuracy:.4f}%')print("分类效果评估:")target_names = [str(i) for i in range(10)]report = classification_report(all_label, all_pred, target_names=target_names)print(report)if __name__ == '__main__':print(f"24信计2班 佘婷婷 2024310143102")print(f"device:{device}")epoch_num = 20train_loader, test_loader, test_dataset = load_data()train_loss, train_acc = train(model, train_loader, criterion, optimizer, epoch_num)test(model, test_loader)#绘制结果plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.plot(range(1, epoch_num+1), train_loss)plt.title("Training Loss")plt.xlabel("Epoch")plt.ylabel("Loss")plt.subplot(1, 2, 2)plt.plot(range(1, epoch_num+1), train_acc)plt.title("Training Accuracy")plt.xlabel("Epoch")plt.ylabel("Accuracy (%)")plt.tight_layout()plt.show()

image

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/943340.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Navicat Premium 17 官方版下载安装教程|支持MySQL、PostgreSQL、MongoDB等数据库

Navicat Premium是一款功能强大的数据库管理和开发工具,专为PC电脑端设计。它支持多种数据库类型,包括MySQL、SQLite、SQL Server、Oracle、PostgreSQL、MariaDB和MongoDB,让用户能够轻松管理这些不同类型的数据库。…

有什么指标可以判断手机是否降频

1)有什么指标可以判断手机是否降频2)关于降低动画浮点数精度的问这是第449篇UWA技术知识分享的推送,精选了UWA社区的热门话题,涵盖了UWA问答、社区帖子等技术知识点,助力大家更全面地掌握和学习。 UWA社区主页:c…

实用指南:Linux内核kallsyms符号压缩与解压机制

实用指南:Linux内核kallsyms符号压缩与解压机制pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", …

从埋点到用户行为分析:ClkLog 如何帮助企业读懂用户

“理解用户”已成为企业竞争力的关键。越来越多的企业开始关注用户运营,从PV、UV、停留时间等指标入手,但这些数据只能反映业务趋势,无法回答更核心的问题:用户是谁?为什么留下?又为什么流失? 一、为什么企业都…

深入解析:领码方案 | 掌控研发管理成熟度:从理论透视到AI驱动的实战进阶

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

函数的高级

函数的高级-函数的默认参数 在C++中,函数的形参列表中的形参是可以有默认值的 语法:返回值类型 函数名 (参数 = 默认值){ } 如果函数的声明有默认参数值,函数实现就不能有默认参数值 如果某个位置参数有默认值,那…

C#实现OPC客户端

C#实现OPC客户端,结合OPC DA与OPC UA两种协议 一、环境配置与依赖库 1. 基础环境开发工具:Visual Studio 2019+(.NET Framework 4.6+ 或 .NET Core 3.1+)核心库:OPC DA:Interop.OpcDa.dll(需OPC Core Component…

Gitee:数字化转型浪潮中的项目管理利器

Gitee:数字化转型浪潮中的项目管理利器 在数字化转型的浪潮席卷全球之际,企业效率提升已成为核心竞争力。项目管理工具作为这一转型过程中的关键支撑,正迎来前所未有的发展机遇。国际知名调研机构Gartner预测,到20…

zlog2

1."df.isnul()返回一个布尔类型的 DataFrame,其中缺失值被标记为True,非缺失值被标记为 False。 "sum()方法被应用于这个布尔DataFrame,计算每列中缺失值的数量。 2。df.info():提供的信息更全面。df.inf…

C++进阶篇:001

C++进阶篇:001$(".postTitle2").removeClass("postTitle2").addClass("singleposttitle");C++进阶篇:001.linux入门 一、vim编辑器的使用 1.如何打开vim编辑器1. 先touch一个文件,然…

卷积神经网络的读后感

深度探索图像的语言:卷积神经网络读后感 读完关于卷积神经网络的介绍,我仿佛打开了一扇通往图像世界的新视角。在此之前,我一直认为图像的本质只是像素点组成的二维矩阵,而CNN则像一位经验丰富的翻译官,教会我如何…

Calibre 8.11技术拆解:AI集成与二次开发的实战指南 - 教程

Calibre 8.11技术拆解:AI集成与二次开发的实战指南 - 教程2025-10-22 14:14 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !importan…

5G企业应用的七大场景与商业机遇

本文详细介绍了5G在企业领域的七大关键应用场景,包括固定无线接入、医疗健康、传感器系统、网络边缘计算、远程设备控制、汽车行业和智慧城市,分析了5G技术如何推动企业数字化转型和业务创新。尽管5G早期面临价格溢价…

2025 水泥墩源头厂家最新推荐排行榜:光伏 / 围挡 / 交通 / 防撞水泥墩多品类优选,实力品牌权威榜单

引言 水泥墩作为市政基建、光伏电站、交通防护等领域的核心基础建材,其质量直接关系到工程安全与使用寿命。2025 年国内水泥制品行业市场规模预计达 14850 亿元,华东地区占比近 30%,交通、光伏等领域需求持续攀升,…

类的多态(Num020) - 实践

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

2025 年国内活塞杆厂家最新推荐排行榜:聚焦精密 / 不锈钢 / 油缸 / 气缸 / 45# 镀铬类产品,助力企业精准挑选可靠合作方

引言 当前工业自动化进程持续加速,活塞杆作为液压油缸、气缸等关键部件的核心组件,其质量直接关乎设备运行精度与使用寿命。但当下市场中,活塞杆制造商良莠不齐,部分企业产品存在工艺粗糙、精度不足、耐腐蚀性能差…

20232305 2025-2026-1 《网络与系统攻防技术》实验二实验报告

1.实验内容 (1)学习使用netcat监听端口,反弹链接到主机并获得shell; (2)使用netcat在liunx主机上增加一个定时任务,并学习使用socat; (3)使用MSF meterpreter生成可执行文件(后门),利用ncat或socat传送到…

就在Visual Studio Code中配置好C/C++

就在Visual Studio Code中配置好C/C++这篇随笔主要是闲暇时间写的,写这篇随笔的原因有以下两点: 1.唯一的一个粉丝说我好久没更新了。 2.我的直系学弟他说他用Visual Studio Code一直搞不好C,终端也实现不了。 好的…

高效数据结构 - 循环队列

循环队列在游戏开发中通常叫做CircularBuffer、RingBuffer,常用来做数据缓存,生产者/消费者模型等。 在UE中有内置这样的数据结构,而Unity的.Net库中恰恰没有。为什么说这样的结构高效,以双下标循环队列为例。配个…

数据类型,二元运算符,自动类型提升规则,关系运算,取余模运算

数据类型,二元运算符,自动类型提升规则,关系运算,取余模运算数据类型,二元运算符,自动类型提升规则,关系运算,取余模运算 package com.kun.operator;public class Demo1 {public static void main(String[] ar…