深度学习 - 构建神经网络

1. 自动求导机制

概念解释

  • 自动求导:PyTorch的autograd模块允许我们自动计算张量的梯度,这在反向传播算法中尤为重要。反向传播是神经网络训练的核心,用于计算每个参数的梯度并更新参数。

生活中的例子

想象你是一个厨师,正在调整一个菜谱,使它更加美味。每次你改变一个配料的量,比如盐或糖,你都会尝试这个菜,然后根据味道的变化决定是否需要进一步调整。这就像在神经网络中计算梯度:你调整网络的参数(配料),观察输出(菜的味道),然后根据输出的变化来更新参数(调整配料)。

示例代码

import torchx = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()print("x:", x)
print("y:", y)
print("z:", z)
print("out:", out)

运行结果

x: tensor([1., 2., 3.], requires_grad=True)
y: tensor([3., 4., 5.], grad_fn=<AddBackward0>)
z: tensor([27., 48., 75.], grad_fn=<MulBackward0>)
out: tensor(50., grad_fn=<MeanBackward0>)

解释

  • requires_grad=True:表示我们需要计算x的梯度。
  • y = x + 2:对每个元素加2。
  • z = y * y * 3:每个元素先平方,再乘以3。
  • out = z.mean():计算张量的平均值。

计算图与梯度计算

out.backward()
print(x.grad)

运行结果

tensor([ 4.,  8., 12.])

解释out相对于x的梯度是4*x + 4

梯度计算公式

  1. y = x + 2 y = x + 2 y=x+2
  2. z = 3 y 2 z = 3y^2 z=3y2
  3. o u t = 1 3 ∑ z out = \frac{1}{3} \sum z out=31z

反向传播:

  1. ∂ o u t ∂ z i = 1 3 \frac{\partial out}{\partial z_i} = \frac{1}{3} ziout=31
  2. ∂ z i ∂ y i = 6 y i \frac{\partial z_i}{\partial y_i} = 6y_i yizi=6yi
  3. ∂ y i ∂ x i = 1 \frac{\partial y_i}{\partial x_i} = 1 xiyi=1

所以:

∂ o u t ∂ x i = 1 3 × 6 y i × 1 = 2 y i \frac{\partial out}{\partial x_i} = \frac{1}{3} \times 6y_i \times 1 = 2y_i xiout=31×6yi×1=2yi

y i = x i + 2 y_i = x_i + 2 yi=xi+2,所以:

∂ o u t ∂ x i = 2 ( x i + 2 ) \frac{\partial out}{\partial x_i} = 2(x_i + 2) xiout=2(xi+2)

x = [ 1 , 2 , 3 ] x = [1, 2, 3] x=[1,2,3] 时:

∂ o u t ∂ x = [ 4 , 8 , 12 ] \frac{\partial out}{\partial x} = [4, 8, 12] xout=[4,8,12]

torch.autograd.Variable

现在torch.Tensor已经取代了Variable,并且默认情况下所有张量都支持自动求导,所以Variable不再需要单独使用。

2. 构建简单神经网络

概念解释

  • 神经网络:神经网络是一种模仿人脑工作方式的计算模型。它由许多相互连接的“神经元”组成,每个神经元接收输入信号并产生输出信号。
  • nn.Module:是所有神经网络模块的基类。自定义的神经网络类需要继承nn.Module并实现其方法。

生活中的例子

想象你正在教一个机器人识别不同类型的水果。你给机器人看各种水果的图片,并告诉它们每个水果的名称。机器人通过观察这些图片并学习它们的特征(比如颜色、形状),逐渐学会区分不同的水果。这就像神经网络通过训练数据学习模式。

示例代码

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(784, 256)self.fc2 = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.fc1(x))x = self.fc2(x)return xnet = Net()
print(net)

运行结果

Net((fc1): Linear(in_features=784, out_features=256, bias=True)(fc2): Linear(in_features=256, out_features=10, bias=True)
)

解释

  • __init__方法中定义了两个全连接层(fc1fc2)。
  • forward方法定义了前向传播的过程,首先通过第一层,然后应用ReLU激活函数,最后通过第二层。

前向传播

input = torch.randn(1, 784)
output = net(input)
print(output)

运行结果

tensor([[ 0.0520,  0.2651,  0.0512, -0.1564, -0.2470, -0.2246,  0.0936, -0.2600,0.1607,  0.1467]], grad_fn=<AddmmBackward>)

解释:生成一个随机输入张量input,通过网络得到输出output

损失函数和优化器

概念解释

  • 损失函数:用来衡量模型输出与实际目标之间的差异。
  • 优化器:通过反向传播计算梯度并更新模型参数,以最小化损失函数。

生活中的例子

想象你在考试中答错了一些题目,老师告诉你哪些题目答错了,并给你一些建议。你根据这些建议修改你的学习方法,下次考试争取做得更好。损失函数就像老师的反馈,优化器就像你调整学习方法的过程。

示例代码

import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)# 示例训练步骤
optimizer.zero_grad()   # 清除梯度
output = net(input)     # 前向传播
loss = criterion(output, torch.tensor([1]))  # 计算损失
loss.backward()         # 反向传播
optimizer.step()        # 更新权重

解释

  • criterion:定义损失函数,这里使用交叉熵损失函数。
  • optimizer:定义优化器,这里使用随机梯度下降(SGD)。
  • optimizer.zero_grad():清除梯度。
  • output = net(input):前向传播。
  • loss = criterion(output, torch.tensor([1])):计算损失。
  • loss.backward():反向传播,计算梯度。
  • optimizer.step():更新权重。

3. 训练流程

概念解释

  • 数据加载与处理:使用torch.utils.data模块加载和处理数据。

生活中的例子

想象你在准备一个大餐,需要从市场购买食材。你需要将所有食材分成不同的类别,并按照菜谱中的要求进行处理和烹饪。数据加载和处理就像你从市场获取食材,并准备它们以便进一步使用。

示例代码

from torch.utils.data import DataLoader, TensorDataset# 示例数据
inputs = torch.randn(100, 784)
targets = torch.randint(0, 10, (100,))dataset = TensorDataset(inputs, targets)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

解释

  • inputstargets:生成示例数据。
  • TensorDataset:将输入和目标数据打包成数据集。
  • DataLoader:加载数据集,支持批处理和打乱数据。

定义模型、损失函数和优化器

net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

解释:定义模型、损失函数和优化器。

训练循环

for epoch in range(2):  # 训练2个epochfor inputs, targets in dataloader:optimizer.zero_grad()   # 清除梯度outputs = net(inputs)   # 前向传播loss = criterion(outputs, targets)  # 计算损失loss.backward()         # 反向传播optimizer.step()        # 更新权重print(f"Epoch {epoch+1}, Loss: {loss.item()}")

**

运行结果**:

Epoch 1, Loss: 2.303864002227783
Epoch 2, Loss: 2.3021395206451416

解释:训练2个epoch,每个epoch中对每个批次数据进行前向传播、计算损失、反向传播和更新权重。

模型评估与验证

概念解释

  • 评估模式:在评估模式下,不计算梯度,节省内存和计算资源。

生活中的例子

想象你准备了一个大餐,现在邀请朋友来品尝。他们给你反馈,你记录这些反馈以便改进菜谱。这就像模型评估,你不再调整参数,而是观察模型在新数据上的表现。

示例代码

net.eval()  # 进入评估模式
with torch.no_grad():inputs = torch.randn(10, 784)outputs = net(inputs)predicted = torch.argmax(outputs, dim=1)print(predicted)

运行结果

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

解释

  • net.eval():将模型设置为评估模式。
  • torch.no_grad():禁用梯度计算,节省内存和计算资源。
  • predicted:预测的类别。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/23085.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java时间类(十六) -- 将一天的时间进行等步长分割

废话不多说,直接上工具类: import java.time.LocalDate; import java.time.LocalDateTime; import java.time.LocalTime; import java.time.format.DateTimeFormatter; import java.util.ArrayList; import java.util.List;/*** @ClassName TimeSplitterUtil* @Description …

C语言指针与数组名的联系

目录 一、数组名的理解 a.数组名代表数组首元素的地址 b. 两个例外 二、使用指针来访问数组 三、一维数组传参的本质 一、数组名的理解 a.数组名代表数组首元素的地址 我们在使用指针访问数组的内容时&#xff0c;有这样的代码&#xff1a; int arr[10] {1,2,3,4,5,6,7,…

枚举(enum)+联合体(union)

枚举联合 一.枚举类型1.枚举类型的声明2.枚举类型的优点3.枚举类型的使用 二.联合体1.联合体类型的声明2.联合体的特点3.相同成员的结构体和联合体对比4.联合体大小的计算5.联合体的练习&#xff08;判断大小端&#xff09;6.联合体节省空间例题 一.枚举类型 1.枚举类型的声明…

Sentinel1.8.6更改配置同步到nacos(项目是Gateway)

本次修改的源码在&#xff1a;https://gitee.com/stonic-open-source/sentinel-parent 一 下载源码 地址&#xff1a;https://github.com/alibaba/Sentinel/releases/tag/1.8.6 二 导入idea&#xff0c;等待maven下载好各种依赖 三 打开sentile-dashboard这个模块&#xf…

介绍下CIDR(Classless Inter-Domain Routing)无类别域间路由

最近在搞DELL EMC XtremIO的重新初始化&#xff0c;在Stortage controller和XMS的xinstall配置的时候&#xff0c;需要配置用到CIDR&#xff0c;就是classless inter-domian routing&#xff0c;总结了一下&#xff0c;其实很多对网络设备的地方都用得到&#xff0c;以前还不知…

华为手机录屏在哪里?图文详解帮你找!

随着科技的进步&#xff0c;智能手机已成为我们日常生活中不可或缺的工具。其中&#xff0c;华为手机凭借其卓越的性能和用户体验&#xff0c;在全球范围内赢得了广泛的赞誉。在众多功能中&#xff0c;录屏功能尤为实用&#xff0c;无论是制作教程、记录游戏精彩瞬间&#xff0…

压敏电阻器是在规定温度下,当电压超过某一临界值时电导随电压的升高而急速增大的一种电阻器

压敏电阻器是在规定温度下,当电压超过某一临界值时电导随电压的升高而急速增大的一种电阻器。压敏电阻器的伏安特性是非线性的,因此,压敏电阻器亦称为非线性电阻器,非线性来自于压敏电阻器两端的外加电压,其伏安特性如图 9-1所示。从图9-1可以看出,压敏电阻器有对称型和非对称型…

网络运维简介

目录 1.网络运维的定义 2.诞生背景 3.网络运维的重要性 4.优点 5.缺点 6.应用场景 6.1.十个应用场景 6.2.数据中心运维 7.应用实例 8.小结 1.网络运维的定义 网络运维&#xff08;Network Operations&#xff09;是指管理、监控和维护计算机网络以确保其高效、安全和…

2024最新华为OD算法题目

在一个机房中,服务器的位置标识在 n*m 的整数矩阵网格中,1表示单元格上有服务器,0 表示没有。如果两台服务器位于同一行或者同一列中紧邻的位置,则认为它们之间可以组成一个局域网。请你统计机房中最大的局域网包含的服务器个数。 输入描述 第一行输入两个正整数,n和m,…

Python私教张大鹏 Vue3整合AntDesignVue之文本组件

案例&#xff1a;展示标题 核心代码&#xff1a; <a-typography><a-typography-title>Introduction</a-typography-title> </a-typography>vue3示例&#xff1a; <template><a-typography><a-typography-title>这是一个标题</…

HTTP请求过程

HTTP&#xff08;超文本传输协议&#xff09;请求过程是客户端&#xff08;通常是浏览器&#xff09;与服务器之间通信的方式&#xff0c;用于从服务器请求资源&#xff08;如网页、图片、视频等&#xff09;。以下是HTTP请求的基本步骤&#xff1a; 建立TCP连接&#xff1a; 如…

【K8s】专题四(6):Kubernetes 控制器之 Job

以下内容均来自个人笔记并重新梳理&#xff0c;如有错误欢迎指正&#xff01;如果对您有帮助&#xff0c;烦请点赞、关注、转发&#xff01;欢迎扫码关注个人公众号&#xff01; 目录 一、基本介绍 二、工作原理 三、相关特性 四、资源清单&#xff08;示例&#xff09; 五…

C语言经典习题20

一编写一个函数用于计算高于平均分的人数 编写一个函数int fun(float s[],int n)&#xff0c;用于计算高于平均分的人数&#xff0c;并作为函数值返回&#xff0c;其中数组s中存放n位学生的成绩。再编写一个主函数&#xff0c;从键盘输入一批分数&#xff08;用-1来结束输入&a…

电路分析答疑 1

三要素法求解的时候&#xff0c; 电容先求U&#xff0c;再利用求导求I 电感先求I&#xff0c;再利用求导求U 若I的头上没有点点&#xff0c;那就是求有效值 叠加定理&#xff0c;不要忘记 若电流值或者电压值已经给出来了&#xff0c;那就说明这一定是直流电。 在画画圈的时候…

数据库(25)——多表关系介绍

在项目开发中&#xff0c;进行数据库表结构设计时&#xff0c;会根据业务需求及业务模块之间的关系&#xff0c;分析并设计表结构&#xff0c;各个表之间的结构基本上分为三种&#xff1a;一对多&#xff0c;多对多&#xff0c;一对一。 一对多 例如&#xff0c;一个学校可以有…

Mac修改Mysql8.0密码

转载请标明出处&#xff1a;http://blog.csdn.net/donkor_/article/details/139392605 文章目录 前言修改密码Step1:修改my.conf文件Step2:添加配置skip-grant-tablesStep3:重启mysql服务Step4:进入mysqlStep5:刷新权限Step6:修改密码Step7:再次刷新权限Step8:删除/注释 skip-…

DNS域名

DNS域名 DNS是域名系统的简称 域名和ip地址之间的映射关系 互联网中&#xff0c;ip地址是通信的唯一标识 访问网站&#xff0c;域名&#xff0c;ip地址不好记&#xff0c;域名朗朗上口&#xff0c;好记。 域名解析的目的就是为了实现&#xff0c;访问域名就等于访问ip地址…

【Python】 获取当前日期的Python代码解析与应用

标题&#xff1a;Python中获取当前日期的简单指南 基本原理 在Python中&#xff0c;获取当前日期是一个常见的需求&#xff0c;尤其是在处理日志、数据记录和时间相关的任务时。Python提供了多种方式来获取和处理日期和时间&#xff0c;其中最常用的模块是datetime。datetime…

多客陪玩系统-开源陪玩系统平台源码-支持游戏线上陪玩家政线下预约等多场景应用支持H5+小程序+APP

多客陪玩系统-开源陪玩系统平台源码-支持游戏线上陪玩家政按摩线下预约等多场景应用支持H5小程序APP 软件架构 前端&#xff1a;Uniapp-vue2.0 后端&#xff1a;Thinkphp6 前后端分离 前端支持&#xff1a; H5小程序双端APP&#xff08;安卓苹果&#xff09; 安装教程 【商业…

学习VUE3——组件(一)

组件注册 分为全局注册和局部注册两种。 全局注册&#xff1a; 在main.js或main.ts中&#xff0c;使用 Vue 应用实例的 .component() 方法&#xff0c;让组件在当前 Vue 应用中全局可用。 import { createApp } from vue import MyComponent from ./App.vueconst app crea…