面向对象方法使用gluon

一、面向过程与面向对象的优缺点

面向过程使用mxnet,就是使用gluon封装好的对象,不加改动的表达机器学习的逻辑过程,其特点是方便、快捷,缺点是不够灵活(虽然可以应对90%以上的问题了),面向对象基于继承、多态的性质,对原有的gluon类进行了继承重写,并在不改变应用接口的情况下(基于多态),灵活的改写原有类,使之更加符合用户特殊需求。本文从自定义模型、自定义层、自定义初始化三个方面说明gluon的继承重写,这三个基本操作足够用户随心所欲的创造模型了。

二、自定义模型

1、定义静态模型

静态模型就是实例化后模型的结构就不能随便改变了,其代码如下:

from mxnet import nd
from mxnet.gluon import nnclass MLP(nn.Block):# 声明带有模型参数的层,这里声明了两个全连接层def __init__(self, **kwargs):# 调用MLP父类Block的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数# 参数,如“模型参数的访问、初始化和共享”一节将介绍的模型参数paramssuper(MLP, self).__init__(**kwargs)self.hidden = nn.Dense(256, activation='relu')  # 隐藏层self.output = nn.Dense(10)  # 输出层# 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出def forward(self, x):return self.output(self.hidden(x))X = nd.random.uniform(shape=(2, 20))
net = MLP()
net.initialize()
net(X)

2、定义动态模型

动态模型就是在实例化以后,后续可以根据需要随时修改模型结构,下面只定义一个增加网络层的功能。

class MySequential(nn.Block):def __init__(self, **kwargs):super(MySequential, self).__init__(**kwargs)def add(self, block):# block是一个Block子类实例,假设它有一个独一无二的名字。我们将它保存在Block类的# 成员变量_children里,其类型是OrderedDict。当MySequential实例调用# initialize函数时,系统会自动对_children里所有成员初始化self._children[block.name] = blockdef forward(self, x):# OrderedDict保证会按照成员添加时的顺序遍历成员for block in self._children.values():x = block(x)return xnet = MySequential()
net.add(nn.Dense(256, activation='relu'))
net.add(nn.Dense(10))
net.initialize()
net(X)

三、定义tensor流

tensor流就是tensor之间是怎样运算的,gluon默认的tensor流是简单的tensor乘法运算,自定义就对tengsor流使用if判断、for循环手段,构造出更加复杂的tensor流,这一点在后面的卷积网络、循环网络中频繁使用。

class FancyMLP(nn.Block):def __init__(self, **kwargs):super(FancyMLP, self).__init__(**kwargs)# 使用get_constant创建的随机权重参数不会在训练中被迭代(即常数参数)self.rand_weight = self.params.get_constant('rand_weight', nd.random.uniform(shape=(20, 20)))self.dense = nn.Dense(20, activation='relu')def forward(self, x):x = self.dense(x)# 使用创建的常数参数,以及NDArray的relu函数和dot函数x = nd.relu(nd.dot(x, self.rand_weight.data()) + 1)# 复用全连接层。等价于两个全连接层共享参数x = self.dense(x)# 控制流,这里我们需要调用asscalar函数来返回标量进行比较while x.norm().asscalar() > 1:x /= 2if x.norm().asscalar() < 0.8:x *= 10return x.sum()net = FancyMLP()
net.initialize()
net(X)

说明

  1. 以上三个方法是可以结合起来使用的,基于这三点用户可以使用gluon构造出各种卷积、循环网络。
  2. 以上三种继承方式中,forward函数必须定义重写,否则出现下面的错误,就是没找到forward propagation。
print(net(X))
out = self.forward(*args)
raise NotImplementedError
NotImplementedError

三、自定义层

层与模型没有本质区别,从语言角度讲是一样的,二者的数据结构都是tensor+forward,只是用途不同而已。层可以理解为整个模型的一层或一部分,是一段网络,层的作用用来构造模型。

1、gluon的层

Dense层:forward = (X * weight + bias).relu()

g_layer = nn.Dense(2)
g_layer.initialize(init=init.One())
X = nd.array([1, 2, 3, 4]).reshape((1, 4))
y = g_layer(X)
print('weight of g_layer:', g_layer.weight.data())
print('bias of g_layer:', g_layer.bias.data())
print('X:', X)
print('g_layer(X):', y)
print('structure of g_layer:', g_layer)"""
# output
weight of g_layer: 
[[1. 1. 1. 1.][1. 1. 1. 1.]]
<NDArray 2x4 @cpu(0)>
bias of g_layer: 
[0. 0.]
<NDArray 2 @cpu(0)>
X: 
[[1. 2. 3. 4.]]
<NDArray 1x4 @cpu(0)>
g_layer(X): 
[[10. 10.]]
<NDArray 1x2 @cpu(0)>
structure of g_layer: Dense(4 -> 2, linear)
"""

说明:

  1. 再次强调一遍,层和模型的要素是tensor + forward,上面的g_layer是gluon默认的forward,即进行简单的乘法运算(X * tensor);
  2. 因为上面的层从模型的角度看只有一个层,所以查看参数的时候使用g_layer.weight.data(),而不是g_layer[0].weight.data();

2、自定义无参数层

from mxnet import gluon, nd
from mxnet.gluon import nnclass CenteredLayer(nn.Block):def __init__(self, **kwargs):super(CenteredLayer, self).__init__(**kwargs)def forward(self, x):return x - x.mean()
layer = CenteredLayer()
layer(nd.array([1, 2, 3, 4, 5]))

说明: 与上面的g_layer没有区别,都是tensor+forward,这里layer.weight.data()就会报错,因为是0个层;

3、自定义含参数层

自定义的层的意思是tensor也要自定义,tensor就是weight + bias;

class MyDense(nn.Block):def __init__(self, units, in_units, **kwargs):super(MyDense, self).__init__(**kwargs)self.weight1 = self.params.get('haha_weight', shape=(in_units, units))self.bias1 = self.params.get('haha_bias', shape=(units,))def forward(self, x):linear = nd.dot(x, self.weight1.data()) + self.bias1.data()return nd.relu(linear)if __name__ == '__main__':dense = MyDense(units=3, in_units=5)dense.initialize()dense(nd.random.uniform(shape=(2, 5)))print(dense.weight1.data()[0])"""
[0.0068339  0.01299825 0.0301265 ]
<NDArray 3 @cpu(0)>
"""

说明:从这个代码中可以看出一个层的本质就是一段网络;

4、层的应用

net = nn.Sequential()
net.add(MyDense(8, in_units=64),MyDense(1, in_units=8))
net.initialize()
y = net(nd.random.uniform(shape=(2, 64)))
print('self_define tensor:', net[0].weight1.data()[0])"""
self_define tensor: 
[0.0068339  0.01299825 0.0301265  0.04819721 0.01438687 0.050112390.00628365 0.04861524]
<NDArray 8 @cpu(0)>
"""

四、自定义初始化

1、_init_weight在做什么?

# -*- coding: utf-8 -*-
from mxnet import init, nd
from mxnet.gluon import nnclass MyInit(init.Initializer):def _init_weight(self, name, data):print('Init', name, data.shape)if __name__ == '__main__':net = nn.Sequential()net.add(nn.Dense(256, activation='relu'),nn.Dense(256, activation='relu'),nn.Dense(10))net.initialize(init=MyInit())X = nd.random.uniform(shape=(2, 20))print('---------1---------')Y = net(X)print('---------2---------')net.initialize(init=MyInit(), force_reinit=True)"""
# output
---------1---------
Init dense0_weight (256, 20)
Init dense1_weight (256, 256)
Init dense2_weight (10, 256)
---------2---------
Init dense0_weight (256, 20)
Init dense1_weight (256, 256)
Init dense2_weight (10, 256)
"""

2、怎么使用_init_weight自定义初始化?

class MyInit(init.Initializer):def _init_weight(self, name, data):print('Init', name, data.shape)data[:] = nd.random.uniform(low=-10, high=10, shape=data.shape)data *= data.abs() >= 5net.initialize(MyInit(), force_reinit=True)
net[0].weight.data()[0]

说明:上面仅说明对weight初始化,gulon也提供了_init_bias,但是最后还是强制bias=0,也就是重写的_init_bias没有被调用,从机器学习的角度讲,bias一般初始化为0;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/420787.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mapreduce执行过程分析(基于Hadoop2.4)——(一)

1 概述 该瞅瞅MapReduce的内部运行原理了&#xff0c;以前只知道个皮毛&#xff0c;再不搞搞&#xff0c;不然怎么死的都不晓得。下文会以2.4版本中的WordCount这个经典例子作为分析的切入点&#xff0c;一步步来看里面到底是个什么情况。 2 为什么要使用MapReduce Map/Reduce&…

spring配置数据源

spring配置数据源1. 什么是数据源连接池2. 手动创建数据源&#xff08;c3p0&#xff0c;druid&#xff09;2.1 导入数据库连接驱动&#xff0c;数据源pom坐标2.2 创建数据源2.3 配置jdbc.properties, 解耦拿到数据源3. spring配置数据源3.1 bean方式创建数据源13.2 bean方式创建…

基于mxnet的Regression问题Kaggle比赛代码框架

一、概述 书中3.16节扩展一下可以作为kaggle比赛的框架&#xff0c;这个赛题的名字是House Prices: Advanced Regression Techniques&#xff0c;是一个Regression问题。 二、Deeplearning的一般流程 结合李航《统计学习方法》中对机器学习流程的总结&#xff0c;分为data、…

centos8安装

一. 下载centos centos下载 下载镜像版 mini版本 二&#xff0c;安装centos8 虚拟机安装 可 打开虚拟机安装centos 选择下载的镜像 配置磁盘大小 配置资源 配置虚拟机内存&#xff0c;处理器个数等. 安装成功后&#xff0c;也可配置

alert,confirm和prompt

1.警告消息框alertalert 方法有一个参数&#xff0c;即希望对用户显示的文本字符串。该字符串不是 HTML 格式。该消息框提供了一个“确定”按钮让用户关闭该消息框&#xff0c;并且该消息框是模式对话框&#xff0c;也就是说&#xff0c;用户必须先关闭该消息框然后才能继续进行…

(一)卷积网络之基础要点

一、提出问题 对于生活生产中的表格数据&#xff0c;至多也就上百维&#xff0c;而且表格数据的行与行之间没有序列和位置上的关系&#xff0c;所以用传统的机器学习算法就可轻松的解决这些问题。但是到了图片数据&#xff0c;传统机器学习就非常吃力了&#xff0c;一个普通的…

Windows Phone本地数据库(SQLCE):3、[table]attribute(翻译) (转)

这是“windows phone mango本地数据库&#xff08;sqlce&#xff09;”系列短片文章的第三篇。 为了让你开始在Windows Phone Mango中使用数据库&#xff0c;这一系列短片文章将覆盖所有你需要知道的知识点。这个时候我将谈谈有关你使用windows phone mango本地数据库时使用[ta…

Java代理模式——静态代理动态代理

proxy mode1. 什么是代理1.1 例子解释1.2 作用2. 静态代理2.1 优缺点分析2.2 以厂家卖u盘用代码说明3. 动态代理3.1 什么是动态代理3.2 jdk实现原理3.3 代码描述1. 什么是代理 1.1 例子解释 1. 生活中的例子&#xff0c;常见的商家卖东西&#xff0c; 商家就是代理&#xff0…

一、Insertion sort

1. 问题 2. 算法 2.1 伪代码 2.2 算法思想 2.3 手工演示 2.4 Python实现 《算法导论》一书数组默认从111开始&#xff0c;这种方式适合算法分析&#xff0c;从000开始适合程序实现&#xff0c;为了能和伪代码一致便于对比&#xff0c;后边所有的Python实现中数组均从111开始。…

windows 2502 2503 错误解决

1. 错误原因 1. c盘下temp文件夹权限问题 2. c盘temp文件夹环境变量配置错误&#xff0c;或者更改了2. 造成的问题 每次安装msi文件或者卸载msi程序包时&#xff0c;都会弹出此恶心的错误...3. 解决 1. 针对问题一&#xff0c;解决&#xff0c;以管理员身份安装或者卸载 win…

Hibernate学习笔记

Hibernate是什么&#xff1a; Hibernate 架构&#xff1a; 下载、安装、必要的 jar包、环境CLASSPAST的设置&#xff08;此步骤省略&#xff09; Hibernate框架的使用步骤&#xff1a;1、创建Hibernate的配置文件&#xff08;hibernate.cfg.xml&#xff09;2、创建持久化类&…

二、Merge sort

1 问题 2 算法 2.1 伪代码 2.2 算法思想 2.3 手工演示 2.4 Python实现 # -*- coding: utf-8 -*- import sysdef merge(A, p, q, r):n1 q - p 1n2 r - qL [0] * (n1 2)R [0] * (n2 2)for i in range(1, n11):L[i] A[pi-1]for j in range(1, n21):R[j] A[qj]L[n11] 6…

cglib实现动态代理

对目标方法实现前置或者后置增强&#xff0c; 是在程序动态运行时加入增强方法的。 1. 目标类 package com.lovely.proxy.cglib;/*** 目标类* author echo lovely* date 2020/7/26 15:20*/ public class Target {public void save() {System.out.println("sve running..…

fragment嵌套,viewpager嵌套 不能正确显示

转帖&#xff1a;http://blog.csdn.net/mybook1122/article/details/24003343 通常为 viewPager.setAdapter(new MyFragmentPagerAdapter(getSupportFragmentManager(), fragmentsList)); 替换为 mPager.setAdapter(new MyFragmentPagerAdapter(getChildFragmentManager(), fra…

三、递归树分析法

1 问题 2 解决思路 使用递归树猜想一个上界&#xff0c;使用归纳法证明上界也是下界。 2.1 使用递归树&#xff08;recursion tree&#xff09;猜想结论&#xff08;不严谨&#xff09; 使用递归树两点&#xff1a;1⃣️逐行展开&#xff1b;2⃣️逐行相加&#xff1b; 逐行…

Linux文件查看/编辑方法介绍

转载:https://www.centos.bz/2011/10/linux-file-view-edit/ cat 命令介绍 cat 命令的原含义为连接(concatenate)&#xff0c; 用于连接多个文件内容并输出到标准输出流中&#xff08;标准输出流默认为屏幕&#xff09;。实际运用过程中&#xff0c;我们常使用它来显示文件内容…

html5input表单标签新属性

初探h5一&#xff0c;h5 新增表单类型二&#xff0c;新增表单属性三&#xff0c;code demo一&#xff0c;h5 新增表单类型 •email 邮箱地址•url 网络地址•number 数字框•range 滑块•Date pickers (date, month, week, time, datetime, datetime-local) 日期时间框•search…