FLOPs、FLOPS、Params的含义及PyTorch中的计算方法

FLOPs、FLOPS、Params的含义及PyTorch中的计算方法

含义解释

  1. FLOPS:注意全大写,是floating point operations per second的缩写(这里的大S表示second秒),表示每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。

  2. FLOPs:注意s小写,是floating point operations的缩写(这里的小s则表示复数),表示浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。

  3. Params:没有固定的名称,大小写均可,表示模型的参数量,也是用来衡量算法/模型的复杂度。通常我们在论文中见到的是这样:# Params,那个井号是表示 number of 的意思,因此 # Params 的意思就是:参数的数量。

在这里插入图片描述

FLOPs与模型时间复杂度、GPU利用率有关,Params与模型空间复杂度、显存占用有关。即我们常见的nvidia-smi命令中的GPU利用率(红框)和显存占用(篮框)。

MAC

MAC:Multiply Accumulate,乘加运算。乘积累加运算(英语:Multiply Accumulate, MAC)是在数字信号处理器或一些微处理器中的特殊运算。实现此运算操作的硬件电路单元,被称为“乘数累加器”。这种运算的操作,是将乘法的乘积结果和累加器的值相加,再存入累加器:
a←a+b×ca\leftarrow a+b\times c aa+b×c
使用MAC可以将原本需要的两个指令操作减少到一个指令操作,从而提高运算效率。

FLOPs的计算

以下不考虑激活函数的计算量。

卷积层

(2×Ci×K2−1)×H×W×C0(2\times C_i\times K^2-1)\times H\times W\times C_0(2×Ci×K21)×H×W×C0

CiC_iCi=输入通道数, KKK=卷积核尺寸,H,WH,WH,W=输出特征图空间尺寸,CoC_oCo=输出通道数。

一个MAC算两个个浮点运算,所以在最前面×2\times 2×2。不考虑bias时有−1-11,有bias时没有−1-11。由于考虑的一般是模型推理时的计算量,所以上述公式是针对一个输入样本的情况,即batch size=1。

理解上面这个公式分两步,括号内是第一步,计算出输出特征图的一个pixel的计算量,然后再乘以 H×W×CoH\times W\times C_oH×W×Co 拓展到整个输出特征图。

括号内的部分又可以分为两步,(2⋅Ci⋅K2−1)=(Ci⋅K2)+(Ci⋅K2−1)(2\cdot C_i\cdot K^2-1)=(C_i\cdot K^2)+(C_i\cdot K^2-1)(2CiK21)=(CiK2)+(CiK21)。第一项是乘法运算数,第二项是加法运算数,因为 nnn 个数相加,要加 n−1n-1n1 次,所以不考虑bias,会有一个−1-11,如果考虑bias,刚好中和掉,括号内变为 2⋅Ci⋅K22\cdot C_i\cdot K^22CiK2

全连接层

全连接层: (2×I−1)×O(2\times I-1)\times O(2×I1)×O

III=输入层神经元个数 ,OOO=输出层神经元个数。

还是因为一个MAC算两个个浮点运算,所以在最前面×2\times 2×2。同样不考虑bias时有−1-11,有bias时没有−1-11。分析同理,括号内是一个输出神经元的计算量,拓展到OOO了输出神经元。

NVIDIA Paper [2017-ICLR]

笔者在这里放上 NVIDIA 在 【2017-ICLR】的论文:PRUNING CONVOLUTIONAL NEURAL NETWORKS FOR RESOURCE EFFICIENT INFERENCE 的附录部分FLOPs计算方法截图放在下面供读者参考。
在这里插入图片描述

使用PyTorch直接输出模型的Params(参数量)

完整统计参数量

import torch 
from torchvision.models import resnet50
import numpy as npTotal_params = 0
Trainable_params = 0
NonTrainable_params = 0model = resnet50()
for param in model.parameters():mulValue = np.prod(param.size())  # 使用numpy prod接口计算参数数组所有元素之积Total_params += mulValue  # 总参数量if param.requires_grad:Trainable_params += mulValue  # 可训练参数量else:NonTrainable_params += mulValue  # 非可训练参数量print(f'Total params: {Total_params / 1e6}M')
print(f'Trainable params: {Trainable_params/ 1e6}M')
print(f'Non-trainable params: {NonTrainable_params/ 1e6}M')

输出:

Total params: 25.557032M
Trainable params: 25.557032M
Non-trainable params: 0.0M

简单统计可训练的参数量

通常,我们想知道的只是可训练的参数量,我们也可以简单地直接一行统计出可训练的参数量:

import torchvision.models as modelsmodel = models.resnet50(pretrained=False)Trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Trainable params: {Trainable_params/ 1e6}M')

输出:

Trainable params: 25.557032M

统计每一层的参数量

倘若想要统计每一层的参数量,参考代码如下:

model = vgg16()
for name, parameters in model.named_parameters():print(name, ':', np.prod(parameters.size()))

会打印出每一层的名称及参数量:

features.0.weight : 1728
features.0.bias : 64
features.2.weight : 36864
features.2.bias : 64
features.5.weight : 73728
...

使用thop库来获取模型的FLOPs(计算量)和Params(参数量)

安装

直接pypi安装即可

pip install thop

使用

我们使用thop库来计算vgg16模型的计算量和参数量。

import torch
from thop import profile
from archs.ViT_model import get_vit, ViT_Aes
from torchvision.models import resnet50model = resnet50()
input1 = torch.randn(4, 3, 224, 224) 
flops, params = profile(model, inputs=(input1, ))
print('FLOPs = ' + str(flops/1000**3) + 'G')
print('Params = ' + str(params/1000**2) + 'M')

输出:

FLOPs = 16.446058496G
Params = 25.557032M

Ref:

https://openreview.net/forum?id=SJGCiw5gl

https://www.zhihu.com/question/65305385/answer/451060549

https://www.cnblogs.com/chuqianyu/p/14254702.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532797.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

设置中文linux输入ubuntu,Linux_ubuntu怎么设置成中文?ubuntu中文设置图文方法,  很多朋友安装ubuntu后,发 - phpStudy...

ubuntu怎么设置成中文?ubuntu中文设置图文方法很多朋友安装ubuntu后,发现都是英文,看不懂要怎么办?其实ubuntu是可以设置成中文的,下文小编就为大家带来ubuntu中文的设置方法,一起去看下设置方法吧。ubuntu中文设置方…

科普 | 单精度、双精度、多精度和混合精度计算的区别是什么?

科普 | 单精度、双精度、多精度和混合精度计算的区别是什么? 转自:https://zhuanlan.zhihu.com/p/93812784 我们提到圆周率 π 的时候,它有很多种表达方式,既可以用数学常数3.14159表示,也可以用一长串1和0的二进制长串表示。 …

linux 磁盘分配 简书,linux 磁盘分区

1物理磁盘的构成: 盘面:由一圈一圈的磁道组成机械手臂:读取数据主轴马达:帮助机械手臂转动2 扇区:磁盘上存取数据的最小单位512字节按照扇区分配大小,如果数据只有一字节也会占用512字节簇:用若…

条件控制与条件传送详解

条件控制与条件传送详解 提要 CSAPP3e中文译本 3.6.5 用条件控制来实现条件分支 3.6.6 用条件传送来实现条件分支 CSAPP3e第三章前面主要是介绍了机器级代码的二进制形式和汇编形式、反汇编、x86汇编的基础指令、条件码及其访问方式等。 在介绍到汇编语言的条件分支时分了两…

联合体(union)的使用方法及其本质

联合体(union)的使用方法及其本质 转自:https://blog.csdn.net/huqinwei987/article/details/23597091 有些基础知识快淡忘了,所以有必要复习一遍,在不借助课本死知识的前提下做些推理判断,温故知新。 1…

linux设备驱动之串口移植,Linux设备驱动之UART驱动结构

一、对于串口驱动Linux系统中UART驱动属于终端设备驱动,应该说是实现串口驱动和终端驱动来实现串口终端设备的驱动。要了解串口终端的驱动在Linux系统的结构就先要了解终端设备驱动在Linux系统中的结构体系,一方面自己了解的不够,另一发面关于…

linux python复制安装,复制一个Python全部环境到另一个环境,python另一个,导出此环境下安装的包...

复制一个Python全部环境到另一个环境,python另一个,导出此环境下安装的包导出此环境下安装的包的版本信息清单pipfreeze>requirements.txt联网,下载清单中的包到all-packet文件夹[[email protected] ~]# pip download -d ./all-packet -r requirement…

NVIDIA英伟达的Multi-GPU多卡通信框架NCCL

NVIDIA英伟达的Multi-GPU多卡通信框架NCCL 笔者注:NCCL 开源项目地址:https://github.com/NVIDIA/nccl 转自:https://www.zhihu.com/question/63219175/answer/206697974 NCCL是Nvidia Collective multi-GPU Communication Library的简称&…

C语言n个坐标点间的最大距离,c语言已知两点坐标,求另一点到穿过这两点的直线最短距离。...

c语言已知两点坐标,求另一点到穿过这两点的直线最短距离。以下文字资料是由(历史新知网www.lishixinzhi.com)小编为大家搜集整理后发布的内容,让我们赶快一起来看一下吧!c语言已知两点坐标,求另一点到穿过这两点的直线最短距离。#…

[分布式训练] 单机多卡的正确打开方式:理论基础

[分布式训练] 单机多卡的正确打开方式:理论基础 转自:https://fyubang.com/2019/07/08/distributed-training/ 瓦砾由于最近bert-large用的比较多,踩了很多分布式训练的坑,加上在TensorFlow和PyTorch之间更换,算是熟…

s3c2416开发板 linux,S3C2416移植内核Linux3.1的wm9713声卡过程

移植内核的声卡驱动。原因没有声卡驱动,WM9713声卡驱动移植(原来的内核有UDA1341声卡驱动,我们再次基础上直接修改)1、直接复制内核得到三个文件:s3c2416_wm9713.c , wm9713.c , s3c2416_ac97.c.linux-3.1\sound\soc\codecs\Wm9713.c---->wm9713.c;li…

Linux查看文件内容命令:cat, tail, head, more, less

Linux查看文件内容命令:cat, tail, head, more, less cat 直接显示整个文件。 cat直接显示全部文件内容,没有换页等交互。 cat filenamemore more命令,功能类似 cat ,cat命令是整个文件的内容从上到下显示在屏幕上。 more会…

linux查看队列 msg,linux第10天 msg消息队列

cat /proc/sys/kernel/msgmax最大消息长度限制cat /proc/sys/kernel/msgmnb消息队列总的字节数cat /proc/sys/kernel/msgmni消息条目数消息队列综合案例//server#include #include #include #include #include #include #include #include #define ERR_EXIT(m)do{perror(m);}wh…

Linux中 C++ main函数参数argc和argv含义及用法

Linux中 C main函数参数argc和argv含义及用法 简介 argc 是 argument count的缩写,表示传入main函数的参数个数; argv 是 argument vector的缩写,表示传入main函数的参数序列或指针,并且第一个参数argv[0]一定是程序的名称&…

c语言六位抢答器课程设计,51单片机八路抢答器课程设计

;说明:本人的这个设计改进后解决了前一个版本中1号抢答优先的问题,并增加了锦囊的设置,当参赛选手在回答问题时要求使用锦囊,则主持人按下抢答开始键,计时重新开始。;八路抢答器电路请看下图是用ps仿真的,已…

ELF文件详解—初步认识

ELF文件详解—初步认识 转自:https://blog.csdn.net/daide2012/article/details/73065204 一、 引言 在讲解ELF文件格式之前,我们来回顾一下,一个用C语言编写的高级语言程序是从编写到打包、再到编译执行的基本过程,我们知道在C…

埃及分数问题c语言,埃及分数问题(转)

今日,小雨和小明来到网络中心,继续与刘老师讨论“数的认识”问题。刘老师说:“还有一种‘埃及分数’需要认识。这是一类分裂分数的思维题,对思维能力的训练很有价值。”小明说:“有意思,愿洗耳恭听。”刘老…

linux常用命令--开发调试篇

前言 Linux常用命令中有一些命令可以在开发或调试过程中起到很好的帮助作用,有些可以帮助了解或优化我们的程序,有些可以帮我们定位疑难问题。本文将简单介绍一下这些命令。 转自:https://www.yanbinghu.com/2018/09/26/61877.html 示例程序…

简单有趣的c语言小程序,一个有趣的小程序

该楼层疑似违规已被系统折叠 隐藏此楼查看此楼源码:#include #include #include #include #include HINSTANCE g_hInstance 0;LRESULT CALLBACK WndProc(HWND, UINT, WPARAM, LPARAM);int WINAPI WinMain(HINSTANCE hInstance,HINSTANCE hPreInstance,LPSTR lpCmdLine,int nSh…

linux下ora 01110,ORA-01003ORA-01110

Oracle 9i数据库登录时,提示ORA-01003&ORA-01110,大概意思是数据文件存储介质损坏。startup nomount,正常;alter database mount,也正常;alter database open,提示如下:alter database open*ERROR 位于第 1 行:ORA…