[译] RNN 循环神经网络系列 2:文本分类

  • 原文地址:RECURRENT NEURAL NETWORKS (RNN) – PART 2: TEXT CLASSIFICATION
  • 原文作者:GokuMohandas
  • 译文出自:掘金翻译计划
  • 本文永久链接:github.com/xitu/gold-m…
  • 译者:Changkun Ou
  • 校对者:yanqiangmiffy, TobiasLee

本系列文章汇总

  1. RNN 循环神经网络系列 1:基本 RNN 与 CHAR-RNN
  2. RNN 循环神经网络系列 2:文本分类
  3. RNN 循环神经网络系列 3:编码、解码器
  4. RNN 循环神经网络系列 4:注意力机制
  5. RNN 循环神经网络系列 5:自定义单元

RNN 循环神经网络系列 2:文本分类

在第一篇文章中,我们看到了如何使用 TensorFlow 实现一个简单的 RNN 架构。现在我们将使用这些组件并将其应用到文本分类中去。主要的区别在于,我们不会像 CHAR-RNN 模型那样输入固定长度的序列,而是使用长度不同的序列。

文本分类

这个任务的数据集选用了来自 Cornell 大学的语句情绪极性数据集 v1.0,它包含了 5331 个正面和负面情绪的句子。这是一个非常小的数据集,但足够用来演示如何使用循环神经网络进行文本分类了。

我们需要进行一些预处理,主要包括标注输入、附加标记(填充等)。请参考完整代码了解更多。

预处理步骤

  1. 清洗句子并切分成一个个 token;
  2. 将句子转换为数值 token;
  3. 保存每个句子的序列长。

Screen Shot 2016-10-05 at 7.32.36 PM.png

如上图所示,我们希望在计算完成时立即对句子的情绪做出预测。引入额外的填充符会带来过多噪声,这样的话你模型的性能就会不太好。注意:我们填充序列的唯一原因是因为需要以固定大小的批量输入进 RNN。下面你会看到,使用动态 RNN 还能避免在序列完成后的不必要计算。

模型

代码:

class model(object):def __init__(self, FLAGS):# 占位符self.inputs_X = tf.placeholder(tf.int32,shape=[None, None], name='inputs_X')self.targets_y = tf.placeholder(tf.float32,shape=[None, None], name='targets_y')self.dropout = tf.placeholder(tf.float32)# RNN 单元stacked_cell = rnn_cell(FLAGS, self.dropout)# RNN 输入with tf.variable_scope('rnn_inputs'):W_input = tf.get_variable("W_input",[FLAGS.en_vocab_size, FLAGS.num_hidden_units])inputs = rnn_inputs(FLAGS, self.inputs_X)#initial_state = stacked_cell.zero_state(FLAGS.batch_size, tf.float32)# RNN 输出seq_lens = length(self.inputs_X)all_outputs, state = tf.nn.dynamic_rnn(cell=stacked_cell, inputs=inputs,sequence_length=seq_lens, dtype=tf.float32)# 由于使用了 seq_len[0],state 自动包含了上一次的对应输出# 因为 state 是一个带有张量的元组outputs = state[0]# 处理 RNN 输出with tf.variable_scope('rnn_softmax'):W_softmax = tf.get_variable("W_softmax",[FLAGS.num_hidden_units, FLAGS.num_classes])b_softmax = tf.get_variable("b_softmax", [FLAGS.num_classes])# Logitslogits = rnn_softmax(FLAGS, outputs)probabilities = tf.nn.softmax(logits)self.accuracy = tf.equal(tf.argmax(self.targets_y,1), tf.argmax(logits,1))# 损失函数self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits, self.targets_y))# 优化self.lr = tf.Variable(0.0, trainable=False)trainable_vars = tf.trainable_variables()# 使用梯度截断来避免梯度消失和梯度爆炸grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, trainable_vars), FLAGS.max_gradient_norm)optimizer = tf.train.AdamOptimizer(self.lr)self.train_optimizer = optimizer.apply_gradients(zip(grads, trainable_vars))# 下面是用于采样的值# (在每个单词后生成情绪)# 取所有输出作为第一个输入序列# (由于采样,只需一个输入序列)sampling_outputs = all_outputs[0]# Logitssampling_logits = rnn_softmax(FLAGS, sampling_outputs)self.sampling_probabilities = tf.nn.softmax(sampling_logits)# 保存模型的组件self.global_step = tf.Variable(0, trainable=False)self.saver = tf.train.Saver(tf.all_variables())def step(self, sess, batch_X, batch_y=None, dropout=0.0,forward_only=True, sampling=False):input_feed = {self.inputs_X: batch_X,self.targets_y: batch_y,self.dropout: dropout}if forward_only:if not sampling:output_feed = [self.loss,self.accuracy]elif sampling:input_feed = {self.inputs_X: batch_X,self.dropout: dropout}output_feed = [self.sampling_probabilities]else: # 训练output_feed = [self.train_optimizer,self.loss,self.accuracy]outputs = sess.run(output_feed, input_feed)if forward_only:if not sampling:return outputs[0], outputs[1]elif sampling:return outputs[0]else: # 训练return outputs[0], outputs[1], outputs[2]复制代码

上面的代码就是我们的模型代码,它在训练的过程中使用了输入的文本。注意:为了清楚起见,我们决定将批量数据的大小保存在我们的输入和目标占位符中,但是我们应该让它们独立于一个特定的批量大小之外。由于这个特定的批量大小依赖于 batch_size,如果我们这么做,那么我们就还得输入一个 initial_state。我们通过嵌入他们来为每个数据序列来输入 token。实践策略表明,我们在输入文本上使用 skip-gram 模型预训练嵌入权重能够取得更好的性能。

在此模型中,我们再次使用 dynamic_rnn,但是这次我们提供了sequence_length 参数的值,它是一个包含每个序列长度的列表。这样,我们就可以避免在输入序列的最后一个词之后进行的不必要的计算。length 函数就用来获取这个列表的长度,如下所示。当然,我们也可以在外面计算seq_len,再通过占位符进行传递。

def length(data):relevant = tf.sign(tf.abs(data))length = tf.reduce_sum(relevant, reduction_indices=1)length = tf.cast(length, tf.int32)return length复制代码

由于我们填充符 token 为 0,因此可以使用每个 token 的 sign 性质来确定它是否是一个填充符 token。如果输入大于 0,则 tf.sign 为 1;如果输入为 0,则为 tf.sign 为 0。这样,我们可以逐步通过列索引来获得 sign 值为正的 token 数量。至此,我们可以将这个长度提供给 dynamic_rnn 了。

注意:我们可以很容易地在外部计算 seq_lens,并将其作为占位符进行传参。这样我们就不用依赖于 PAD_ID = 0 这个性质了。

一旦我们从 RNN 拿到了所有的输出和最终状态,我们就会希望分离对应输出。对于每个输入来说,将具有不同的对应输出,因为每个输入长度不一定不相同。由于我们将 seq_len 传给了 dynamic_rnn,而 state 又是最后一个对应输出,我们可以通过查看 state 来找到对应输出。注意,我们必须取 state[0],因为返回的 state 是一个张量的元组。

其他需要注意的事情:我并没有使用 initial_state,而是直接给 dynamic_rnn 设置 dtype。此外,dropout 将根据 forward_only 与否,作为参数传递给 step()

推断

总的来说,除了单个句子的预测外,我还想为具有一堆样本句子整体情绪进行预测。我希望看到的是,每个单词都被 RNN 读取后,将之前的单词分值保存在内存中,从而查看预测分值是怎样变化的。举例如下(值越接近 0 表明越靠近负面情绪):

Screen Shot 2016-10-05 at 8.34.51 PM.png

注意:这是一个非常简单的模型,其数据集非常有限。主要目的只是为了阐明它是如何搭建以及如何运行的。为了获得更好的性能,请尝试使用数据量更大的数据集,并考虑具体的网络架构,比如 Attention 模型、Concept-Aware 词嵌入以及隐喻(symbolization to name)等等。

损失屏蔽(这里不需要)

最后,我们来计算 cost。你可能会注意到我们没有做任何损失屏蔽(loss masking)处理,因为我们分离了对应输出,仅用于计算损失函数。然而,对于其他诸如机器翻译的任务来说,我们的输出很有可能还来自填充符 token。我们不想考虑这些输出,因为传递了 seq_lens 参数的 dynamic_rnn 将返回 0。下面这个例子比较简单,只用来说明这个实现大概是怎么回事;我们这里再一次使用了填充符 token 为 0 的性质:

# 向量化 logits 和目标
targets = tf.reshape(targets, [-1]) # 将张量 targets 转为向量
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, targets)
mask = tf.sign.(tf.to_float(targets)) # targets 为 0 则输出为 0, target < 0 则输出为 -1, 否则 为 1
masked_losses = mask*losses # 填充符所在位置的贡献为 0复制代码

首先我们要将 logits 和 targets 向量化。为了使 logits 向量化,一个比较好的办法是将 dynamic_rnn 的输出向量化为 [-1,num_hidden_units] 的形状,然后乘以 softmax 权重 [num_hidden_units,num_classes]。通过损失屏蔽操作,就可以消除填充符所在位置贡献的损失。

代码

GitHub 仓库 (正在更新,敬请期待!)

张量形状变化的参考

原始未处理过的文本 X 形状为 [N,]y 的形状为 [N, C],其中 C 是输出类别的数量(这些是手动完成的,但我们需要使用独热编码来处理多类情况)。

然后 X 被转化为 token 并进行填充,变成了 [N, <max_len>]。我们还需要传递形状为 [N,]seq_len 参数,包含每个句子的长度。

现在 Xseq_leny 通过这个模型首先嵌入为 [NXD],其中 D 是嵌入维度。X 便从 [N, <max_len>] 转换为了 [N, <max_len>, D]。回想一下,X 在这里有一个中间表示,它被独热编码为了 [N, <max_len>, <num_words>]。但我们并不需要这么做,因为我们只需要使用对应词的索引,然后从词嵌入权重中取值就可以了。

我们需要将这个嵌入后的 X 传递给 dynamic_rnn 并返回 all_outputs[N, <max_len>, D])以及 state[1, N, D])。由于我们输入了 seq_lens,对于我们而言它就是最后一个对应的状态。从维度的角度来说,你可以看到, all_outputs 就是来自 RNN 的对于每个句子中的每个词的全部输出结果。然而,state 仅仅只是每个句子的最后一个对应输出。

现在我们要输入 softmax 权重,但在此之前,我们需要通过取第一个索引(state[0])来把状态从 [1,N,D] 转换为[N,D]。如此便可以通过与 softmax 权重 [D,C] 的点积,来得到形状为 [N,C] 的输出。其中,我们做指数级 softmax 运算,然后进行正则化,最终结合形状为 [N,C]target_y 来计算损失函数。

注意:如果你使用了基本的 RNN 或者 GRU,从 dynamic_rnn 返回的 all_outputsstate 的形状是一样的。但是如果使用 LSTM 的话,all_outputs 的形状就是 [N, <max_len>, D]state 的形状为 [1, 2, N, D]


掘金翻译计划 是一个翻译优质互联网技术文章的社区,文章来源为 掘金 上的英文分享文章。内容覆盖 Android、iOS、React、前端、后端、产品、设计 等领域,想要查看更多优质译文请持续关注 掘金翻译计划、官方微博、知乎专栏。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/257206.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[置顶] Android开发者官方网站文档 - 国内踏得网镜像

Mark 一下&#xff1a; 镜像地址&#xff1a;http://wear.techbrood.com/index.html Android DevelopTools: http://www.androiddevtools.cn/ 转载于:https://www.cnblogs.com/superle/p/4561856.html

Java实现选择排序

选择排序思想就是选出最小或最大的数与第一个数交换&#xff0c;然后在剩下的数列中重复完成该动作。 package Sort;import java.util.Arrays;public class SelectionSort {public static int selectMinKey(int[] list, int beginIdx) {int idx beginIdx;int temp list[begin…

ASP.NET MVC中ViewData、ViewBag和TempData

1.ViewData 1.1 ViewData继承了IDictionary<string, object>,因此在设置ViewData属性时,传入key必须要字符串型别,value可以是任意类型。 1.2 ViewData它只会存在这次的HTTP要求而已,而不像Session可以将数据带到下HTTP要求。 public class TestController : Controller{…

java 正则表达式验证邮箱格式是否合规 以及 正则表达式元字符

package com.ykmimi.testtest; /*** 测试邮箱地址是否合规* author ukyor**/ public class EmailTest {public static void main(String[] args) {//定义要匹配的Email地址的正则表达式//其中\w代表可用作标识符的字符,不包括$. \w表示多个// \\.\\w表示点.后面有\w 括号{2,3}…

镜头选型

景深&#xff1a; 光圈越大&#xff0c;光圈值越小&#xff0c;景深越小 光圈越小&#xff0c;光圈值越大&#xff0c;景深越深 焦距越长&#xff0c;视角越小&#xff0c;主体像越大&#xff0c;景深越小 主体越近&#xff0c;景深越小

迅雷账号

账号 jiangchnangli:1 密码 892812 网址 http://www.s8song.net/read-htm-tid-4906661.html漫晴xydcq7681转载于:https://www.cnblogs.com/wlzhang/p/4563118.html

【Swift学习】Swift编程之旅---ARC(二十)

Swift使用自动引用计数(ARC)来跟踪并管理应用使用的内存。大部分情况下&#xff0c;这意味着在Swift语言中&#xff0c;内存管理"仍然工作"&#xff0c;不需要自己去考虑内存管理的事情。当实例不再被使用时&#xff0c;ARC会自动释放这些类的实例所占用的内存。然而…

像元大小及精度

说完了光学系统的分辨率之后我们来看看相机的图像分辨率。图像分辨率比较好理解&#xff0c;就是单位距离内的像用多少个像素来显示。以我们的ORCA-Flash4.0为例&#xff0c;芯片的像元大小为 6.5 μm&#xff0c;在 40X物镜的放大倍率下&#xff0c;1 μm的物经光学系统放大为…

转:传入的表格格式数据流(TDS)远程过程调用(RPC)协议流不正确 .

近期在做淘宝客的项目&#xff0c;大家都知道&#xff0c;淘宝的商品详细描述字符长度很大&#xff0c;所以就导致了今天出现了一个问题 VS的报错是这样子的 ” 传入的表格格式数据流(TDS)远程过程调用(RPC)协议流不正确“ 还说某个desricption 过长之类的话 直觉告诉我&#…

合并bin文件-----带boot发布版本比较好用的bat(便捷版)

直接上图上代码&#xff08;代码在结尾&#xff09;&#xff0c;有不会用的可以留言&#xff1a; 第一步&#xff1a;工程介绍&#xff0c;关键点--- 1.bat文件放所在app和boot工程的同级目录下 2.release为运行bat自动生成文件夹 第二步&#xff1a;合版.bat 针对具体项目需…

第五天 断点续传和下载

1 断点续传&#xff0c; 2.多线程下载原理 3.httpUtils 多线程断点下载的使用。 ------------- 1.拿到需要下载的文件的大小&#xff0c;和需要初始的线程数 2.得到每个线程需要下载的大小&#xff0c;最后一个线程负责将剩下的数据全部下载。 3.同时需要设置一个与下载文件同大…

关于cmake从GitHub上下载的源码启动时报错的问题

关于cmake从GitHub上下载的源码启动时报错的问题&#xff1a; 由于cmake会产生all_build和zero_check两个project&#xff0c;此时需要右击鼠标将需要运行的项目设为启动项&#xff0c;在进行编译&#xff0c;现只针对“找不到all_build文件“的出错信息&#xff0c;若有相关编…

一个人的Scrum之准备工作

在2012年里&#xff0c;我想自己一人去实践一下Scrum&#xff0c;所以才有了这么一个开篇。 最近看了《轻松的Scrum之旅》这本书&#xff0c;感觉对我非常有益。书中像讲述故事一样描述了在执行Scrum过程中的点点滴滴&#xff0c; 仿佛我也跟着进行了一次成功的Scrum。同样的&a…

Elementary OS安装Chrome

elementary os 官方网站&#xff1a;https://elementary.io/ 这os是真好看&#xff01;首先这是基于ubuntu的&#xff0c;所以可以安装ubuntu的软件&#xff01; 电脑必备浏览器必须是chrome呀&#xff01;下载地址&#xff1a; https://www.chrome64bit.com/index.php/google…

vs+opencv编译出现内存问题

将图片路径改为项目下的相对路径&#xff0c;如 …\data\01.jpg; 其中…表示项目所在目录的上级目录&#xff0c;不要用绝对路径&#xff0c;具体原因未知&#xff0c;同时&#xff0c;出现opencv_worldxxx.lib找不到情况&#xff0c;1.链接中依赖项是否写错&#xff08;英文输…

runtime--实现篇02(Category增加属性)

在iOS设计Category中&#xff0c;默认不能直接添加属性&#xff0c;如果分类中通过property修饰的属性&#xff0c;只会生成setter和getter的声明&#xff0c; 不会生成其实现&#xff1b;因此&#xff0c;如果一定要添加属性的话&#xff0c;需要借助runtime特性&#xff0c;通…

spark、oozie、yarn、hdfs、zookeeper、

为什么80%的码农都做不了架构师&#xff1f;>>> spark、 oozie:任务调度 yarn:资源调度 hdfs:分布式文件系统 zookeeper、 转载于:https://my.oschina.net/u/3709135/blog/1556661

关于halcon多区域挑选有关算法的自我理解(tuple_sort_index)

多区域根据面积挑选想要的obj area_center&#xff08;regions&#xff0c;areas&#xff09; tuple_sort_index(areas&#xff0c;indexs) tuple_sort_index算子将一组数组进行升序排列&#xff0c;然后将其在原数组的index按升序放入indexs中&#xff0c; 例如原数组areas[20…

JLOI2016 方

bzoj4558 真是一道非常excited的题目啊…JLOI有毒 题目大意&#xff1a;给一个(N1)*(M1)的网格图&#xff0c;格点坐标为(0~N,0~M)&#xff0c;现在挖去了K个点&#xff0c;求剩下多少个正方形&#xff08;需要注意的是正方形可以是斜着的&#xff0c;多斜都可以&#xff09; N…

opencv 直方图反向投影

转载至&#xff1a;http://www.cnblogs.com/zsb517/archive/2012/06/20/2556508.html 直方图反向投影式通过给定的直方图信息&#xff0c;在图像找到相应的像素分布区域&#xff0c;opencv提供两种算法&#xff0c;一个是基于像素的&#xff0c;一个是基于块的。 使用方法不写了…