torch.nn

torch.nn 与 torch.nn.functional

说起torch.nn,不得不说torch.nn.functional!
这两个库很类似,都涵盖了神经网络的各层操作,只是用法有点不同,比如在损失函数Loss中实现交叉熵! 但是两个库都可以实现神经网络的各层运算。其他包括卷积、池化、padding、激活(非线性层)、线性层、正则化层、其他损失函数Loss,两者都可以实现不过nn.functional毕竟只是nn的子库,nn的功能要多一些,还可以实现如Sequential()这种将多个层弄到一个序列这样复杂的操作。

nn.functional.xxx 是**函数接口**,nn.Xxx 是 .nn.functional.xxx 的**类封装**
nn.Xxx 除了具有 nn.functional.xxx 功能之外,内部附带 nn.Module 相关的属性和方法,eg. train(), eval(), load_state_dict, state_dict

如何看关于torch.nn的 API

首先看左侧,这些都是功能的分类! 比如Containers中就包含了 集合各种神经网络的操作的 容器;
而Convolution Layers则是抱哈了各种卷积操作! 如果我们想要调用的话,不是torch.nn.Containers,而是torch.nn.Moule()、torch.nn.Sequential() 等

torch.nn中某些函数

1. torch.nn.Parameter()

一种张量,被认为是一个模参数。
ParameterTensor的子类,当与Module s一起使用时,它们有一个非常特殊的属性——当它们被分配为模块属性时,它们会自动添加到参数列表中,并将出现在Parameters()迭代器中(可以通过nn.Moudle.Parameter()获得)。
赋值一个张量没有这样的效果。这是因为人们可能想要在模型中缓存一些临时状态,比如RNN的最后一个隐藏状态。如果没有Parameter这样的类,这些临时对象也会被注册。
参数

  • data(张量)–parameter tensor。
  • Requires_grad (bool,optional)—如果参数需要渐变。有关更多细节,请参阅局部禁用梯度计算。默认值:True

Parameter Vs Tensor
首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。

所以如果只是单纯的线性层or卷积层,是可以使用tensor的;
但是如果是在模型中,也就是Containers容器中的,那么就必须是Parameter类型

在concat注意力机制中,权值V是不断学习的所以要是parameter类型
通过做下面的实验发现,linear里面的weight和bias就是parameter类型,且不能够使用tensor类型替换,还有linear里面的weight甚至可能通过指定一个不同于初始化时候的形状进行模型的更改。

与torch.tensor([1,2,3],requires_grad=True)的区别,这个只是将参数变成可训练的,并没有绑定在module的parameter列表中

Containers()

2. torch.nn.Module()

所有神经网络模块的基类。
可以看博客

3. torch.nn.Sequential()

顺序容器。
顺序容器。Modules将按照它们在构造函数中传递的顺序添加到它中。另外,Modules的OrderedDict可以被传入。Sequentialforward()方法接受任何输入并将其转发到它包含的第一个Module。然后,它将输出按顺序链接到每个后续Module的输入,最后返回最后一个Module的输出。
一个
Sequential
通过手动调用模块序列提供的值是,它允许将整个容器视为单个Module,这样在Sequential上执行转换将应用于它存储的每个Module(每个Module都是Sequential的注册subModule)。

Sequentialtorch.nn.ModuleList的区别是什么?ModuleList就是一个用来存储Module的列表!另一方面,Sequential中的各层以级联方式连接。

举例

# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(1,20,5)),('relu1', nn.ReLU()),('conv2', nn.Conv2d(20,64,5)),('relu2', nn.ReLU())]))

4. torch.nn.ModuleDict()

在字典中保存子模块。
ModuleDict可以像普通的Python字典一样被索引,但它包含的Module是正确注册的,所有Module方法都可以看到它。
ModuleDict是一个有序字典,它反映了

  • 插入的顺序,和
  • update()中,合并的OrderedDictdict(从Python 3.6开始)或另一个ModuleDict (**update()**的参数)的顺序。

注意,对于其他无序映射类型的update()(例如,Python 3.6版之前的普通dict)不会保留合并映射的顺序。

参数
Modules (iterable, optional)——(string: module)的映射(字典)或键值对类型的可迭代对象(string, module)

实例

class MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.choices = nn.ModuleDict({'conv': nn.Conv2d(10, 10, 3),'pool': nn.MaxPool2d(3)})self.activations = nn.ModuleDict([['lrelu', nn.LeakyReLU()],['prelu', nn.PReLU()]])def forward(self, x, choice, act):x = self.choices[choice](x)x = self.activations[act](x)return x

项目中的实例

self.conv_filters = nn.ModuleDict({str(x): nn.Conv2d(3 if self.config.use_context else 2,self.config.num_filters,(x, self.config.word_embedding_dim))for x in self.config.window_sizes})

方法
1、 clear()
从ModuleDict中删除所有项。

2、items()
返回ModuleDict键/值对的可迭代对象。

3、keys()
返回ModuleDict键的可迭代对象。

4、pop(key)
从ModuleDict中移除key并返回它的模块。
参数:key (string) -从ModuleDict中弹出的键

5、update(modules)
使用映射或可迭代对象的键值对更新ModuleDict,覆盖现有的键。
如果modules是OrderedDict、ModuleDict或键值对的可迭代对象,则其中新元素的顺序将保留。
参数: modules (iterable)——一个从字符串到模块的映射(字典),或者一个键值对类型的可迭代对象(string, Module)

6、values()
返回ModuleDict值的可迭代对象。

Convolution Layers

5. torch.nn.Conv1d

6. torch.nn.Conv2d

方法
torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode=‘zeros’, device=None, dtype=None)

参数

  • in_channels (int) -输入图像中的通道数

  • out_channels (int) -由卷积产生的信道数

  • kernel_size (int或tuple) -卷积内核的大小

  • stride (int或tuple,optional)——卷积的stride。默认值:1

  • padding (int, tuple或str,optional)-添加到输入所有四方的padding。默认值:0

  • Padding_mode (string,optional)- ‘zeros’‘reflect’’ replication ‘’circular’。默认值:“0”

  • dilation(int或tuple,optional)-内核元素之间的间距。默认值:1

  • groups (int,optional)-从输入通道到输出通道的阻塞连接数。默认值:1

  • bias (bool,optional)-如果为True,则在输出中添加一个可学习的偏差。默认值:True

在这里插入图片描述

变量

  • ~Conv2d.weight (Tensor):(out_channels,in−channelsgroups\frac{ in-channels}{groups}groupsinchannels , kernel_size[0], kernel_size[1]). 这些权重的值是从U(−k,k)\mathcal{U}(-\sqrt{k}, \sqrt{k})U(k,k)取样,其中k=groups Cin ∗∏i=01kernel-size [i]k=\frac{\text { groups }}{C_{\text {in }} * \prod_{i=0}^{1} \text { kernel-size }[i]}k=Cin i=01 kernel-size [i] groups 
  • ~Conv2d.bias (Tensor) : 可学习bias的形状(out_channels)。如果偏差是True, 那么这些权重的值将从从U(−k,k)\mathcal{U}(-\sqrt{k}, \sqrt{k})U(k,k)取样,其中k=groups Cin ∗∏i=01kernel-size [i]k=\frac{\text { groups }}{C_{\text {in }} * \prod_{i=0}^{1} \text { kernel-size }[i]}k=Cin i=01 kernel-size [i] groups 

举例

>>> # 正方形的核和相同的步长
>>> m = nn.Conv2d(16, 33, 3, stride=2)
>>> # 非正方形的核和不一样的步长和填充
>>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> # non-square kernels and unequal stride and with padding and dilation
>>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
>>> input = torch.randn(20, 16, 50, 100)
>>> output = m(input)

etc

Pooling Layers

torch.nn.MaxPool1d

torch.nn.MaxPool2d

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/476208.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ORACLE使用JOB定时备份数据库

Oracle的备份一般都是在操作系统上完成,因此定时备份Oracle的功能一般都是由操作系统功能完成,比如crontab。但是Oracle的PIPE接口使得在Oracle数据库中通过JOB来备份Oracle变得可能。 这篇文章给出一个简单的例子,说明如何在JOB中定期备份数…

mysql 装载dump文件_mysql命令、mysqldump命令找不到解决

1、解决bash: mysql: command not found 的方法[rootDB-02 ~]# mysql -u root-bash: mysql: command not found原因:这是由于系统默认会查找/usr/bin下的命令,如果这个命令不在这个目录下,当然会找不到命令,我们需要做的就是映射一个链接到/u…

LeetCode 796. 旋转字符串

1. 题目 给定两个字符串, A 和 B。 A 的旋转操作就是将 A 最左边的字符移动到最右边。 例如, 若 A ‘abcde’,在移动一次之后结果就是’bcdea’ 。如果在若干次旋转操作之后,A 能变成B,那么返回True。 示例 1: 输入: A abcde, B cdeab …

【DKN】(一)KCN详解

_ init _()函数 参数: self, config, pretrained_word_embedding, pretrained_entity_embedding, pretrained_context_embedding config: 设置的固定的参数! pretrained_word_embedding: 根据下面的使用是…

搜索引擎优化经验谈

转自:http://blog.donews.com/zszwyds/archive/2009/08/24/1551179.aspx 费话少说,直入正题。 1. “白马非马”的关键字(词) 很多客户对于自己网站的关键词无从下手,大部分的客户选择都是大而全的关键词,很多的关键词如果选择…

iphone版 天行skyline_Skyline QT

应用标题Skyline QT应用描述An information and feedback gathering tool for our Skyline Queenstown visitor to discover the complex and its array of activities and food and beverage outlets.Welcome to the world of SkylineAre you looking for things to do in New…

LeetCode 788. 旋转数字

1. 题目 我们称一个数 X 为好数, 如果它的每位数字逐个地被旋转 180 度后,我们仍可以得到一个有效的,且和 X 不同的数。要求每位数字都要被旋转。 如果一个数的每位数字被旋转以后仍然还是一个数字, 则这个数是有效的。 0, 1, 和 8 被旋转后…

pycharm中无法识别相对路径的问题

这种情况如果在Windows下操作如下: 第一步: 往往拷贝下来的程序是在linux上运行的 第二步: 设置根路径 要调整有python.exe文件的地方! 这两个路径要设置成为自己的项目根目录!

vue变量传值_Vue各类组件之间传值的实现方式

1、父组件向子组件传值首先在父组件定义好数据,接着将子组件导入到父组件中。父组件只要在调用子组件的地方使用v-bind指令定义一个属性,并传值在该属性中即可,此时父组件的使命完成,请看下面关键代码::content"i…

Linux常用指令自己备用

~ 和 / 的区别: ~ 是当前用户的目录地址 / 是根目录的地址(一般称呼为root,/ 和 /root/ 是有区别的) 当用户是root用户时 ~ 代表/root/,即根目录下的root目录 / 代表/ ,即根目录 当用户是jack用户时 ~…

『号外』 排名进入3000,特致感谢!

开博半个月来,老孙项目管理成功地闯入了博客园3000名!! 谢谢博客园的朋友们!非常感谢!!“老孙项目管理”今日排名2975。这样的成绩,老孙没有预料到,开心极了。比奥巴马当选总统&…

qt如和调用linux底层驱动_擅长复杂硬件体系设计,多核系统设计,以及基于RTOS或者Linux,QT等进行相关底层驱动。...

双向可控硅在使用时,其触发限流电阻的阻值和封装应该怎么选取?(1)首先我们在进行TRIAC其驱动电路设计的时候,我们一般不直接进行驱动,而是通过DIAC或者Photo-TRIAC即光学的双向可控硅配合来使用进行驱动电路的设计,为什…

学习:Web安装项目创建桌面快捷方式及重写安装类(转)

一、WEB安装项目部署1、新建: 新建项目-安装和部署项目-WEB安装项目 2、部署: (1)进入文件系统视图,"项目-右键-视图-文件系统";也可以直接点"解决方案资源管理器"上部的快捷图标(2)在"WEB应用程序文件夹"添加文件,例如aspx文件,ico文…

12c oracle 激活_Oracle 12C 安装教程

Oracle 12c,全称Oracle Database 12c,是Oracle 11g的升级版,新增了很多新的特性。本章节就为大家介绍Oracle 12c的下载和安装步骤。Oracle 12c下载打开Oracle的官方中文网站,选择相应的版本即可。注意:下载时&#xff…

运行试错合集

试错: 在服务器训练好的参数直接被pycharm映射给覆盖了! 记得把这里取消掉! 如果在py文件中修改了代码,手动上传! 就是上面的upload! 运行结果: 运行train的结果 评估阶段: 出错…

LeetCode 806. 写字符串需要的行数

1. 题目 我们要把给定的字符串 S 从左到右写到每一行上,每一行的最大宽度为100个单位,如果我们在写某个字母的时候会使这行超过了100 个单位,那么我们应该把这个字母写到下一行。 我们给定了一个数组 widths ,这个数组 widths[0…

【转载】揭开硬件中断请求IRQ所有秘密(图解)

转载自:http://news.csdn.net/n/20040517/45868.html IRQ(Interrupt Request)的作用就是在我们所用的电脑中,执行硬件中断请求的动作,用来停止其相关硬件的工作状态。比如我们要打印一份文件,在打印结束时就需要由系统对打印机提出…

(七)DKN:用于新闻推荐的深度知识感知网络

摘要: 背景: 新闻语言是高度浓缩的,充满了知识实体和常识。然而,现有的方法并没有意识到一些外在的知识,也不能充分发现新闻之间潜在的知识层面的联系。因此,推荐给用户的结果仅限于简单的模式&#xff0c…

平面方程(Plane Equation)

平面方程(Plane Equation) 原文链接:http://www.songho.ca/math/plane/plane.html翻译:罗朝辉 (http://www.cnblogs.com/kesalin/)本文遵循“署名-非商业用途-保持一致”创作公用协议平面方程平面上的一点以及垂直于该平面的法线唯一定义了 3D 空间的一个…

【DKN】(三)data_preprogress.py

内容 try: # 以绝对导入的方式导入cofig对象,并获取其{model_name}Config! config getattr(importlib.import_module(config), f"{model_name}Config") except AttributeError:print(f"{model_name} not included!")exit()这里…