李沐深度学习-d2lzh_pytorch模块实现

d2lzh_pytorch 模块

import random
import torch
import matplotlib_inline
from matplotlib import pyplot as plt
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets
import sys
from collections import OrderedDict# ---------------------------------------------------------------------------------------------
# 图表展示
def use_svg_display():# 用矢量图表示matplotlib_inline.backend_inline.set_matplotlib_formats('svg')def set_figsize(figsize=(3.5, 2.5)):use_svg_display()# 设置图的尺寸plt.rcParams['figure.figsize'] = figsize# ---------------------------------------------------------------------------------------------
# 读取数据
# 获取总的样本数量,然后打乱顺序,用batch-size获取每一部分索引去索引对应样本中的数据,使用yield返回
'''
函数详解:
torch.linspace(start, end, steps, dtype) → Tensor  从start开始到end结束,生成steps个数据点,数据类型为dtype
torch.index_select(input, dim, index)   索引张量中的子集
**input:需要进行索引操作的输入张量dim:张量维度  0,1index:索引号,是张量类型
**
yield: 使用yield的函数返回迭代器对象,每次使用时会保存变量信息,使用next()或者使用for可以循环访问迭代器中的内容
'''def data_iter(batch_size, features, labels):num_examples = len(features)  # features   nxmindices = list(range(num_examples))  # 借助range生成索引序列random.shuffle(indices)  # 把list列表中的值打乱顺序for i in range(0, num_examples, batch_size):j = torch.LongTensor(indices[i:min(i + batch_size, num_examples)])  # 这里的i是对标乱序表中的下标索引号yield features.index_select(0, j), labels.index_select(0, j)  # 0维度,有1000个样本,j就是他们的下标# ---------------------------------------------------------------------------------------------# 定义模型
def linreg(X, w, b):return torch.mm(X, w) + b  # 传进来的参数和样本特征都符合矩阵形式 w,b都是列矩阵  X:1000x2  w:2x1  b:1x1# 这里使用了广播# ---------------------------------------------------------------------------------------------# 定义损失函数
def square_loss(y_hat, y):# 保证y_hat和y同型,pytorch中的MSELoss没有除以2的操作return (y_hat - y.view(y_hat.size())) ** 2 / 2# 这里的得到的也是一个小批量的样本的损失张量# ---------------------------------------------------------------------------------------------
# 定义优化算法
# 这里使用的是sgd算法,使用小批量梯度和(参数求导后的和:梯度会自动累加,不用自己加和梯度)除以小批量样本个数来求小批量平均值
def sgd(params, lr, batch_size):for param in params:param.data -= lr * param.grad / batch_size  # 这里更改param时使用的是param.data,这样就不会影响反向梯度# 这里的param指的是w1,w2,b# 这里应该是小批量中的每个loss运行完,得到小批量每个样本的梯度然后pytorch自动进行了梯度累加,之后一个小批量得到一个累加和后的
# 梯度w1,w2,b
# ---------------------------------------------------------------------------------------------'''
FashionMNIST 数据集
'''# ----------------------------------------------------------将数值标签转换成文本标签
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal','shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]# -----------------------------------------------------在一行里画出多张图像和对应标签的函数
def show_fashion_mnist(images, labels):use_svg_display()# 这里的_表示忽略(不使用)的变量_, figs = plt.subplots(1, len(images), figsize=(12, 12))  # 设置一行 len(images)个数量,每个figsize大小的画布# figs 返回的是一个画布对象,这个对象有imshow,set_tittle,axes_get_xasis().set_visible,# axes.get_yaxis().set_visible()这几种函数调用方式,用来给figs里面添加图像for f, img, lbl, in zip(figs, images, labels):  # 这个画布对象循环往里面添加图像信息f.imshow(img.view((28, 28)).numpy())  # img承接图像信息,将tensor转化为numpy  这里参数为数组元素f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.savefig("路径")# ----------------------------------------------------------------获取并读取FashionMNIST数据集函数,返回小批量train,test
def load_data_fashion_mnist(batch_size):mnist_train = torchvision.datasets.FashionMNIST(root='路径',train=True, download=True, transform=transforms.ToTensor())mnist_test = torchvision.datasets.FashionMNIST(root='路径',train=False, download=True, transform=transforms.ToTensor())'''上面的mnist_train,mnist_test都是torch.utils.data.Dataset的子类,所以可以使用len()获取数据集的大小训练集和测试集中的每个类别的图像数分别是6000,1000,两个数据集分别有10个类别'''# mnist是torch.utils.data.dataset的子类,因此可以将其传入torch.utils.data.DataLoader来创建一个DataLoader实例来读取数据# 在实践中,数据读取一般是训练的性能瓶颈,特别是模型较简单或者计算硬件性能比较高的时候# DataLoader一个很有用的功能就是允许多进程来加速读取  使用num_works来设置4个进程读取数据if sys.platform.startswith('win'):num_workers = 0else:num_workers = 4train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=num_workers)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=num_workers)return train_iter, test_iter# -------------------------------------------------------------查看mnist前10个图像和标签
def check_mnist():mnist_train = torchvision.datasets.FashionMNIST(root='路径',train=True, download=True, transform=transforms.ToTensor())mnist_test = torchvision.datasets.FashionMNIST(root='路径',train=False, download=True, transform=transforms.ToTensor())X, y = [], []for i in range(10):X.append(mnist_train[i][0])  # 循环获取图像张量矩阵y.append(mnist_train[i][1])  # 循环获取图像对应数值标签show_fashion_mnist(X, get_fashion_mnist_labels(y))# feature, label = mnist_train[0]# print(feature.shape, label)  CxHxW# feature对应高和宽均为28像素的图像,因为使用了transforms.ToTensor(),所以每个像素的数值对应于【0.0,1.0】的32位浮点数# C 是通道数,RGB,灰色图像,通道数为1,H,W分别为高,宽# mnist_train[0] 是一个元祖,它包含两部分,图像数据结构和图像标签值,图像的数据结构是1x28x28结构,是一个浮点数矩阵,代表一个图像# -------------------------------------------------------------------------评价模型net在数据集data_iter上的准确率
def evaluate_accuracy(test_iter, net):acc_sum, n, x = 0.0, 0, 0.0for X, y in test_iter:  # 返回一个批量的数据元组迭代对象acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()  # 将net模型的预测y与标签y进行了准确率比较n += y.shape[0]  # 累加获得样本个数x = acc_sum / nreturn x# -------------------------------------------------------------------------训练模型函数
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params=None, lr=None, optimizer=None):for epochs in range(num_epochs):  # 循环周期train_l_sum, train_acc_sum, n = 0.0, 0.0, 0  # 预先定义 训练损失,训练精度,批量个数for X, y in train_iter:  # 批量更新y_hat = net(X)l = loss(y_hat, y).sum()  # 损失计算# 梯度清零if optimizer is not None:optimizer.zero_grad()elif params is not None and params[0].grad is not None:  # 权重存在并且权重的梯度存在for param in params:param.grad.data.zero_()l.backward()  # 反向传播# 梯度更新操作if optimizer is None:sgd(params, lr, batch_size)  # 调用sgd进行梯度下降操作else:optimizer.step()  # softmax回归的简洁实现将要用到train_l_sum += l.item()  # 损失累加train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()  # (y_hat.argmax(dim=1) == y)# 取出y_hat每一行中最大的概率索引和y比较,结果为tensor,元素值为0/1n += y.shape[0]  # 计算一个批量中标签的个数test_acc = evaluate_accuracy(test_iter, net)  # 一个循环之后进行测试集的准确度计算print(f'epoch %d,loss %.4f,train_acc %.3f,test_acc %.3f'% (epochs + 1, train_l_sum / n, train_acc_sum / n, test_acc))# x = torch.tensor([[0.1, 0.4, 0.2], [1, 0.06, 0.5]])
# print((x.argmax(dim=1)==torch.tensor([[1,1]])).float())# -------------------------------------------------------------------------x的形状转换功能函数
class FlattenLayer(torch.nn.Module):def __init__(self):super(FlattenLayer, self).__init__()  # 初始化函数,自动调用forward函数def forward(self, x):  # x shape: (batch,*,*,....)return x.view(x.shape[0], -1)  # 转换成(batch_size,特征数)形状# 这样就方便定义模型
net = torch.nn.Sequential(# FlattenLayer()# torch.nn.Linear(num_inputs,num_outputs)OrderedDict([('flatten', FlattenLayer()),('linear', torch.nn.Linear(2, 3))])
)'''
-------------------------------------------------------------------作图函数
'''def semilogy(x_vals, y_vals, xlabel, ylabel, label, x2_vals=None, y2_vals=None, legend=None):plt.xlabel(xlabel)plt.ylabel(ylabel)plt.semilogy(x_vals, y_vals)  # y轴使用对数尺度if x2_vals and y2_vals:plt.semilogy(x2_vals, y2_vals, linestyle=':')plt.legend(legend)plt.savefig("路径/多项式" + label + "模拟.png")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/637559.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

1.11马原

同一性是事物存在和发展的前提,一方的发展以另一方的发展为条件 同一性使矛盾双方相互吸收有利于自身的因素,在相互作用中各自得到发展 是事物发展根本规律,唯物辩证法的实质和核心 揭示了事物普遍联系的根本内容和变化发展的内在动力 是贯…

Visual Studio 设置编辑框(即代码编辑器)的背景颜色

在Visual Studio 中设置编辑框(即代码编辑器)的背景颜色,可以按照以下步骤进行: 打开Visual Studio。在菜单栏上找到并点击“工具”(Tools)选项。在下拉菜单中选择“选项”(Options)。在“选项”对话框中,导航至“环境…

设计模式-资源库模式

设计模式专栏 模式介绍模式特点应用场景资源库模式与关系型数据库的区别代码示例Java实现资源库模式Python实现资源库模式 资源库模式在spring中的应用 模式介绍 资源库模式是一种架构模式,介于领域层与数据映射层(数据访问层)之间。它的存在…

Django REST Framework入门之序列化器

文章目录 一、概述二、安装三、序列化与反序列化介绍四、之前常用三种序列化方式jsonDjango内置Serializers模块Django内置JsonResponse模块 五、DRF序列化器序列化器工作流程序列化(读数据)反序列化(写数据) 序列化器常用方法与属…

Flink(十四)【Flink SQL(中)查询】

前言 接着上次写剩下的查询继续学习。 Flink SQL 查询 环境准备: # 1. 先启动 hadoop myhadoop start # 2. 不需要启动 flink 只启动yarn-session即可 /opt/module/flink-1.17.0/bin/yarn-session.sh -d # 3. 启动 flink sql 的环境 sql-client ./sql-client.sh …

力扣每日一题---1547. 切棍子的最小成本

//当我们将棍子分段之后,我们是不是想到了怎么组合这些棍子 //并且这些棍子有一个性质就是只能与相邻的进行组合 //暴力搜索的话复杂度很高 //在思考暴力搜索的时候,我们发现一个规律 //比如棍子长度1 2 1 1 2 //那么与最后一个2组合的棍子有&#xff0c…

【大数据分析与挖掘技术】Mahout推荐算法

目录 一、推荐的定义与评估 (一)推荐的定义 (二)推荐的评估 二、Mahout中的常见推荐算法 (一)基于用户的推荐算法 (二)基于物品的推荐算法 (三)基于S…

SQL注入实战:http报文包讲解、http头注入

一:http报文包讲解 HTTP(超文本传输协议)是今天所有web应用程序使用的通信协议。最初HTTP只是一个为获取基于文本的静态资源而开发的简单协议,后来人们以各种形式扩展和利用它.使其能够支持如今常见的复杂分布式应用程序。HTTP使用一种用于消息的模型:客…

NLP论文阅读记录 - 2021 | WOS 使用预训练的序列到序列模型进行土耳其语抽象文本摘要

文章目录 前言0、论文摘要一、Introduction1.1目标问题1.2相关的尝试1.3本文贡献 二.相关工作2.1 预训练的序列到序列模型2.2 抽象文本摘要 三.本文方法3.1 总结为两阶段学习3.1.1 基础系统 3.2 重构文本摘要 四 实验效果4.1数据集4.2 对比模型4.3实施细节4.4评估指标4.5 实验结…

maven 基本知识/1.17

maven ●maven是一个基于项目对象模型(pom)的项目管理工具,帮助管理人员自动化构建、测试和部署项目 ●pom是一个xml文件,包含项目的元数据,如项目的坐标(GroupId,artifactId,version )、项目的依赖关系、构建过程 ●生命周期&…

[VulnHub靶机渗透]:billu_b0x 快速通关

🍬 博主介绍👨‍🎓 博主介绍:大家好,我是 hacker-routing ,很高兴认识大家~ ✨主攻领域:【渗透领域】【应急响应】 【python】 【VulnHub靶场复现】【面试分析】 🎉点赞➕评论➕收藏 == 养成习惯(一键三连)😋 🎉欢迎关注💗一起学习👍一起讨论⭐️一起进步…

在可执行文件中追加资源文件(C语言)

咦,2018年写的竟然放在草稿夹里了。。。 本来是想研究下怎么把已经定义好的数据库追加到可执行文件中的,但是转念又想总归是要重新编译,不如直接把预定义的数据参数直接写到代码里更简单一些,研究的过程中顺便总结了下这篇文章。 …

数据库性能优化的解决方案

目录​​​​​​​ 1、什么是数据库性能优化 1.1 数据库性能优化的概念 1.2 为何需要进行数据库性能优化 1.3 数据库性能优化的好处 2、数据库性能优化的基本原理 2.1 数据库查询优化 2.2 数据库索引优化 2.3 数据库表结构优化 2.4 数据库硬件优化 3、数据库查询优化…

OpenHarmony AI框架开发指导

一、概述 1、 功能简介 AI业务子系统是OpenHarmony提供原生的分布式AI能力的子系统。AI业务子系统提供了统一的AI引擎框架,实现算法能力快速插件化集成。 AI引擎框架主要包含插件管理、模块管理和通信管理模块,完成对AI算法能力的生命周期管理和按需部…

Tensorflow2 GPU版本-极简安装方式

Tensorflow2 GPU版本-极简安装方式: 1、配置conda环境加速 https://wtl4it.blog.csdn.net/article/details/135723095https://wtl4it.blog.csdn.net/article/details/135723095 2、tensorflow-gpu安装 conda create -n STZZWANG_TF2 tensorflow-gpu2.0

[AutoSar]BSW_OS 02 Autosar OS_STACK

目录 关键词平台说明一、 task stack1.1 Task stack 的共享1.2 task stack 的实际使用大小 二、ISR stack2.1 ISR stack 的共享 三、Single-stack(单一栈)和multi-stack (多栈)策略3.1 Single-stack3.2 multi-stack 四、Stack Che…

Datawhale 强化学习笔记(三)基于策略梯度(policy-based)的算法

文章目录 参考基于价值函数的缺点策略梯度算法REINFORCE 算法策略梯度推导进阶策略函数的设计离散动作的策略函数连续动作的策略函数 参考 第九章 策略梯度 之前介绍的 DQN 算法属于基于价值(value-based)的算法,基于策略梯度的算法直接对策略本身进行优化。 将策…

HackTheBox - Medium - Linux - BackendTwo

BackendTwo BackendTwo在脆弱的web api上通过任意文件读取、热重载的uvicorn从而访问目标,之后再通过猜单词小游戏获得root 外部信息收集 端口扫描 循例nmap Web枚举 feroxbuster扫目录 /api/v1列举了两个节点 /api/v1/user/1 扫user可以继续发现login和singup 注…

Java设计模式-抽象工厂模式(5)

大家好,我是馆长!从今天开始馆长开始对java设计模式的创建型模式中的单例模式、原型模式、工厂方法、抽象工厂、建造者的抽象工厂模式进行讲解和说明。 抽象工厂模式(Abstract Factory Pattern) 定义 是一种为访问类提供一个创建一组相关或相互依赖对象的接口,且访问类…

Webpack5入门到原理18:Plugin 原理

Plugin 的作用 通过插件我们可以扩展 webpack,加入自定义的构建行为,使 webpack 可以执行更广泛的任务,拥有更强的构建能力。 Plugin 工作原理 webpack 就像一条生产线,要经过一系列处理流程后才能将源文件转换成输出结果。 这条…