【TensorFlow】稀疏矢量

  • 官方Document: https://tensorflow.google.cn/api_guides/python/sparse_ops
  • 开发测试环境:
    • Win10
    • Python 3.6.4
    • tensorflow-gpu 1.6.0

SparseTensor与SparseTensorValue的理解

SparseTensor(indices, values, dense_shape)

稀疏矢量的表示

  • indices shape为[N, ndims]的2-D int64矢量,用以指定非零元素的位置,比如indices=[[1,3], [2,4]]表示[1,3]和[2,4]位置的元素为非零元素。
  • values shape为[N]的1-D矢量,对应indices所指位置的元素值
  • dense_shape shape为[ndims]的1-D矢量,代表稀疏矩阵的shape
SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
>>
[[1, 0, 0, 0][0, 0, 2, 0][0, 0, 0, 0]]SparseTensor(indices=[[0], [3]], values=[4, 6], dense_shape=[7])
>>[4, 0, 0, 6, 0, 0, 0]

稀疏矢量的封装并不直观,可以通过稀疏矢量的方式构建矢量(sparse_to_dense)或者将稀疏矢量转换成矢sparse_tensor_to_dense量的方式来感受一下:

def sparse_to_dense(sparse_indices,output_shape,sparse_values,default_value=0,validate_indices=True,name=None)
  • sparse_indices sparse_indices:稀疏矩阵中那些个别元素对应的索引值。
    • sparse_indices是个数,那么它只能指定一维矩阵的某一个元素
    • sparse_indices是个向量,那么它可以指定一维矩阵的多个元素
    • sparse_indices是个矩阵,那么它可以指定二维矩阵的多个元素
  • output_shape 输出的稀疏矩阵的shape
  • sparse_value 个别元素的值
    • sparse_values是个数:所有索引指定的位置都用这个数
    • sparse_values是个向量:输出矩阵的某一行向量里某一行对应的数(所以这里向量的长度应该和输出矩阵的行数对应,不然报错)
  • default_value:未指定元素的默认值,一般如果是稀疏矩阵的话就是0了

实例展示

import tensorflow as tf  
import numpy  BATCHSIZE=6label=tf.expand_dims(tf.constant([0,2,3,6,7,9]),1)
index=tf.expand_dims(tf.range(0, BATCHSIZE),1)
# use a matrix
concated = tf.concat([index, label], 1)   # [[0, 0], [0, 2], [0, 3], [0, 6], [0, 7], [0, 9]] (6,2)
onehot_labels = tf.sparse_to_dense(concated, [BATCHSIZE,10], 1.0, 0.0)# use a vector
sparse_indices2=tf.constant([1,3,4])
onehot_labels2 = tf.sparse_to_dense(sparse_indices2, [10], 1.0, 0.0)#can use# use a scalar
sparse_indices3=tf.constant(5)
onehot_labels3 = tf.sparse_to_dense(sparse_indices3, [10], 1.0, 0.0)sparse_tensor_00 = tf.SparseTensor(indices=[[0,0,0], [1,1,2]], values=[4, 6], dense_shape=[2,2,3])
dense_tensor_00 = tf.sparse_tensor_to_dense(sparse_tensor_00)with tf.Session(config=config) as sess:result1=sess.run(onehot_labels)result2 = sess.run(onehot_labels2)result3 = sess.run(onehot_labels3)result4 = sess.run(dense_tensor_00)print ("This is result1:")print (result1)print ("This is result2:")print (result2)print ("This is result3:")print (result3)print ("This is result4:")print (result4)

输出结果如下

[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.][0. 0. 1. 0. 0. 0. 0. 0. 0. 0.][0. 0. 0. 1. 0. 0. 0. 0. 0. 0.][0. 0. 0. 0. 0. 0. 1. 0. 0. 0.][0. 0. 0. 0. 0. 0. 0. 1. 0. 0.][0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
This is result2:
[0. 1. 0. 1. 1. 0. 0. 0. 0. 0.]
This is result3:
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
This is result4:
[[[4 0 0][0 0 0]][[0 0 0][0 0 6]]]

区别

两者的区别可以通过应用来说起

If you would like to define the tensor outside the graph, e.g. define the sparse tensor for later data feed, use SparseTensorValue. In contrast, if the sparse tensor is defined in graph, use SparseTensor

在graph定义sparse_placeholder,在feed中需要使用SparseTensorValue

x_sp = tf.sparse_placeholder(dtype=tf.float32)
W = tf.Variable(tf.random_normal([6, 6]))
y = tf.sparse_tensor_dense_matmul(sp_a=x_sp, b=W)init = tf.global_variables_initializer()
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
sess.run(init)stv = tf.SparseTensorValue(indices=[[0, 0], [1, 2]], values=[1.1, 1.2], 
dense_shape=[2,6])
result = sess.run(y,feed_dict={x_sp:stv})print(result)

在graph中做定义需要使用SparseTensor

indices_i = tf.placeholder(dtype=tf.int64, shape=[2, 2])
values_i = tf.placeholder(dtype=tf.float32, shape=[2])
dense_shape_i = tf.placeholder(dtype=tf.int64, shape=[2])
st = tf.SparseTensor(indices=indices_i, values=values_i, dense_shape=dense_shape_i)W = tf.Variable(tf.random_normal([6, 6]))
y = tf.sparse_tensor_dense_matmul(sp_a=st, b=W)init = tf.global_variables_initializer()
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
sess.run(init)result = sess.run(y,feed_dict={indices_i:[[0, 0], [1, 2]], values_i:[1.1, 1.2], dense_shape_i:[2,6]})print(result)

在feed中应用SparseTensor,需要使用运算

x = tf.sparse_placeholder(tf.float32)
y = tf.sparse_reduce_sum(x)config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
with tf.Session(config=config) as sess:indices = np.array([[3, 2, 0], [4, 5, 1]], dtype=np.int64)values = np.array([1.0, 2.0], dtype=np.float32)shape = np.array([7, 9, 2], dtype=np.int64)print(sess.run(y, feed_dict={x: tf.SparseTensorValue(indices, values, shape)}))  # Will succeed.print(sess.run(y, feed_dict={x: (indices, values, shape)}))  # Will succeed.sp = tf.SparseTensor(indices=indices, values=values, dense_shape=shape)sp_value = sp.eval(session=sess)print(sp_value)print(sess.run(y, feed_dict={x: sp_value}))  # Will succeed.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/508646.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Numpy】array操作总结

官方Document: https://www.numpy.org/devdocs/reference/routines.array-manipulation.html开发测试环境 Win10Python 3.6.4NumPy 1.14.2 Basic operations 函数原型作用[copyto](dst, src[, casting, where])Copies values from one array to another, broadcasting as nec…

【TensorFlow】conv2d函数参数解释以及padding理解

卷积conv2d CNN在深度学习中有着举足轻重的地位,主要用于特征提取。在TensorFlow中涉及的函数是tf.nn.conv2d。 tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpuTrue, data_format“NHWC”, dilations[1, 1, 1, 1], nameNone) input 代表做卷积的…

卷积与傅立叶变换

一、卷积 1、一维的卷积 连续: 在泛函分析中,卷积是通过两个函数f(x)f(x)和g(x)g(x)生成第三个函数的一种算子,它代表的意义是:两个函数中的一个(我取g(x)g(x),可以任意取)函数,把g(x)g(x)经过翻转平移,…

海明纠错码工作原理

海明纠错码 海明码(Hamming Code)是一个可以有多个校验位,具有检测并纠正一位错误代码的纠错码,所以它也仅用于信道特性比较好的环境中,如以太局域网中,因为如果信道特性不好的情况下,出现的错…

OpenCV-Python bindings是如何生成的(1)

翻译自How OpenCV-Python Bindings Works? 目标 学习 OpenCV-Python bindings是如何生成的如何为Python扩展新的opencv模块 OpenCV-Python bindings是如何生成的 在OpenCV里,所有算法都是用C实现的。但是这些算法可以在别的语言里使用,比如Python&…

OpenCV-Python bindings是如何生成的(2)

OpenCV-Python bindings生成流程 通过上篇文章和opencv python模块中的CMakeLists.txt文件,可以了解到opencv-python bindings生成的整个流程: 生成headers.txt文件 将每个模块的头文件添加到list中,通过一些关键词过滤掉一些不需要扩展的头文件&#x…

【TensorFlow】学习资源汇总以及知识总结

官方资源 官方网站 https://tensorflow.org 非翻墙神器不能访问也(关键是我用了翻墙神器也没能访问)伪官方网站 https://tensorflow.google.cn/ 墙内的人可以查阅的资料github https://github.com/tensorflow/tensorflow官方提供的models以及tutorial h…

机器学习资源锦集

http://www.cnblogs.com/pinard 十年码农,对数学统计学,数据挖掘,机器学习,大数据平台,大数据平台应用开发,大数据可视化感兴趣。github 深度学习 【深度学习】批归一化(Batch Normalization&…

获取训练数据的方式

下载搜狗词库 https://pinyin.sogou.com/dict/ 在官网搜索相关的词库下载,比如地名等,然后使用脚本将此条转换成txt保存, 来源 # -*- coding: utf-8 -*- import os import sys import struct # 主要两部分 # 1.全局拼音表,貌似…

浅谈python MRO与Mixin模式

MRO(Method Resolution Order) In object-oriented programming languages with multiple inheritance, the diamond problem (sometimes referred to as the “deadly diamond of death”) is an ambiguity that arises when two classes B and C inherit from A, and class D…

CentOS7开发环境搭建(2)

关闭SELinux # 查看 $ getenforce Disabled $ sestatus SELinux status: enabled SELinuxfs mount: /sys/fs/selinux SELinux root directory: /etc/selinux Loaded policy name: targeted Current mode: …

IntelliJ IDEA开发环境应用

安装 下载windows压缩包获取帮助: idea.medeming.com/jihuoma 常用设置 全局设置,对新建的工程生效 【File】【Other Settings】【Setings for New Projects…】 比如配置maven的路径以及配置文件的路径,基本设置一次即可,不需要每次新建工…

tcp状态机-三次握手-四次挥手以及常见面试题

TCP状态机介绍 在网络协议栈中,目前只有TCP提供了一种面向连接的可靠性数据传输。而可靠性,无非就是保证,我发给你的,你一定要收到。确保中间的通信过程中,不会丢失数据和乱序。在TCP保证可靠性数据传输的实现来看&am…

Visual studio Code的C/C++开发环境搭建

文章目录VS CodeC/C环境配置环境准备使用实例基于 VSCode 的远程开发平台环境准备参考VS Code Visual Studio Code(简称VS Code)是一个由微软开发,同时支持Windows 、 Linux和macOS等操作系统且开放源代码的代码编辑器,它支持测试…

Linux网络编程--文件描述符

文件描述符 在Unix和Unix-like操作系统中,文件描述符(file descriptor, FD)是一个文件或者像pipe或者network socket等之类的输入/输出源的唯一标识。 文件描述符通常是一个非负整数,负数通常代表无值或者错误。 文件描述符是POSIX API的一部分。每个除…

深信服 linux软件开发面试题整理

1、结构体可以进行比较 int memcmp ( const void * ptr1, const void * ptr2, size_t num ); Compare two blocks of memory Compares the first num bytes of the block of memory pointed by ptr1 to the first num bytes pointed by ptr2, returning zero if they all match…

大端小端模式判断以及数据转换

简介 在计算机系统中,我们是以字节为单位的,每个地址单元都对应着一个字节,一个字节为 8bit。但是在C语言中除了8bit的char之外,还有16bit的short型,32bit的long型(要看具体的编译器)&#xff…

MSYS2下搭建Qt开发环境

最近随意浏览了一下俺们大省会城市的招聘信息,发现C招聘中涉及Qt经验的要求有不少,为了牛奶和面包,决心深入一下Qt开发。本篇文章由此而出。 Qt 关于Qt的人生经历在这不在累赘,资料随处可得,这里只记录干货。 环境搭…

CentOS7开发环境搭建(1)

文章目录BIOS开启VT支持U盘安装系统(2019-03-11)CentOS DNS配置CentOS网络配置配置静态IP克隆虚拟机网卡名称变更 CentOS6.5时间配置安装VMWare-tools用户管理 (2019-03-15 7.6.1810)给一般账号 root 权限Samba服务配置安装必备软件获取本机公网ipyum源和第三方库源管理配置本地…

ACM 欧拉公式

给出一个数X,求小于X的与X互质的数的个数,使用欧拉公式。 如果x1*x2*...*xnX,则个数nX*(1-1/x1)*(1-/x2)*... 使用这个的题目,超典型 相遇周期(HDOJ)