基于 PyTorch 的模型瘦身三部曲:量化、剪枝和蒸馏,让模型更短小精悍!

基于 PyTorch 的模型量化、剪枝和蒸馏

    • 1. 模型量化
      • 1.1 原理介绍
      • 1.2 PyTorch 实现
    • 2. 模型剪枝
      • 2.1 原理介绍
      • 2.2 PyTorch 实现
    • 3. 模型蒸馏
      • 3.1 原理介绍
      • 3.2 PyTorch 实现
    • 参考文献

在这里插入图片描述

1. 模型量化

1.1 原理介绍

模型量化是将模型参数从高精度(通常是 float32)转换为低精度(如 int8 或更低)的过程。这种技术可以显著减少模型大小、降低计算复杂度,并加快推理速度,同时尽可能保持模型的性能。
在这里插入图片描述
量化的主要方法包括:

  1. 动态量化

    • 在推理时动态地将权重从 float32 量化为 int8。
    • 激活值在计算过程中保持为浮点数。
    • 适用于 RNN 和变换器等模型。
  2. 静态量化

    • 在推理之前,预先将权重从 float32 量化为 int8。
    • 在推理过程中,激活值也被量化。
    • 需要校准数据来确定激活值的量化参数。
  3. 量化感知训练(QAT)

    • 在训练过程中模拟量化操作。
    • 允许模型适应量化带来的精度损失。
    • 通常能够获得比后量化更高的精度。

1.2 PyTorch 实现

import torch# 1. 动态量化
model_fp32 = MyModel()
model_int8 = torch.quantization.quantize_dynamic(model_fp32,  # 原始模型{torch.nn.Linear, torch.nn.LSTM},  # 要量化的层类型dtype=torch.qint8  # 量化后的数据类型
)# 2. 静态量化
model_fp32 = MyModel()
model_fp32.eval()  # 设置为评估模式# 设置量化配置
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_fp32_prepared = torch.quantization.prepare(model_fp32)# 使用校准数据进行校准
with torch.no_grad():for batch in calibration_data:model_fp32_prepared(batch)# 转换模型
model_int8 = torch.quantization.convert(model_fp32_prepared)# 3. 量化感知训练
model_fp32 = MyModel()
model_fp32.train()  # 设置为训练模式# 设置量化感知训练配置
model_fp32.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model_fp32_prepared = torch.quantization.prepare_qat(model_fp32)# 训练循环
for epoch in range(num_epochs):for batch in train_data:output = model_fp32_prepared(batch)loss = criterion(output, target)loss.backward()optimizer.step()# 转换模型
model_int8 = torch.quantization.convert(model_fp32_prepared)

2. 模型剪枝

2.1 原理介绍

模型剪枝是一种通过移除模型中不重要的权重或神经元来减少模型复杂度的技术。剪枝可以减少模型大小、降低计算复杂度,并可能改善模型的泛化能力。
在这里插入图片描述

主要的剪枝方法包括:

  1. 权重剪枝

    • 移除绝对值小于某个阈值的单个权重。
    • 可以大幅减少模型参数数量,但可能导致非结构化稀疏性。
  2. 结构化剪枝

    • 移除整个卷积核、神经元或通道。
    • 产生更加规则的稀疏结构,有利于硬件加速。
  3. 重要性剪枝

    • 基于权重或激活值的重要性评分来决定剪枝对象。
    • 常用的重要性度量包括权重幅度、激活值、梯度等。

2.2 PyTorch 实现

import torch
import torch.nn.utils.prune as prunemodel = MyModel()# 1. 权重剪枝
prune.l1_unstructured(model.conv1, name='weight', amount=0.3)# 2. 结构化剪枝
prune.ln_structured(model.conv1, name='weight', amount=0.5, n=2, dim=0)# 3. 全局剪枝
parameters_to_prune = ((model.conv1, 'weight'),(model.conv2, 'weight'),(model.fc1, 'weight'),
)
prune.global_unstructured(parameters_to_prune,pruning_method=prune.L1Unstructured,amount=0.2
)# 4. 移除剪枝
for module in model.modules():if isinstance(module, torch.nn.Conv2d):prune.remove(module, 'weight')

3. 模型蒸馏

3.1 原理介绍

模型蒸馏是一种将复杂模型(教师模型)的知识转移到简单模型(学生模型)的技术。这种方法可以在保持性能的同时,大幅减少模型的复杂度和计算需求。
在这里插入图片描述

主要的蒸馏方法包括:

  1. 响应蒸馏

    • 学生模型学习教师模型的最终输出(软标签)。
    • 软标签包含了教师模型对不同类别的置信度信息。
  2. 特征蒸馏

    • 学生模型学习教师模型的中间层特征。
    • 可以传递更丰富的知识,但需要设计合适的映射函数。
  3. 关系蒸馏

    • 学习样本之间的关系,如相似度或排序。
    • 有助于保持教师模型学到的数据结构。

3.2 PyTorch 实现

import torch
import torch.nn as nn
import torch.nn.functional as Fclass DistillationLoss(nn.Module):def __init__(self, alpha=0.5, temperature=2.0):super().__init__()self.alpha = alphaself.T = temperaturedef forward(self, student_outputs, teacher_outputs, labels):# 硬标签损失hard_loss = F.cross_entropy(student_outputs, labels)# 软标签损失soft_loss = F.kl_div(F.log_softmax(student_outputs / self.T, dim=1),F.softmax(teacher_outputs / self.T, dim=1),reduction='batchmean') * (self.T * self.T)# 总损失loss = (1 - self.alpha) * hard_loss + self.alpha * soft_lossreturn loss# 训练循环
teacher_model = TeacherModel().eval()
student_model = StudentModel().train()
distillation_loss = DistillationLoss(alpha=0.5, temperature=2.0)for epoch in range(num_epochs):for batch, labels in train_loader:optimizer.zero_grad()with torch.no_grad():teacher_outputs = teacher_model(batch)student_outputs = student_model(batch)loss = distillation_loss(student_outputs, teacher_outputs, labels)loss.backward()optimizer.step()

通过这些技术的组合使用,可以显著减小模型大小、提高推理速度,同时尽可能保持模型性能。在实际应用中,可能需要根据具体任务和硬件限制来选择和调整这些方法。

参考文献

[1]Jacob, B., Kligys, S., Chen, B., Zhu, M., Tang, M., Howard, A., Adam, H., & Kalenichenko, D. (2018). Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 2704-2713).[2]Krishnamoorthi, R. (2018). Quantizing deep convolutional networks for efficient inference: A whitepaper. arXiv preprint arXiv:1806.08342.[3]Han, S., Pool, J., Tran, J., & Dally, W. (2015). Learning both Weights and Connections for Efficient Neural Network. In Advances in Neural Information Processing Systems (NeurIPS) (pp. 1135-1143).[4]Li, H., Kadav, A., Durdanovic, I., Samet, H., & Graf, H. P. (2016). Pruning Filters for Efficient ConvNets. arXiv preprint arXiv:1608.08710.[5]Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the Knowledge in a Neural Network. arXiv preprint arXiv:1503.02531.[6]Romero, A., Ballas, N., Kahou, S. E., Chassang, A., Gatta, C., & Bengio, Y. (2014). FitNets: Hints for Thin Deep Nets. arXiv preprint arXiv:1412.6550.

创作不易,烦请各位观众老爷给个三连,小编在这里跪谢了!
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/874459.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024年自动驾驶规划控制面试及答案

自动驾驶行业随着发展越来越卷,在面试前有更多的准备对求职者来说是很有必要的。在即将秋招来临之前给大家整理出了一些新的 自动驾驶规划控制面试题,希望在大家找工作的过程中提供帮助。 持续更新中…… 刷更多的面试题有助于求职者展示自己与公司需求的…

Elasticsearch:Retrievers 介绍 - Python Jupyter notebook

在今天的文章里,我是继上一篇文章 “Elasticsearch:介绍 retrievers - 搜索一切事物” 来使用一个可以在本地设置的 Elasticsearch 集群来展示 Retrievers 的使用。在本篇文章中,你将学到如下的内容: 从 Kaggle 下载 IMDB 数据集…

Linux云计算 |【第一阶段】SERVICES-DAY5

主要内容: 源码编译安装、rsync同步操作、inotify实时同步、数据库服务基础 实操前骤:(所需tools.tar.gz与users.sql) 1.两台主机设置SELinnx和关闭防火墙 setenforce 0 systemctl stop firewalld.service //停止防火墙 sy…

C#类型基础Part2-对象判等

C#类型基础Part2-对象判等 参考资料引用类型判等简单值类型判等复杂值类型判等 参考资料 《.NET之美-.NET关键技术深入解析》 引用类型判等 先定义两个类型,它们代表直线上的一个点,一个是引用类型class,一个是值类型struct public class…

MySQL定时备份数据,并上传到oss

1.环境准备 1.安装阿里云的ossutil 2.安装mysql 2.编写脚本 脚本内容如下 #!/bin/bash # 数据库的配置信息,根据自己的情况进行填写 db_hostlocalhost db_usernameroot db_passwordroot db_namedb_root # oss 存贮数据的bucket地址 bucket_namerbsy-backup-buck…

软件更新的双刃剑:从”微软蓝屏”事件看网络安全的挑战与对策

引言 原文链接 近日,一场由微软视窗系统软件更新引发的全球性"微软蓝屏"事件震惊了整个科技界。这次事件源于美国电脑安全技术公司"众击"提供的一个带有"缺陷"的软件更新,如同一颗隐形炸弹在全球范围内引爆,…

/var/log/nginx/error.log

日志 : *1979 recv() failed (104: Connection reset by peer) while proxying and reading from upstream, client: 127.0.0.1, server: 0.0.0.0:60443, upstream: “10.x.x.x:6443”, bytes from/to client:64534240/731320118, bytes from/to upstream:731320118/64534240 …

Qt 实战(7)元对象系统 | 7.5、QMetaProperty详解

文章目录 一、QMetaProperty详解1、QMetaProperty的作用2、使用QMetaProperty2.1、声明属性2.2、访问属性 3、QMetaProperty成员方法4、示例4.1、通过名称获取指定属性4.2、遍历全部属性(包含从基类继承下来的)4.3、遍历当前类的全部属性(不包…

Python面试宝典第17题:Z字形变换

题目 将一个给定字符串 s 根据给定的行数numRows ,以从上往下、从左到右进行Z字形排列。比如:输入字符串为"PAYPALISHIRING",行数为3时,排列如下。最后,你的输出需要从左往右逐行读取,产生出一个…

unity 实现图片的放大与缩小(根据鼠标位置拉伸放缩)

1创建UnityHelper.cs using UnityEngine.Events; using UnityEngine.EventSystems;public class UnityHelper {/// <summary>/// 简化向EventTrigger组件添加事件的操作。/// </summary>/// <param name"_eventTrigger">要添加事件监听的UI元素上…

DevExpress中文教程 - 如何在.NET MAUI应用中实现Material Design 3?

DevExpress .NET MAUI多平台应用UI组件库提供了用于Android和iOS移动开发的高性能UI组件&#xff0c;该组件库包括数据网格、图表、调度程序、数据编辑器、CollectionView和选项卡组件等。 获取DevExpress v24.1正式版下载 Material Design是一个由Google开发的跨平台指南系统…

HydraRPC: RPC in the CXL Era——论文阅读

ATC 2024 Paper CXL论文阅读笔记整理 问题 远程过程调用&#xff08;RPC&#xff09;是分布式系统中的一项基本技术&#xff0c;它允许函数在远程服务器上通过本地调用执行来促进网络通信&#xff0c;隐藏底层通信过程的复杂性简化了客户端/服务器交互[15]。RPC已成为数据中心…

【Hot100】LeetCode—279. 完全平方数

目录 题目1- 思路2- 实现⭐完全平方数——题解思路 3- ACM 实现 题目 原题连接&#xff1a;279. 完全平方数 1- 思路 思路 动规五部曲 2- 实现 ⭐完全平方数——题解思路 class Solution {public int numSquares(int n) {// 1. 定义 dpint[] dp new int[n1];//2. 递推公式…

Mojo编程语言

Mojo编程语言作为一种新兴的、专为AI开发者设计的编程语言&#xff0c;近年来在AI领域引起了广泛关注&#xff0c;并逐渐成为AI开发者的新宠儿。以下是对Mojo编程语言的详细解析&#xff1a; 设计目的与特点 Mojo编程语言由Modular公司开发&#xff0c;旨在结合Python的易用性…

算法学习5——图算法

图算法在计算机科学中占有重要地位&#xff0c;广泛应用于网络连接、路径查找、社会网络分析等领域。本文将介绍几种常见的图算法&#xff0c;包括Dijkstra算法、Bellman-Ford算法、Floyd-Warshall算法、Kruskal算法和Prim算法&#xff0c;并提供相应的Python代码示例。 图的基…

在 WSL2 中频繁切换 PHP 版本,可以使用更简便的方法

在 WSL2 中频繁切换 PHP 版本&#xff0c;可以使用更简便的方法&#xff0c;例如使用 update-alternatives 工具。这是一种更系统化的方法&#xff0c;允许你更方便地管理和切换不同的 PHP 版本。 以下是使用 update-alternatives 工具切换 PHP 版本的步骤&#xff1a; 添加 P…

论文学习记录之一种具有边缘增强特点的医学图像分割网络

标题&#xff1a;一种具有边缘增强特点的医学图像分割网络 期刊&#xff1a;电子与信息学报-&#xff08;2022年5月出刊&#xff09; 摘要&#xff1a;针对传统医学图像分割网络存在边缘分割不清晰、缺失值大等问题&#xff0c;该文提出一种具有边缘增强特点的医学图像分割网…

社交圈子小程序搭建-源码部署-服务公司

消息通知:当有新的消息、评论或回复时&#xff0c;用户需要收到系统的推送通知&#xff0c;以便及时查看和回复 活动发布与参加:用户可以在社交圈子中发布各种类型的活动&#xff0c;如聚餐、旅游、运动等。其他用户可以参加这些活动&#xff0c;并与组织者进行交流和沟通 社交…

C#初级——输出语句和转义字符

输出语句 在C#中&#xff0c;C#的输出语句是通过Console类进行输出&#xff0c;该类是一个在控制台下的一个标准输入流、输出流和错误流。使用该类下的Write()函数&#xff0c;即可打印要输出的内容。 Console.Write("Hello World!"); //在控制台应用中打印Hell…

通过QT进行服务器和客户端之间的网络通信

客户端 client.pro #------------------------------------------------- # # Project created by QtCreator 2024-07-02T14:11:20 # #-------------------------------------------------QT core gui network #网络通信greaterThan(QT_MAJOR_VERSION, 4): QT widg…