PyTorch下的5种不同神经网络-ResNet

1.导入模块

导入所需的Python库,包括图像处理、深度学习模型和数据加载

import osimport torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import Dataset, DataLoaderfrom PIL import Imagefrom torchvision import models, transforms

2.定义自定义图像数据集:

创建一个自定义的图像数据集类,用于加载和处理图像数据

class CustomImageDataset(Dataset):def __init__(self, main_dir, transform=None):self.main_dir = main_dirself.transform = transformself.files = []self.labels = []self.label_to_index = {}for index, label in enumerate(os.listdir(main_dir)):self.label_to_index[label] = indexlabel_dir = os.path.join(main_dir, label)if os.path.isdir(label_dir):for file in os.listdir(label_dir):self.files.append(os.path.join(label_dir, file))self.labels.append(label)def __len__(self):return len(self.files)def __getitem__(self, idx):image = Image.open(self.files[idx])label = self.labels[idx]if self.transform:image = self.transform(image)return image, self.label_to_index[label]

3.定义数据转换

定义一个数据转换过程,包括图像大小调整、转换为张量以及标准化

transform = transforms.Compose([transforms.Resize((224, 224)),  # ResNet的输入图像大小transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # 标准化])

4.创建数据集

使用自定义数据集类和定义的数据转换来创建数据集

dataset = CustomImageDataset(main_dir="F:\\A-GX\\A-SJJ\\flower_photos\\flower_photos", transform=transform)

5.创建数据加载器

使用数据集创建一个数据加载器,用于批量加载和处理数据。

data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

6.加载预训练的ResNet模型

从PyTorch库中加载预训练的ResNet18模型

resnet_model = models.resnet18(pretrained=True)

7.修改最后几层以适应新的分类任务

修改ResNet模型的最后几层,以便它能够处理新的分类任务

num_ftrs = resnet_model.fc.in_featuresresnet_model.fc = nn.Linear(num_ftrs, len(dataset.label_to_index))

8.定义损失函数和优化器

定义用于训练模型的损失函数和优化器

criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(resnet_model.parameters(), lr=0.001)

9.模型并行化

如果有多GPU,则使用nn.DataParallel来并行化模型

if torch.cuda.device_count() > 1:resnet_model = nn.DataParallel(resnet_model)

10.将模型发送到GPU

模型发送到GPU进行训练

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")resnet_model.to(device)

11.训练模型

使用数据加载器和定义的参数训练模型

num_epochs = 10for epoch in range(num_epochs):resnet_model.train()running_loss = 0.0for images, labels in data_loader:images, labels = images.to(device), labels.to(device)# 前向传播outputs = resnet_model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/857297.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Interpreting Machine Learning Models with SHAP: A Comprehensive Guide

Interpreting Machine Learning Models with SHAP: A Comprehensive Guide DateAuthorVersionNote2024.06.20Dog TaoV1.0Finish the document. 文章目录 Interpreting Machine Learning Models with SHAP: A Comprehensive GuideWhat is SHAPUnderstanding Base ValueDefiniti…

虚拟机拖拽文档造成缓存过大

查看文件夹大小:du -h --max-depth1 缓存位置:~/.cache/vmware/drag_and_drop 删除:rm -fr ~/.cache/vmware/drag_and_drop 释放了3GB

自然语言NLP的基础处理

NLP基本处理从句子的情感分析、实体与实体直接的关系,句子结构来分析 情感分析 1.句子的情感分析找出句子表达的是正面、负面还是中性的情感。 情感分析的影响因素: 词语顺序:词语的顺序可以影响句子的整体情感。例如,“我喜欢…

网络安全:Web 安全 面试题.(文件上传漏洞)

网络安全:Web 安全 面试题.(文件上传漏洞) 网络安全面试是指在招聘过程中,面试官会针对应聘者的网络安全相关知识和技能进行评估和考察。这种面试通常包括以下几个方面: (1)基础知识:包括网络基础知识、操…

CVPR上新 | 从新视角合成、视频编解码器、人体姿态估计,到文本布局分析,微软亚洲研究院精选论文

编者按:欢迎阅读“科研上新”栏目!“科研上新”汇聚了微软亚洲研究院最新的创新成果与科研动态。在这里,你可以快速浏览研究院的亮点资讯,保持对前沿领域的敏锐嗅觉,同时也能找到先进实用的开源工具。 本周&#xff0…

python如何判断图片是否为空

如下所示: import cv2im cv2.imread(2.jpg) if im is None:print("图像为空") # cv2.imshow("ss", im) # cv2.waitKey(0)

编码规则UTF-8 和 UTF-16的区别

UTF-8 和 UTF-16 的设计背景与历史 为了更好地理解 UTF-8 和 UTF-16 的设计选择和背景,以下是两种编码方案的历史、设计动机和它们在计算机科学中的应用。 Unicode 的背景 在 Unicode 之前,不同的字符集和编码方案使得跨平台和国际化的文本处理变得复…

2024年AI+游戏赛道的公司和工具归类总结

随着人工智能技术的飞速发展,AI在游戏开发领域的应用越来越广泛。以下是对2024年AI+游戏赛道的公司和工具的归类总结,涵盖了从角色和场景设计到音频制作,再到动作捕捉和动画生成等多个方面。 2D与3D创作 2D创作工具:专注于角色和场景的平面设计,提供AI辅助的图案生成和风…

C++之thread_local变量

目录 1.C 的存储类型 1.1.存储周期(Storage duration) 1.2.存储类型说明符(Storage class specifiers) 1.3.存储类型说明符与存储周期的关系 2.thread_local简介 3.thread_local 应用 3.1.thread_local 与全局变量 3.2.th…

粘包拆包服务器

服务器&#xff1a; 创建个控制台应用 创建Server.cs internal class Server{TcpListener listen;public Server(IPAddress ip,int port) {listen new TcpListener(ip, port);}public void Start(){listen.Start(100);StartConnect(); }Dictionary<string,TcpClient>…

【2024德国工作】外国人在德国找工作是什么体验?

挺难的&#xff0c;德语应该是所有中国人的难点。大部分中国人进德国公司要么是做中国业务相关&#xff0c;要么是做技术领域的工程师。先讲讲人在中国怎么找德国的工作&#xff0c;顺便延申下&#xff0c;德国工作的真实体验&#xff0c;最后聊聊在今年的德国工作签证申请条件…

秀米排版的一些技巧

1.正文一般16字号 、默认字体、格式首行缩进 2.最后署名&#xff08;居中&#xff09; 文丨1234 图丨1234 排版丨1234 指导老师 | 1234 审核 |1234 信息学院研究生会宣传中心 宣 3.不必要的文字要删除 以及不必要的排版的画面 简简单单就ok 4.然后图片文字按顺序 5.最开始有个框…

Android AlarmManager 设定过去的时间会触发事件

Android AlarmManager 设定过去的时间会触发事件 在使用 AlarmManager 做每日定时任务时&#xff0c;发现如果设定的时间小于当前的系统时间&#xff0c;那么设定后会立刻收到一次定时任务回调。 我们设想的是设定的时间应该是明日的这个时间&#xff0c;但是如果打印出设定的…

【八股系列】说一下mobx和redux有什么区别?(React)

&#x1f389; 博客主页&#xff1a;【剑九 六千里-CSDN博客】 &#x1f3a8; 上一篇文章&#xff1a;【介绍React高阶组件&#xff0c;适用于什么场景&#xff1f;】 &#x1f3a0; 系列专栏&#xff1a;【面试题-八股系列】 &#x1f496; 感谢大家点赞&#x1f44d;收藏⭐评…

现代数字信号处理及其应用-常见结论

现代数字信号处理及其应用-常见结论 本文的结论均摘抄自 何子述、夏威等编著&#xff0c;《现代数字信号处理及其应用》&#xff0c;清华出版社出版。 解析信号信号预包络&#xff1b;基带信号信号复包络。BT法&#xff08;自相关谱估计法&#xff09;&#xff1a;间接法&…

双例集合(二)——双例集合的实现类之HashMap容器类

双例集合的常用实现类有HashMap和TreeMap两个&#xff0c;通过这两个类我们可以实现Map接口定义的容器&#xff0c;一般情况下使用HashMap容器类较多。 HashMap容器类是Map接口最常用的实现类&#xff0c;它的底层采用Hash算法来实现&#xff0c;这也就满足了键key不能重复的要…

Python:调用zabbix api,删除部分被监控主机

调用zabbix api&#xff0c;删除部分被监控主机。 简介代码部分配置文件config.jsonnamefile.txt 简介 当新主机上线时&#xff0c;我们可以通过自动注册功能&#xff0c;在zabbix中批量添加这些新主机。那当有主机需要下线时&#xff0c;我们又该如何在zabbix中批量删除这些主…

揭秘!速卖通、敦煌网、国际站出单背后的黑科技:自养号测评技术

在竞争激烈的跨境电商平台上&#xff0c;如亚马逊、速卖通、Lazada、Shopee、敦煌网、Temu、Shein、美客多和阿里国际等&#xff0c;稳定出单成为每位卖家共同追求的目标。为了实现这一目标&#xff0c;卖家需要从产品选择、运营策略和客户服务等多个维度进行全面考量&#xff…

华为重磅官宣:超9亿台、5000个头部应用已加入鸿蒙生态!人形机器人现身 专注AI芯片!英伟达挑战者Cerebras要上市了

内容提要 华为表示&#xff0c;盘古大模型5.0加持&#xff0c;小艺能力全新升级。小艺智能体与导航条融为一体&#xff0c;无处不在&#xff0c;随时召唤。只需将文字、图片、文档“投喂”小艺&#xff0c;即可便捷高效处理文字、识别图像、分析文档。 正文 据华为终端官方微…

采用string 及random库随机生成长度为32的字符串

要使用Python的string和random库来生成一个长度为32的随机字符串&#xff0c;其中包含大小写字母和数字&#xff0c;你可以按照以下方式编写代码&#xff1a; import string import random def generate_random_string(length32): """生成一个指定长度的随…