移动端深度学习部署:TFlite

1.TFlite介绍

1TFlite概念

  • tflite是谷歌自己的一个轻量级推理库。主要用于移动端。
  • tflite使用的思路主要是从预训练的模型转换为tflite模型文件,拿到移动端部署。
  • tflite的源模型可以来自tensorflowsaved model或者frozen model,也可以来自keras

2TFlite优点

用Flatbuffer序列化模型文件,这种格式磁盘占用少,加载快

可以对模型进行量化,把float参数量化为uint8类型,模型文件更小、计算更快。

可以对模型进行剪枝、结构合并和蒸馏。

对NNAPI的支持。可调用安卓底层的接口,把异构的计算能力利用起来。

3TFlite量化

a.量化的好处        

  • 较小的存储大小:小模型在用户设备上占用的存储空间更少。例如,一个使用小模型的 Android 应用在用户的移动设备上会占用更少的存储空间。
  • 较小的下载大小:小模型下载到用户设备所需的时间和带宽较少。
  • 更少的内存用量:小模型在运行时使用的内存更少,从而释放内存供应用的其他部分使用,并可以转化为更好的性能和稳定性。

b.量化的过程

        tflite的量化并不是全程使用uint8计算。而是存储每层的最大和最小值,然后把这个区间线性分成 256 个离散值,于是此范围内的每个浮点数可以用八位 (二进制) 整数来表示,近似为离得最近的那个离散值。比如,最小值是 -3 而最大值是 6 的情形,0 字节表示 -3,255 表示 6,而 128 是 1.5。每个操作都先用整形计算,输出时重新转换为浮点型。下图是量化Relu的示意图。

Tensorflow官方量化文档

 

 

c.量化的实现

训练后动态量化

import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
#converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_model1 = converter.convert()
open("xxx.tflite", "wb").write(tflite_model1)

训练后float16量化

import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
tflite_model1 = converter.convert()
open("xxx.tflite", "wb").write(tflite_model1)

训练后int8量化

import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
def representative_dataset_gen():
  for _ in range(num_calibration_steps):
    # Get sample input data as a numpy array in a method of your choosing.
    yield [input]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8  # or tf.uint8
converter.inference_output_type = tf.int8  # or tf.uint8
tflite_model1 = converter.convert()
open("xxx.tflite", "wb").write(tflite_model1)

注:float32和float16量化可以运行再GPU上,int8量化只能运行再CPU上

2.TFlite模型转换

1)在训练的时候就保存tflite模型

import tensorflow as tf
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")
with tf.Session() as sess:
  tflite_model = tf.lite.toco_convert(sess.graph_def, [img], [out])
  open("converteds_model.tflite", "wb").write(tflite_model)

2)使用其他格式的TensorFlow模型转换成tflite模型

        首先要安装Bazel,参考:https://docs.bazel.build/versions/master/install-ubuntu.html ,只需要完成Installing using binary installer这一部分即可。然后克隆TensorFlow的源码:

git clone https://github.com/tensorflow/tensorflow.git

        接着编译转换工具,这个编译时间可能比较长:

cd tensorflow/
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/lite/toco:toco

        获得到转换工具之后,开始转换模型,以下操作是冻结图:

  • input_graph对应的是.pb文件;
  • input_checkpoint对应mobilenet_v1_1.0_224.ckpt.data-00000-of-00001,使用时去掉后缀名。
  • output_node_names这个可以在mobilenet_v1_1.0_224_info.txt中获取。

./freeze_graph --input_graph=/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_frozen.pb \
  --input_checkpoint=/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \
  --input_binary=true \
  --output_graph=/tmp/frozen_mobilenet_v1_224.pb \
  --output_node_names=MobilenetV1/Predictions/Reshape_1

将冻结的图转换成tflite模型:

  • input_file是已经冻结的图;
  • output_file是转换后输出的路径;
  • output_arrays这个可以在mobilenet_v1_1.0_224_info.txt中获取;
  • input_shapes这个是预测数据的shape

./toco --input_file=/tmp/mobilenet_v1_1.0_224_frozen.pb \
  --input_format=TENSORFLOW_GRAPHDEF \
  --output_format=TFLITE \
  --output_file=/tmp/mobilenet_v1_1.0_224.tflite \
  --inference_type=FLOAT \
  --input_type=FLOAT \
  --input_arrays=input \
  --output_arrays=MobilenetV1/Predictions/Reshape_1 \
  --input_shapes=1,224,224,3

3)使用检查点进行模型转换

  • 将tensorflow模型保存成.pb文件

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
 
if __name__ == "__main__":
    a = tf.Variable(tf.constant(5.,shape=[1]),name="a")
    b = tf.Variable(tf.constant(6.,shape=[1]),name="b")
    c = a + b
    init = tf.initialize_all_variables()
    sess = tf.Session()
    sess.run(init)
    #
导出当前计算图的GraphDef部分
    graph_def = tf.get_default_graph().as_graph_def()
    #保存指定的节点,并将节点值保存为常数
    output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add'])
    #将计算图写入到模型文件中
    model_f = tf.gfile.GFile("model.pb","wb")
    model_f.write(output_graph_def.SerializeToString())

  • 模型文件的读取

Bash
 sess = tf.Session()
    #
将保存的模型文件解析为GraphDef
    model_f = gfile.FastGFile("model.pb",'rb')
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(model_f.read())
    c = tf.import_graph_def(graph_def,return_elements=["add:0"])
    print(sess.run(c))
    #[array([ 11.], dtype=float32)]

  • pb文件转tflite

Python
import tensorflow as tf
 
in_path=r'D:\tmp_mobilenet_v1_100_224_classification_3\output_graph.pb'
out_path=r'D:\tmp_mobilenet_v1_100_224_classification_3\output_graph.tflite'
 
input_tensor_name=['Placeholder']
input_tensor_shape={'Placeholder':[1,224,224,3]}
 
class_tensor_name=['final_result']
 
convertr=tf.lite.TFLiteConverter.from_frozen_graph(in_path,input_arrays=input_tensor_name
                                                   ,output_arrays=class_tensor_name
                                                   ,input_shapes=input_tensor_shape)
 
# convertr=tf.lite.TFLiteConverter.from_saved_model(saved_model_dir=in_path,input_arrays=[input_tensor_name],output_arrays=[class_tensor_name])
tflite_model=convertr.convert()
 
with open(out_path,'wb') as f:
    f.write(tflite_model)

 

3.Android端调用TFlite模型文件

1Android studio中调用TFlite模型实现推理的流程

  • 定义一个interpreter
  • 初始化interpreter(加载tflite模型)
  • 在Android中加载图片到buffer中
  • 用解释器执行图形(推理)
  • 将推理的结果在app中进行显示

2)在Android Studio中导入TFLite模型步骤

  • 新建或打开现有Android项目工程。
  • 通过菜单项 File > New > Other > TensorFlow Lite Model 打开TFLite模型导入对话框。
  • 选择后缀名为.tflite的模型文件。模型文件可以从网上下载或自行训练。
  • 导入的.tflite模型文件位于工程的 ml/ 文件夹下面。

模型主要包括如下三种信息:

  • 模型:包括模型名称、描述、版本、作者等等。
  • 张量:输入和输出张量。比如图片需要预先处理成合适的尺寸,才能进行推理。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/1845.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MotionBert论文解读及详细复现教程

MotionBert:统一视角学习人体运动表示 通过学习人体运动表征,论文原作者提出了处理以人为中心的视频任务的统一方法。使用双流时空transformer(DSTformer)网络实现运动编码器,能够全面、自适应地捕获骨骼关节之间的远…

在php中安装php_xlswriter扩展报错,找不到php_xlswriter.dll

前言:这里已经把下载的php_xlswriter.dll扩展放到了php安装目录的ext目录下,运行php -m还是报错找不到该扩展 原因:下载的扩展是nts的,而安装的php是ts的。查看当前php是nts还是ts: 在PHP中,可以利用phpin…

在线乞讨系统 Docker一键部署

begger乞讨网 在线乞讨 全球要饭系统前端界面后端界面H2 数据库 console运行命令访问信息支付平台 在线乞讨 全球要饭系统 在线乞讨全球要饭项目,支持docker一键部署,支持企业微信通知,支持文案编辑 前端界面 后端界面 H2 数据库 console 运行命令 项…

TCP/IP网络编程 第十六章:关于IO流分离的其他内容

分离I/O流 两次I/O流分离 我们之前通过2种方法分离过IO流,第一种是第十章的“TCPI/O过程(Routine)分离”。这种方法通过调用fork函数复制出1个文件描述符,以区分输入和输出中使用的文件描述符。虽然文件描述符本身不会根据输入和输…

Spring框架概述及核心设计思想

文章目录 一. Spring框架概述1. 什么是Spring框架2. 为什么要学习框架?3. Spring框架学习的难点 二. Spring核心设计思想1. 容器是什么?2. IoC是什么?3. Spring是IoC容器4. DI(依赖注入)5. DL(依赖查找&…

数据结构_进阶(1):搜索二叉树

1.内容 建议再看这节之前能对C有一定了解 二叉树在前面C的数据结构阶段时有出过,现在我们对二叉树来学习一些更复杂的类型,也为之后C学习的 map 和 set 做铺垫 1. map和set特性需要先铺垫二叉搜索树,而二叉搜索树也是一种树形结构2. 二叉搜…

Stable Diffusion生成图片参数查看与抹除

前几天分享了几张Stable Diffusion生成的艺术二维码,有同学反映不知道怎么查看图片的参数信息,还有的同学问怎么保护自己的图片生成参数不会泄露,这篇文章就来专门分享如何查看和抹除图片的参数。 查看图片的生成参数 1、打开Stable Diffus…

Ubuntu 安装 Docker

本文目录 1. 卸载旧版本 Docker2. 更新及安装工具软件2.1 更新软件包列表2.2 安装几个工具软件2.3 增加一个 docker 的官方 GPG key2.4 下载仓库文件 3. 安装 Docker3.1 再次更新系统3.2 安装 docker-ce 软件 4. 查看是否启动 Docker5. 验证是否安装成功 1. 卸载旧版本 Docker …

【iOS】—— 属性关键字及weak关键字底层原理

文章目录 先来看看常用的属性关键字有哪些:内存管理有关的的关键字:(weak,assign,strong,retain,copy)关键字weak关键字assignweak 和 assign 的区别:关键字strong&#…

React(3)

1.案例选项卡 import React, { Component } from reactexport default class App extends Component {state{tabList:[{id:1,text:"电影"},{id:2,text:"影院"},{id:3,text:"我的"}]}render() {return (<div><ul>{this.state.tabList…

【LocalSend】开源跨平台的局域网文件传输工具,支持IOS、Android、Mac、Windows、Linux

工作前提条件&#xff1a;设备使用相同的局域网。 LocalSend is a cross-platform app that enables secure communication between devices using a REST API and HTTPS encryption. Unlike other messaging apps that rely on external servers, LocalSend doesn’t require …

【经济调度】基于多目标宇宙优化算法优化人工神经网络环境经济调度研究(Matlab代码实现)

目录 &#x1f4a5;1 概述 &#x1f4da;2 运行结果 &#x1f389;3 参考文献 &#x1f308;4 Matlab代码及数据 &#x1f4a5;1 概述 基于多目标宇宙优化算法&#xff08;Multi-Objective Universe Optimization Algorithm, MOUA&#xff09;优化人工神经网络环境经济调度是一…

预付费电表收费系统

预付费电表收费系统是一种先进的电表管理系统&#xff0c;它能够帮助电力公司更加高效地管理电表收费&#xff0c;提高用电效率&#xff0c;降低能源浪费。本文将从以下几个方面介绍预付费电表收费系统的特点和优势。 一、预付费电表收费系统的原理 预付费电表收费系统是指用户…

Hadoop集群启动常见错误

错误一 &#xff1a; 配置文件错误 解决方案&#xff1a;检查配置文件&#xff0c;修改错误。重新分发&#xff08;同步&#xff09; 常见错误二 &#xff1a; 重复格式化 DataNode NameNode 在格式化时如果发现下面的提示说明重复格式化了 datanode和namenode的集群id…

Spring Cloud 远程接口调用OpenFeign负载均衡实现原理详解

环境&#xff1a;Spring Cloud 2021.0.7 Spring Boot 2.7.12 配置依赖 maven依赖 <dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-openfeign</artifactId> </dependency> <dependency&…

一百三十、海豚调度器——用DolphinScheduler定时调度HiveSQL任务

一、目标 用海豚调度器对Hive数仓各层数据库的SQL任务进行定时调度。比如&#xff0c;DWD层脱敏清洗表的动态插入数据、DWS层指标表的动态插入数据 二、工具版本 1、海豚调度器&#xff1a;apache-dolphinscheduler-2.0.5-bin.tar.gz 2、Hive&#xff1a;apache-hive-3.1.2…

长短期记忆网络(LSTM)原理解析

长短期记忆网络&#xff08;Long Short-Term Memory&#xff0c;简称LSTM&#xff09;是一种常用于处理序列数据的深度学习模型。它在循环神经网络&#xff08;Recurrent Neural Network&#xff0c;RNN&#xff09;的基础上进行了改进&#xff0c;旨在解决传统RNN中的梯度消失…

PyTorch训练RNN, GRU, LSTM:手写数字识别

文章目录 pytorch 神经网络训练demoResult参考来源 pytorch 神经网络训练demo 数据集&#xff1a;MNIST 该数据集的内容是手写数字识别&#xff0c;其分为两部分&#xff0c;分别含有60000张训练图片和10000张测试图片 图片来源&#xff1a;https://tensornews.cn/mnist_intr…

Kafka消息监控管理工具Offset Explorer的使用教程

1、kafka监控管理工具 Offset Explorer是一款用于监控和管理Apache Kafka集群中消费者组偏移量的开源工具。它提供了一个简单直观的用户界面&#xff0c;用于查看和管理Kafka消费者组偏移量的详细信息。 Offset Explorer具有以下主要功能和特点&#xff1a; 实时监控&#x…

架构训练营学习笔记3-5:消息队列备选架构设计实战

本文属于架构训练营学习笔记系列&#xff1a;模块3的案例讲解 总的来说&#xff0c;这篇从更高的维度去讲&#xff0c;而不是关注消息队列的常见问题&#xff1a;比如消息如何发送&#xff0c;消息如何不丢失 &#xff0c;消息如何不重复。总体上分为2部分&#xff1a;利益干系…