深度学习-torch,全连接神经网路

3. 数据集加载案例

通过一些数据集的加载案例,真正了解数据类及数据加载器。

3.1 加载csv数据集

代码参考如下

import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
​
​
class MyCsvDataset(Dataset):def __init__(self, filename):df = pd.read_csv(filename)# 删除文字列df = df.drop(["学号", "姓名"], axis=1)# 转换为tensordata = torch.tensor(df.values)# 最后一列以前的为data,最后一列为labelself.data = data[:, :-1]self.label = data[:, -1]self.len = len(self.data)
​def __len__(self):return self.len
​def __getitem__(self, index):idx = min(max(index, 0), self.len - 1)return self.data[idx], self.label[idx]
​
​
def test001():excel_path = r"./大数据答辩成绩表.csv"dataset = MyCsvDataset(excel_path)dataloader = DataLoader(dataset, batch_size=4, shuffle=True)for i, (data, label) in enumerate(dataloader):print(i, data, label)
​
​
if __name__ == "__main__":test001()
​

练习:上述示例数据构建器改成TensorDataset

def build_dataset(filepath):df = pd.read_csv(filepath)df.drop(columns=['学号', '姓名'], inplace=True)data = df.iloc[..., :-1]labels = df.iloc[..., -1]
​x = torch.tensor(data.values, dtype=torch.float)y = torch.tensor(labels.values)
​dataset = TensorDataset(x, y)
​return dataset
​
​
def test001():filepath = r"./大数据答辩成绩表.csv"dataset = build_dataset(filepath)dataloader = DataLoader(dataset, batch_size=4, shuffle=True)for i, (data, label) in enumerate(dataloader):print(i, data, label)

3.2 加载图片数据集

参考代码如下:只是用于文件读取测试

import torch
from torch.utils.data import Dataset, DataLoader
import os
​
# 导入opencv
import cv2
​
​
class MyImageDataset(Dataset):def __init__(self, folder):# 文件存储路径列表self.filepaths = []# 文件对应的目录序号列表self.labels = []# 指定图片大小self.imgsize = (112, 112)# 临时存储文件所在目录名dirnames = []
​# 递归遍历目录,root:根目录路径,dirs:子目录名称,files:子文件名称for root, dirs, files in os.walk(folder):# 如果dirs和files不同时有值,先遍历dirs,然后再以dirs的目录为路径遍历该dirs下的files# 这里需要在dirs不为空时保存目录名称列表if len(dirs) > 0:dirnames = dirs
​for file in files:# 文件路径filepath = os.path.join(root, file)self.filepaths.append(filepath)# 分割root中的dir目录名classname = os.path.split(root)[-1]# 根据目录名到临时目录列表中获取下标self.labels.append(dirnames.index(classname))self.len = len(self.filepaths)
​def __len__(self):return self.len
​def __getitem__(self, index):# 获取下标idx = min(max(index, 0), self.len - 1)# 根据下标获取文件路径filepath = self.filepaths[idx]# opencv读取图片img = cv2.imread(filepath)# 图片缩放,图片加载器要求同一批次的图片大小一致img = cv2.resize(img, self.imgsize)# 转换为tensorimg_tensor = torch.tensor(img)# 将图片HWC调整为CHWimg_tensor = torch.permute(img_tensor, (2, 0, 1))# 获取目录标签label = self.labels[idx]
​return img_tensor, label
​
​
def test02():path = os.path.join(os.path.dirname(__file__), 'dataset')# 转换为相对路径path = os.path.relpath(path)dataset = MyImageDataset(path)
​dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
​for img, label in dataloader:print(img.shape)print(label)
​
​
if __name__ == "__main__":test02()
​

练习:1.重写上述代码,如果不对图片进行缩放会产生什么结果?2.在遍历图片的代码中打印图片查看图片效果(打印一批次即可)

# 导入opencv
import cv2
​
​
class MyDataset(Dataset):def __init__(self, folder):
​dirnames = []self.filepaths = []self.labels = []
​for root, dirs, files in os.walk(folder):if len(dirs) > 0:dirnames = dirs
​for file in files:filepath = os.path.join(root, file)self.filepaths.append(filepath)classname = os.path.split(root)[-1]if classname in dirnames:self.labels.append(dirnames.index(classname))else:print(f'{classname} not in {dirnames}')
​self.len = len(self.filepaths)
​def __len__(self):return self.len
​def __getitem__(self, index):idx = min(max(index, 0), self.len - 1)filepath = self.filepaths[idx]img = cv2.imread(filepath)print(img.shape)# 不做图片缩放,报:RuntimeError: stack expects each tensor to be equal size, but got [3, 1333, 2000] at entry 0 and [3, 335, 600] at entry 1img = cv2.resize(img, (112, 112))t_img = torch.tensor(img)t_img = torch.permute(t_img, (2, 0, 1))
​label = self.labels[idx]return t_img, label
​
​
def test02():path = os.path.join(os.path.dirname(__file__), 'dataset')dataset = MyDataset(path)dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
​for img, label in dataloader:
​print(img.shape, label)for i in range(img.shape[0]):im = torch.permute(img[i], (1, 2, 0))plt.imshow(im)plt.show()
​break
​
​
if __name__ == "__main__":test02()

优化:使用ImageFolder加载图片集

ImageFolder 会根据文件夹的结构来加载图像数据。它假设每个子文件夹对应一个类别,文件夹名称即为类别名称。例如,一个典型的文件夹结构如下:

root/class1/img1.jpgimg2.jpg...class2/img1.jpgimg2.jpg......

在这个结构中:

  • root 是根目录。

  • class1class2 等是类别名称。

  • 每个类别文件夹中的图像文件会被加载为一个样本。

ImageFolder构造函数如下:

torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, is_valid_file=None)

参数解释

  • root:字符串,指定图像数据集的根目录。

  • transform:可选参数,用于对图像进行预处理。通常是一个 torchvision.transforms 的组合。

  • target_transform:可选参数,用于对目标(标签)进行转换。

  • is_valid_file:可选参数,用于过滤无效文件。如果提供,只有返回 True 的文件才会被加载。

import torch
from torchvision import datasets, transforms
import os
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
​
torch.manual_seed(42)
​
def load():path = os.path.join(os.path.dirname(__file__), 'dataset')print(path)
​transform = transforms.Compose([transforms.Resize((112, 112)),transforms.ToTensor()])
​dataset = datasets.ImageFolder(path, transform=transform)dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
​for x,y in dataloader:x = x.squeeze(0).permute(1, 2, 0).numpy()plt.imshow(x)plt.show()print(y[0])break
​
​
if __name__ == '__main__':load()
​

3.3 加载官方数据集

在 PyTorch 中官方提供了一些经典的数据集,如 CIFAR-10、MNIST、ImageNet 等,可以直接使用这些数据集进行训练和测试。

数据集:Datasets — Torchvision 0.21 documentation

常见数据集:

  • MNIST: 手写数字数据集,包含 60,000 张训练图像和 10,000 张测试图像。

  • CIFAR10: 包含 10 个类别的 60,000 张 32x32 彩色图像,每个类别 6,000 张图像。

  • CIFAR100: 包含 100 个类别的 60,000 张 32x32 彩色图像,每个类别 600 张图像。

  • COCO: 通用对象识别数据集,包含超过 330,000 张图像,涵盖 80 个对象类别。

torchvision.transforms 和 torchvision.datasets 是 PyTorch 中处理计算机视觉任务的两个核心模块,它们为图像数据的预处理和标准数据集的加载提供了强大支持。

transforms 模块提供了一系列用于图像预处理的工具,可以将多个变换组合成处理流水线。

datasets 模块提供了多种常用计算机视觉数据集的接口,可以方便地下载和加载。

参考如下:

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, datasets
​
​
def test():transform = transforms.Compose([transforms.ToTensor(),])# 训练数据集data_train = datasets.MNIST(root="./data",train=True,download=True,transform=transform,)trainloader = DataLoader(data_train, batch_size=8, shuffle=True)for x, y in trainloader:print(x.shape)print(y)break
​# 测试数据集data_test = datasets.MNIST(root="./data",train=False,download=True,transform=transform,)testloader = DataLoader(data_test, batch_size=8, shuffle=True)for x, y in testloader:print(x.shape)print(y)break
​
​
def test006():transform = transforms.Compose([transforms.ToTensor(),])# 训练数据集data_train = datasets.CIFAR10(root="./data",train=True,download=True,transform=transform,)trainloader = DataLoader(data_train, batch_size=4, shuffle=True, num_workers=2)for x, y in trainloader:print(x.shape)print(y)break# 测试数据集data_test = datasets.CIFAR10(root="./data",train=False,download=True,transform=transform,)testloader = DataLoader(data_test, batch_size=4, shuffle=False, num_workers=2)for x, y in testloader:print(x.shape)print(y)break
​
​
if __name__ == "__main__":test()test006()
​

1. 神经网络基础

1.1 生物神经元与人工神经元

神经网络的设计灵感来源于生物神经元。生物神经元通过树突接收信号,细胞核处理信号,轴突传递信号,突触连接不同的神经元。人工神经元模仿了这一过程,接收多个输入信号,经过加权求和和非线性激活函数处理后,输出结果。

1.2 人工神经元的组成

人工神经元由以下几个部分组成:

  • 输入(Inputs)​:代表输入数据,通常用向量表示。
  • 权重(Weights)​:每个输入数据都有一个权重,表示该输入对最终结果的重要性。
  • 偏置(Bias)​:一个额外的可调参数,用于调整模型的输出。
  • 加权求和:将输入乘以对应的权重后求和,再加上偏置。
  • 激活函数(Activation Function)​:将加权求和后的结果转换为输出结果,引入非线性特性。

数学表示如下:

其中,σ(z) 是激活函数。


2. 神经网络结构

2.1 基本结构

神经网络由以下三层构成:

  • 输入层(Input Layer)​:接收外部数据,不进行计算。
  • 隐藏层(Hidden Layer)​:位于输入层和输出层之间,进行特征提取和转换。隐藏层可以有多层,每层包含多个神经元。
  • 输出层(Output Layer)​:产生最终的预测结果或分类结果。

2.2 全连接神经网络

全连接神经网络(Fully Connected Neural Network,FCNN)是前馈神经网络的一种,每一层的神经元与上一层的所有神经元全连接。全连接神经网络常用于图像分类、文本分类等任务。

2.2.1 特点
  • 权重数量大:由于全连接的特点,权重数量较大,计算量大。
  • 学习能力强:能够学习输入数据的全局特征,但对高维数据的局部特征捕捉能力较弱。
2.2.2 计算步骤
  1. 数据传递:输入数据逐层传递到输出层。
  2. 激活函数:每一层的输出通过激活函数处理。
  3. 损失计算:计算预测值与真实值之间的差距。
  4. 反向传播:通过反向传播算法更新权重以最小化损失。

3. 激活函数

激活函数在神经网络中引入非线性,使网络能够处理复杂的任务。以下是几种常见的激活函数及其特点。

3.1 Sigmoid

3.1.1 公式

3.1.2 特点
  • 将输入映射到 (0, 1) 之间,适合处理概率问题。
  • 梯度消失问题严重,容易导致训练速度变慢。
  • 计算成本较高。
3.1.3 应用场景
  • 一般用于二分类问题的输出层。

3.2 Tanh

3.2.1 公式

3.2.2 特点
  • 输出范围为 (-1, 1),是零中心的,有助于加速收敛。
  • 对称性较好,适合隐藏层。
  • 仍然存在梯度消失问题。
3.2.3 应用场景
  • 适用于隐藏层,但不如 ReLU 常用。

3.3 ReLU

3.3.1 公式

3.3.2 特点
  • 计算简单,适合大规模数据训练。
  • 缓解梯度消失问题,适合深层网络。
  • 存在神经元死亡问题,即某些神经元可能永远不被激活。
3.3.3 应用场景
  • 深度学习中最常用的激活函数,适用于隐藏层。

3.4 Leaky ReLU

3.4.1 公式

3.4.2 特点
  • 解决了 ReLU 的神经元死亡问题。
  • 计算简单,但需要调整超参数 α。
3.4.3 应用场景
  • 适用于隐藏层,尤其是 ReLU 效果不佳时。

3.5 Softmax

3.5.1 公式

3.5.2 特点
  • 将输出转化为概率分布,适合多分类问题。
  • 放大差异,使概率最大的类别更突出。
  • 存在数值不稳定性问题,需进行数值调整。
3.5.3 应用场景
  • 用于多分类问题的输出层。

4. 激活函数的选择

4.1 隐藏层

  1. 优先选择 ReLU。
  2. 如果 ReLU 效果不佳,尝试 Leaky ReLU 或其他激活函数。
  3. 避免使用 Sigmoid,可以尝试 Tanh。

4.2 输出层

  1. 二分类问题选择 Sigmoid。
  2. 多分类问题选择 Softmax。

5. 总结

神经网络是深度学习的核心,理解其结构和激活函数的作用至关重要。人工神经元是神经网络的基本单元,通过加权求和和激活函数实现非线性变换。全连接神经网络是最基本的神经网络结构,广泛应用于各类任务。激活函数在神经网络中引入非线性,增强了网络的表达能力。不同激活函数适用于不同的场景,合理选择激活函数可以显著提升模型性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/902022.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++/Python实现RGB和HSI相互转换

1--C版本 #include <opencv2/opencv.hpp> #include <iostream> #include <cmath>// RGB to HSI cv::Vec3f RGBtoHSI(cv::Vec3b rgb) {float B rgb[0] / 255.0f;float G rgb[1] / 255.0f;float R rgb[2] / 255.0f;float num 0.5f * ((R - G) (R - B));f…

【Linux我做主】make和makefile自动化构建

make和makefile自动化构建 make和makefile自动化构建github地址前言背景介绍为什么需要make和makefile&#xff1f; make和makefile解析什么是make和makefile依赖关系和依赖方法核心语法结构简单演示编译清理 多阶段编译示例 make时执行的顺序场景1&#xff1a;clean目标在前(特…

Qt 入门 5 之其他窗口部件

Qt 入门 5 之其他窗口部件 本文介绍的窗口部件直接或间接继承自 QWidget 类详细介绍其他部件的功能与使用方法 1. QFrame 类 QFrame类是带有边框的部件的基类。它的子类包括最常用的标签部件QLabel另外还有 QLCDNumber、QSplitter,QStackedWidget,QToolBox 和 QAbstractScrol…

JAVA学习-多线程

线程 线程(Thread)是一个程序内部的一条执行流程。 程序中如果只有一条执行流程&#xff0c;那这个程序就是单线程的程序。 线程的常用方法及构造器&#xff1a; Thread提供的常用方法public void run() 线程的任务方法public void start() 启动线程public String getName() …

Github 2025-04-19Rust开源项目日报 Top10

根据Github Trendings的统计,今日(2025-04-19统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Rust项目10Python项目1Rust: 构建可靠高效软件的开源项目 创建周期:5064 天开发语言:Rust协议类型:OtherStar数量:92978 个Fork数量:12000…

OpenLayers:视图变换的方法

一、什么是视图变换&#xff1f; 视图变换就是指视图的 extent&#xff08;范围&#xff09;、center&#xff08;中心点&#xff09;、zoom&#xff08;缩放级别&#xff09;、 resolution&#xff08;分辨率&#xff09;、rotation&#xff08;旋转角&#xff09;等参数发生…

数字孪生火星探测车,星际探索可视化

运用图扑构建数字孪生火星探测车&#xff0c;高精度还原外观与内部构造。实时映射探测车在火星表面的移动、探测作业及设备状态&#xff0c;助力科研人员远程监测、分析数据&#xff0c;为火星探索任务提供可视化决策支持。

【NLP 66、实践 ⑰ 基于Agent + Prompt Engineering文章阅读】

你用什么擦干我的眼泪 莎士比亚全集 工业纸巾 还是你同样泛红的眼睛 —— 4.19 一、⭐【核心函数】定义大模型调用函数 call_large_model prompt&#xff1a;用户传入的提示词&#xff08;如 “请分析这篇作文的主题”&#xff09;&#xff0c;指导模型执行任务 client&…

黑马Java基础笔记-1

JVM&#xff0c;JDK和JRE JDK是java的开发环境 JVM虚拟机&#xff1a;Java程序运行的地方 核心类库&#xff1a;Java已经写好的东西&#xff0c;我们可以直接用。 System.out.print中的这些方法就是核心库中的所包含的 开发工具: javac&#xff08;编译工具&#xff09;、java&…

PR第一课

目录 1.新建 2.PR内部设置 3.导入素材 4.关于素材窗口 5.关于编辑窗口 6.序列的创建 7.视频、图片、音乐 7.1 带有透明通道的素材 8.导出作品 8.1 打开方法 8.2 导出时&#xff0c;需要修改的参数 1.新建 2.PR内部设置 随意点开 编辑->首选项 中的任意内容&a…

Xcode16 调整 Provisioning Profiles 目录导致证书查不到

cronet demo 使用的 ninja 打包&#xff0c;查找 Provisioning Profiles 路径是 ~/Library/MobileDevice/Provisioning Profiles&#xff0c;但 Xcode16 把该路径改为了 ~/Library/Developer/Xcode/UserData/Provisioning Profiles&#xff0c;导致在编译 cronet 的demo 时找不…

【更新完毕】2025华中杯C题数学建模网络挑战赛思路代码文章教学数学建模思路:就业状态分析与预测

完整内容请看文末最后的推广群 先展示文章和代码、再给出四个问题详细的模型 基于多模型下的就业状态研究 摘要 随着全球经济一体化和信息技术的迅猛发展&#xff0c;失业问题和就业市场的匹配性问题愈加突出。为了解决这一问题&#xff0c;本文提出了一种基于统计学习和机器学…

[HOT 100] 1964. 找出到每个位置为止最长的有效障碍赛跑路线

文章目录 1. 题目链接2. 题目描述3. 题目示例4. 解题思路5. 题解代码6. 复杂度分析 1. 题目链接 1964. 找出到每个位置为止最长的有效障碍赛跑路线 - 力扣&#xff08;LeetCode&#xff09; 2. 题目描述 你打算构建一些障碍赛跑路线。给你一个 下标从 0 开始 的整数数组 obst…

2025年KBS SCI1区TOP:增强天鹰算法EBAO,深度解析+性能实测

目录 1.摘要2.天鹰算法AO原理3.改进策略4.结果展示5.参考文献6.代码获取 1.摘要 本文提出了增强二进制天鹰算法&#xff08;EBAO&#xff09;&#xff0c;针对无线传感器网络&#xff08;WSNs&#xff09;中的入侵检测系统&#xff08;IDSs&#xff09;。由于WSNs的特点是规模…

JavaScript数据类型简介

在JavaScript中&#xff0c;理解不同的数据类型是掌握这门语言的基础。数据类型决定了变量可以存储什么样的值以及这些值能够执行的操作。JavaScript支持多种数据类型&#xff0c;每种都有其特定的用途和特点。本文将详细介绍JavaScript中的主要数据类型&#xff0c;并提供一些…

性能比拼: Elixir vs Go(第二轮)

本内容是对知名性能评测博主 Anton Putra Elixir vs Go (Golang) Performance Benchmark (Round 2) 内容的翻译与整理, 有适当删减, 相关指标和结论以原作为准 这是第二轮关于 Elixir 和 Go 的对比测试。我收到了一份来自 Elixir 创作者的 Pull Request &#xff0c;并且我认为…

接口自动化 ——fixture allure

一.参数化实现数据驱动 上一篇介绍了参数化&#xff0c;这篇 说说用参数化实现数据驱动。在有很多测试用例的时候&#xff0c;可以将测试用例都存储在文件里&#xff0c;进行读写调用。本篇主要介绍 csv 文件和 json 文件。 1.读取 csv 文件数据 首先创建 csv 文件&#xff…

`peft`(Parameter-Efficient Fine-Tuning:高效微调)是什么

peft(Parameter-Efficient Fine-Tuning:高效微调)是什么 peft库是Hugging Face推出的用于高效参数微调的库,它能在不调整模型全部参数的情况下,以较少的可训练参数对预训练模型进行微调,从而显著降低计算资源需求和微调成本。以下从核心功能、优势、常见微调方法、使用场…

编程常见错误归类

上一篇讲了调试&#xff0c;今天通过一个举例回忆一下上一篇内容吧&#xff01; 1. 回顾&#xff1a;调试举例 在VS2022、X86、Debug的环境下&#xff0c;编译器不做任何优化的话&#xff0c;下⾯代码执⾏的结果是啥&#xff1f; #include <stdio.h> int main() {int …

在windows上交叉编译opencv供RK3588使用

环境 NDK r27、RK3588 安卓板子、Android 12 步骤操作要点1. NDK 下载选择 r27 版本&#xff0c;解压到无空格路径&#xff08;如 C:/ndk&#xff09;2. 环境变量配置添加 ANDROID_NDK_ROOT 和工具链路径到系统 PATH3. CMake 参数调整指定 ANDROID_NATIVE_API_LEVEL31、ANDRO…