RNN的理解

对于RNN的理解

import torch
import torch.nn as nn
import torch.nn.functional as F# 手动实现一个简单的RNN
class RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(RNN, self).__init__()# 定义权重矩阵和偏置项self.hidden_size = hidden_sizeself.W_xh = nn.Parameter(torch.randn(input_size, hidden_size))  # 输入到隐藏层的权重#

#注:
input_size = 4
hidden_size = 3
W_xh = torch.randn(input_size, hidden_size)
生成的 W_xh 会是一个形状为 (4, 3) 的张量,可能是这样的(数字是随机生成的):
tensor([[ 0.2973, -1.1254, 0.7172],
[ 0.0983, 0.2856, -0.4586],
[-0.0105, 0.2317, 0.2716],
[ 1.0431, -1.3894, -0.1525]])
这个张量有 4 行 3 列。

        self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size))  # 隐藏层到隐藏层的权重self.b_h = nn.Parameter(torch.zeros(hidden_size))  # 隐藏层偏置self.W_hy = nn.Parameter(torch.randn(hidden_size, output_size))  # 隐藏层到输出层的权重self.b_y = nn.Parameter(torch.zeros(output_size))  # 输出层偏置def forward(self, x):# 初始化隐藏状态为0h_t = torch.zeros(x.size(0), self.hidden_size)  # 初始隐藏状态 [[5]]

注:
x 是输入数据,形状是 (3, 5, 4),其中:
3 是批量大小(batch_size),即我们一次性输入网络的样本数是 3。
5 是序列长度(seq_len),每个样本有 5 个时间步。
4 是每个时间步的输入特征数量。
self.hidden_size 假设是 6,表示隐藏层的维度是 6。
x.size(0) 获取输入张量 x 的第一个维度的大小,也就是批量大小 3。
torch.zeros(3, 6) 会创建一个形状为 (3, 6) 的张量,表示有 3 个样本,每个样本有 6 个隐藏状态神经元(即隐状态的维度是 6)。所有的元素都初始化为 0。

    # 遍历时间步,逐个处理输入序列for t in range(x.size(1)):  # x.size(1) 是序列长度x_t = x[:, t, :]  # 获取当前时间步的输入 (batch_size, input_size)

`
注:x = torch.tensor([[[0.1, 0.2, 0.3, 0.4], # 第 0 时间步的输入 (第一个样本)
[0.5, 0.6, 0.7, 0.8], # 第 1 时间步的输入 (第一个样本)
[0.9, 1.0, 1.1, 1.2]], # 第 2 时间步的输入 (第一个样本)
[[1.3, 1.4, 1.5, 1.6], # 第 0 时间步的输入 (第二个样本)
[1.7, 1.8, 1.9, 2.0], # 第 1 时间步的输入 (第二个样本)
[2.1, 2.2, 2.3, 2.4]]]) # 第 2 时间步的输入 (第二个样本)

第一次循环 t=0:
x_t = x[:, 0, :]
x[:, 0, :] 会提取出所有样本在第 0 时间步的输入:

第一个样本在第 0 时间步的输入是 [0.1, 0.2, 0.3, 0.4]。

第二个样本在第 0 时间步的输入是 [1.3, 1.4, 1.5, 1.6]。

因此,x_t 的值是:

tensor([[0.1, 0.2, 0.3, 0.4],
[1.3, 1.4, 1.5, 1.6]])

        # 更新隐藏状态:h_t = tanh(W_xh * x_t + W_hh * h_t + b_h)h_t = torch.tanh(x_t @ self.W_xh + h_t @ self.W_hh + self.b_h)  # [[4]]

`
注:1. 计算 x_t @ W_xh
x_t @ W_xh 是输入 x_t 和权重矩阵 W_xh 的矩阵乘法。我们有 2 个样本,每个样本有 3 个输入特征,权重矩阵 W_xh 的形状是 (3, 4),所以乘法的结果是一个形状为 (2, 4) 的张量,即每个样本的隐藏状态更新的部分。

对于第一个样本:

[0.5, 0.6, 0.7] @ [[0.1, 0.2, -0.1, 0.4],
[0.3, 0.5, 0.2, -0.2],
[0.7, -0.1, 0.3, 0.5]]
我们可以计算它的结果:

= [0.5 * 0.1 + 0.6 * 0.3 + 0.7 * 0.7,
0.5 * 0.2 + 0.6 * 0.5 + 0.7 * (-0.1),
0.5 * -0.1 + 0.6 * 0.2 + 0.7 * 0.3,
0.5 * 0.4 + 0.6 * (-0.2) + 0.7 * 0.5]

= [0.05 + 0.18 + 0.49,
0.1 + 0.3 - 0.07,
-0.05 + 0.12 + 0.21,
0.2 - 0.12 + 0.35]

= [0.72, 0.33, 0.28, 0.43]
对于第二个样本:

[1.0, 1.2, 1.3] @ [[0.1, 0.2, -0.1, 0.4],
[0.3, 0.5, 0.2, -0.2],
[0.7, -0.1, 0.3, 0.5]]
计算结果:

= [1.0 * 0.1 + 1.2 * 0.3 + 1.3 * 0.7,
1.0 * 0.2 + 1.2 * 0.5 + 1.3 * (-0.1),
1.0 * -0.1 + 1.2 * 0.2 + 1.3 * 0.3,
1.0 * 0.4 + 1.2 * (-0.2) + 1.3 * 0.5]

= [0.1 + 0.36 + 0.91,
0.2 + 0.6 - 0.13,
-0.1 + 0.24 + 0.39,
0.4 - 0.24 + 0.65]

= [1.37, 0.67, 0.53, 0.81]
因此,x_t @ W_xh 的结果是:

tensor([[0.72, 0.33, 0.28, 0.43],
[1.37, 0.67, 0.53, 0.81]])

x_t @ self.W_xh:
x_t 是当前时间步的输入,形状是 (batch_size, input_size)。
self.W_xh 是输入到隐藏层的权重矩阵,形状是 (input_size, hidden_size)。
h_t @ self.W_hh:
h_t 是前一时间步的隐藏状态,形状是 (batch_size, hidden_size)。
self.W_hh 是隐藏层到隐藏层的权重矩阵,形状是 (hidden_size, hidden_size)。

    # 最后一个时间步的隐藏状态通过全连接层得到输出y_t = h_t @ self.W_hy + self.b_y  # 输出层return y_t

超参数设置

input_size = 10 # 输入特征维度
hidden_size = 20 # 隐藏层维度
output_size = 5 # 输出类别数
seq_length = 5 # 序列长度
batch_size = 3 # 批量大小

实例化模型

model = RNN(input_size, hidden_size, output_size)

打印模型结构

print(model)

创建随机输入数据 (batch_size, seq_length, input_size)

x = torch.randn(batch_size, seq_length, input_size)

前向传播

output = model(x)
print(“Output shape:”, output.shape) # 输出形状应为 (batch_size, output_size)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/902269.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

二叉查找树和B树

二叉查找树(Binary Search Tree, BST)和 B 树(B-tree)都是用于组织和管理数据的数据结构,但它们在结构、应用场景和性能方面有显著区别。 二叉查找树(Binary Search Tree, BST) 特点&#xff1…

一段式端到端自动驾驶:VAD:Vectorized Scene Representation for Efficient Autonomous Driving

论文地址:https://github.com/hustvl/VAD 代码地址:https://arxiv.org/pdf/2303.12077 1. 摘要 自动驾驶需要对周围环境进行全面理解,以实现可靠的轨迹规划。以往的方法依赖于密集的栅格化场景表示(如:占据图、语义…

OpenCV训练题

一、创建一个 PyQt 应用程序,该应用程序能够: 使用 OpenCV 加载一张图像。在 PyQt 的窗口中显示这张图像。提供四个按钮(QPushButton): 一个用于将图像转换为灰度图一个用于将图像恢复为原始彩色图一个用于将图像进行…

opencv函数展示4

一、形态学操作函数 1.基本形态学操作 (1)cv2.getStructuringElement() (2)cv2.erode() (3)cv2.dilate() 2.高级形态学操作 (1)cv2.morphologyEx() 二、直方图处理函数 1.直方图…

iPhone 13P 换超容电池,一年实记的“电池循环次数-容量“柱状图

继上一篇 iPhone 13P 更换"移植电芯"和"超容电池"🔋体验,详细记录了如何更换这两种电池,以及各自的优略势对比。 一晃一年过去,时间真快,这次分享下记录了使用超容电池的 “循环次数 - 容量(mAh)…

基于 pnpm + Monorepo + Turbo + 无界微前端 + Vite 的企业级前端工程实践

基于 pnpm Monorepo Turbo 无界微前端 Vite 的企业级前端工程实践 一、技术演进:为什么引入 Vite? 在微前端与 Monorepo 架构落地后,构建性能成为新的优化重点: Webpack 构建瓶颈:复杂配置导致开发启动慢&#…

(五)机器学习---决策树和随机森林

在分类问题中还有一个常用算法:就是决策树。本文将会对决策树和随机森林进行介绍。 目录 一.决策树的基本原理 (1)决策树 (2)决策树的构建过程 (3)决策树特征选择 (4&#xff0…

Vue3使用AntvG6写拓扑图,可添加修改删除节点和边

npm安装antv/g6 npm install antv/g6 --save 上代码 <template><div id"tpt1" ref"container" style"width: 100%;height: 100%;"></div> </template><script setup>import { Renderer as SVGRenderer } from …

Arduino编译和烧录STM32——基于J-link SWD模式

一、安装Stm32 Arduino支持 在arduino中添加stm32的开发板地址&#xff1a;https://github.com/stm32duino/BoardManagerFiles/raw/main/package_stmicroelectronics_index.json 安装stm32开发板支持 二、安装STM32CubeProgrammer 从stm32网站中安装&#xff1a;https://ww…

智慧城市气象中台架构:多源天气API网关聚合方案

在开发与天气相关的应用时&#xff0c;获取准确的天气信息是一个关键需求。万维易源提供的“天气预报查询”API为开发者提供了一个高效、便捷的工具&#xff0c;可以通过简单的接口调用查询全国范围内的天气信息。本文将详细介绍如何使用该API&#xff0c;以及其核心功能和调用…

Vue 组件化开发

引言 在当今的 Web 开发领域&#xff0c;构建一个功能丰富且用户体验良好的博客是许多开发者的目标。Vue.js 作为一款轻量级且高效的 JavaScript 框架&#xff0c;其组件化开发的特性为我们提供了一种优雅的解决方案。通过将博客拆分成多个独立的组件&#xff0c;我们可以提高代…

Deno 统一 Node 和 npm,既是 JS 运行时,又是包管理器

Deno 是一个现代的、一体化的、零配置的 JavaScript 运行时、工具链&#xff0c;专为 JavaScript 和 TypeScript 开发设计。目前已有数十万开发者在使用 Deno&#xff0c;其代码仓库是 GitHub 上 star 数第二高的 Rust 项目。 Stars 数102620Forks 数5553 主要特点 内置安全性…

应用篇02-镜头标定(上)

本节主要介绍相机的标定方法&#xff0c;包括其内、外参数的求解&#xff0c;以及如何使用HALCON标定助手实现标定。 计算机视觉——相机标定(Camera Calibration)_摄像机标定-CSDN博客 1. 原理 本节介绍与相机标定相关的理论知识&#xff0c;不一定全&#xff0c;可以参考相…

PG CTE 递归 SQL 翻译为 达梦版本

文章目录 PG SQLDM SQL总结 PG SQL with recursive result as (select res_id,phy_res_code,res_name from tbl_res where parent_res_id (select res_id from tbl_res where phy_res_code org96000#20211203155858) and res_type_id 1 union all select t1.res_id, t1.p…

C# Where 泛型约束

在C#中&#xff0c;Where关键字主要有两种用途 1、在泛型约束中限制类型参数 2、在LINQ查询中筛选数据 本文主要介绍where关键字在在泛型约束中的使用 泛型定义中的 where 子句指定对用作泛型类型、方法、委托或本地函数中类型参数的参数类型的约束。通过使用 where 关键字和…

《MySQL:MySQL表的约束-主键/复合主键/唯一键/外键》

表的约束&#xff1a;表中一定要有各种约束&#xff0c;通过约束&#xff0c;让未来插入数据库表中的数据是符合预期的。约束本质是通过技术手段&#xff0c;倒逼程序员插入正确的数据。即&#xff0c;站在mysql的视角&#xff0c;凡是插入进来的数据&#xff0c;都是符合数据约…

Qt 创建QWidget的界面库(DLL)

【1】新建一个qt库项目 【2】在项目目录图标上右击&#xff0c;选择Add New... 【3】选择模版&#xff1a;Qt->Qt设计师界面类&#xff0c;选择Widget&#xff0c;填写界面类的名称、.h .cpp .ui名称 【4】创建C调用接口&#xff08;默认是创建C调用接口&#xff09; #ifnd…

汽车免拆诊断案例 | 2011款雪铁龙世嘉车刮水器偶尔自动工作

故障现象 一辆2011款雪铁龙世嘉车&#xff0c;搭载1.6 L 发动机&#xff0c;累计行驶里程约为19.8万km。车主反映&#xff0c;该车刮水器偶尔会自动工作&#xff0c;且前照灯偶尔会自动点亮。 故障诊断 接车后试车发现&#xff0c;除了上述故障现象以外&#xff0c;当用遥控器…

【Linux】NAT、代理服务、内网穿透

NAT、代理服务、内网穿透 一. NAT1. NAT 技术2. NAT IP 转换过程3. NAPT 技术4. NAT 技术的缺陷 二. 代理服务器1. 正向代理2. 反向代理3. NAT 和代理服务器 内网穿透内网打洞 一. NAT NAT&#xff08;Network Address Translation&#xff0c;网络地址转换&#xff09;技术&a…

MobaXterm连接Ubuntu(SSH)

1.查看Ubuntu ip 打开终端,使用指令 ifconfig 由图可知ip地址 2.MobaXterm进行SSH连接 点击session,然后点击ssh,最后输入ubuntu IP地址以及用户名