深度学习Dropout实现

深度学习中的 Dropout 技术在代码层面上的实现通常非常直接。其核心思想是在训练过程中,对于网络中的每个神经元(或者更精确地说,是每个神经元的输出),以一定的概率 p 随机将其输出置为 0。在反向传播时,这些被“drop out”的神经元也不会参与梯度更新。

以下是 Dropout 在代码层面上的一个基本实现逻辑,以 Python 和 NumPy 为例进行说明,然后再展示在常见的深度学习框架(如 TensorFlow 和 PyTorch)中的实现方式。

1. NumPy 实现(概念演示)

假设我们有一个神经网络的某一层输出 activation,它是一个形状为 (batch_size, num_neurons) 的 NumPy 数组。我们可以通过以下步骤实现 Dropout:

Python

import numpy as npdef dropout_numpy(activation, keep_prob):"""使用 NumPy 实现 Dropout。Args:activation: 神经网络层的激活输出 (NumPy array).keep_prob: 保留神经元的概率 (float, 0 到 1 之间).Returns:经过 Dropout 处理的激活输出 (NumPy array).mask: 用于记录哪些神经元被 drop out 的掩码 (NumPy array)."""if keep_prob < 0. or keep_prob > 1.:raise ValueError("keep_prob must be between 0 and 1")# 生成一个和 activation 形状相同的随机掩码,元素值为 True 或 Falsemask = (np.random.rand(*activation.shape) < keep_prob)# 将掩码应用于激活输出,被 drop out 的神经元输出置为 0output = activation * mask# 在训练阶段,为了保证下一层的期望输入不变,需要对保留下来的神经元输出进行缩放output /= keep_probreturn output, mask# 示例
batch_size = 64
num_neurons = 128
activation = np.random.randn(batch_size, num_neurons)
keep_prob = 0.8dropout_output, dropout_mask = dropout_numpy(activation, keep_prob)print("原始激活输出的形状:", activation.shape)
print("Dropout 后的激活输出的形状:", dropout_output.shape)
print("Dropout 掩码的形状:", dropout_mask.shape)
print("被 drop out 的神经元比例:", np.sum(dropout_mask == False) / dropout_mask.size)

代码解释:

  • keep_prob: 这是保留神经元的概率。Dropout 的概率通常设置为 1 - keep_prob
  • 生成掩码 (mask): 我们使用 np.random.rand() 生成一个和输入 activation 形状相同的随机数数组,其元素值在 0 到 1 之间。然后,我们将这个数组与 keep_prob 进行比较,得到一个布尔类型的掩码。True 表示对应的神经元被保留,False 表示被 drop out。
  • 应用掩码 (output = activation * mask): 我们将掩码和原始的激活输出进行逐元素相乘。由于布尔类型的 TrueFalse 在数值运算中会被转换为 1 和 0,所以掩码中为 False 的位置对应的激活输出会被置为 0。
  • 缩放 (output /= keep_prob): 这是一个非常重要的步骤。在训练阶段,由于一部分神经元被随机置为 0,为了保证下一层神经元接收到的期望输入与没有 Dropout 时大致相同,我们需要对保留下来的神经元的输出进行放大。放大的倍数是 1 / keep_prob

需要注意的是,在模型的评估(或推理)阶段,通常不会使用 Dropout。这意味着 keep_prob 会被设置为 1,或者 Dropout 层会被禁用。这是因为 Dropout 是一种在训练时使用的正则化技术,用于减少过拟合。在评估时,我们希望模型的所有神经元都参与计算,以获得最准确的预测。

2. TensorFlow 实现

在 TensorFlow 中,Dropout 是一个内置的层:

Python

import tensorflow as tf# 在 Sequential 模型中添加 Dropout 层
model = tf.keras.models.Sequential([tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),tf.keras.layers.Dropout(0.2), # Dropout 概率为 0.2 (即 keep_prob 为 0.8)tf.keras.layers.Dense(10, activation='softmax')
])# 或者在函数式 API 中使用
inputs = tf.keras.Input(shape=(784,))
x = tf.keras.layers.Dense(128, activation='relu')(inputs)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
model_functional = tf.keras.Model(inputs=inputs, outputs=outputs)# 训练模型
# model.compile(...)
# model.fit(...)

在 TensorFlow 的 tf.keras.layers.Dropout(rate) 层中,rate 参数指定的是神经元被 drop out 的概率。在训练时,这个层会随机将一部分神经元的输出置为 0,并对剩下的神经元进行缩放。在推理时,这个层不会有任何作用。TensorFlow 内部会自动处理训练和推理阶段的行为。

3. PyTorch 实现

在 PyTorch 中,Dropout 也是一个内置的模块:

代码段

import torch
import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(784, 128)self.relu = nn.ReLU()self.dropout = nn.Dropout(p=0.2) # Dropout 概率为 0.2self.fc2 = nn.Linear(128, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.dropout(x)x = self.fc2(x)x = self.softmax(x)return xmodel = Net()# 设置模型为训练模式 (启用 Dropout)
model.train()# 设置模型为评估模式 (禁用 Dropout)
model.eval()# 在前向传播中使用 Dropout
# output = model(input_tensor)

在 PyTorch 的 nn.Dropout(p) 模块中,p 参数指定的是神经元被 drop out 的概率。与 TensorFlow 类似,PyTorch 的 Dropout 在训练模式 (model.train()) 下会启用,随机将神经元置零并缩放输出。在评估模式 (model.eval()) 下,Dropout 层会失效,相当于一个恒等变换。

总结

在代码层面上,Dropout 的实现主要涉及以下几个步骤:

  1. 生成一个随机的二值掩码,其形状与神经元的输出相同,掩码中每个元素以一定的概率(Dropout 概率)为 0,以另一概率(保留概率)为 1。
  2. 将这个掩码与神经元的输出逐元素相乘,从而将一部分神经元的输出置为 0。
  3. 在训练阶段,对保留下来的神经元的输出进行缩放,通常除以保留概率。
  4. 在评估阶段,禁用 Dropout,即不进行掩码操作和缩放。

现代深度学习框架已经将 Dropout 的实现封装在专门的层或模块中,用户只需要指定 Dropout 的概率即可,框架会自动处理训练和评估阶段的不同行为。这大大简化了在模型中应用 Dropout 的过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/905949.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AtCoder AT_abc406_c [ABC406C] ~

前言 除了 A 题&#xff0c;唯一一道一遍过的题。 题目大意 我们定义满足以下所有条件的一个长度为 N N N 的序列 A ( A 1 , A 2 , … , A N ) A(A_1,A_2,\dots,A_N) A(A1​,A2​,…,AN​) 为波浪序列&#xff1a; N ≥ 4 N\ge4 N≥4&#xff08;其实满足后面就必须满足这…

Java Web 应用安全响应头配置全解析:从单体到微服务网关的实践

背景&#xff1a;为什么安全响应头至关重要&#xff1f; 在 Web 安全领域&#xff0c;响应头&#xff08;Response Headers&#xff09;是防御 XSS、点击劫持、跨域数据泄露等攻击的第一道防线。通过合理配置响应头&#xff0c;可强制浏览器遵循安全策略&#xff0c;限制恶意行…

如何停止终端呢?ctrl+c不管用,其他有什么方法呢?

如果你在终端中运行了一个程序&#xff08;比如 Python GUI tkinter 应用&#xff09;&#xff0c;按下 Ctrl C 没有作用&#xff0c;一般是因为该程序&#xff1a; 运行了主事件循环&#xff08;例如 tkinter.mainloop()&#xff09; 或 在子线程中运行&#xff0c;而 Ctrl …

深入解析 React 的 useEffect:从入门到实战

文章目录 前言一、为什么需要 useEffect&#xff1f;核心作用&#xff1a; 二、useEffect 的基础用法1. 基本语法2. 依赖项数组的作用 三、依赖项数组演示1. 空数组 []&#xff1a;2.无依赖项&#xff08;空&#xff09;3.有依赖项 四、清理副作用函数实战案例演示1. 清除定时器…

Ubuntu 更改 Nginx 版本

将 1.25 降为 1.18 先卸载干净 # 1. 完全卸载当前Nginx sudo apt purge nginx nginx-common nginx-core# 2. 清理残留配置 sudo apt autoremove sudo rm -rf /etc/apt/sources.list.d/nginx*.list修改仓库地址 # 添加仓库&#xff08;通用稳定版仓库&#xff09; codename$(…

如何在 Windows 10 或 11 中安装 PowerShellGet 模块?

PowerShell 是微软在其 Windows 操作系统上提供的强大脚本语言,可用于通过命令行界面自动化各种任务,适用于 Windows 桌面或服务器环境。而 PowerShellGet 是 PowerShell 中的一个模块,提供了用于从各种来源发现、安装、更新和发布模块的 cmdlet。 本文将介绍如何在 PowerS…

NBA足球赛事直播源码体育直播M33模板赛事源码

源码名称&#xff1a;体育直播赛事扁平自适应M33直播模板源码 开发环境&#xff1a;帝国cms7.5 空间支持&#xff1a;phpmysql 带软件采集&#xff0c;可以挂着自动采集发布&#xff0c;无需人工操作&#xff01; 演示地址&#xff1a;NBA足球赛事直播源码体育直播M33模板赛事…

【Python】魔法方法是真的魔法! (第二期)

还不清楚魔术方法&#xff1f; 可以看看本系列开篇&#xff1a;【Python】小子&#xff01;是魔术方法&#xff01;-CSDN博客 【Python】魔法方法是真的魔法&#xff01; &#xff08;第一期&#xff09;-CSDN博客 在 Python 中&#xff0c;如何自定义数据结构的比较逻辑&…

Qt 强大的窗口停靠浮动

1、左边&#xff1a; 示例代码&#xff1a; CDockManager::setConfigFlags(CDockManager::DefaultOpaqueConfig); CDockManager::setConfigFlag(CDockManager::FocusHighlighting, true); dockManager new CDockManager(this); // Disabling the Internal Style S…

Linux进程异常退出排查指南

在 Linux 中&#xff0c;如果进程无法正常终止&#xff08;如 kill 命令无效&#xff09;或异常退出&#xff0c;可以按照以下步骤排查和解决&#xff1a; 1. 常规终止进程 尝试普通终止&#xff08;SIGTERM&#xff09; kill PID # 发送 SIGTERM 信号&#xff08;…

使用tensorRT10部署低光照补偿模型

1.低光照补偿模型的简单介绍 作者介绍一种Zero-Reference Deep Curve Estimation (Zero-DCE)的方法用于在没有参考图像的情况下增强低光照图像的效果。 具体来说&#xff0c;它将低光照图像增强问题转化为通过深度网络进行图像特定曲线估计的任务。训练了一个轻量级的深度网络…

SLAM定位常用地图对比示例

序号 地图类型 概述 1 格栅地图 将现实环境栅格化,每一个栅格用 0 和 1 分别表示空闲和占据状态,初始化为未知状态 0.5 2 特征地图 以点、线、面等几何特征来描绘周围环境,将采集的信息进行筛选和提取得到关键几何特征 3 拓扑地图 将重要部分抽象为地图,使用简单的图形表示…

【图像生成1】Latent Diffusion Models 论文学习笔记

一、背景 本文主要记录一下使用 LDMs 之前&#xff0c;学习 LDMs 的过程。 二、论文解读 Paper&#xff1a;[2112.10752] High-Resolution Image Synthesis with Latent Diffusion Models 1. 总体描述 LDMs 将传统 DMs 在高维图像像素空间&#xff08;Pixel Space&#x…

通信安全堡垒:profinet转ethernet ip主网关提升冶炼安全与连接

作为钢铁冶炼生产线的安全检查员&#xff0c;我在此提交关于使用profinet转ethernetip网关前后对生产线连接及安全影响的检查报告。 使用profinet转ethernetip网关前的情况&#xff1a; 在未使用profinet转ethernetip网关之前&#xff0c;我们的EtherNet/IP测温仪和流量计与PR…

TIFS2024 | CRFA | 基于关键区域特征攻击提升对抗样本迁移性

Improving Transferability of Adversarial Samples via Critical Region-Oriented Feature-Level Attack 摘要-Abstract引言-Introduction相关工作-Related Work提出的方法-Proposed Method问题分析-Problem Analysis扰动注意力感知加权-Perturbation Attention-Aware Weighti…

day 20 奇异值SVD分解

一、什么是奇异值 二、核心思想&#xff1a; 三、奇异值的主要应用 1、降维&#xff1a; 2、数据压缩&#xff1a; 原理&#xff1a;图像可以表示为一个矩阵&#xff0c;矩阵的元素对应图像的像素值。对这个图像矩阵进行 SVD 分解后&#xff0c;小的奇异值对图像的主要结构贡…

符合Python风格的对象(对象表示形式)

对象表示形式 每门面向对象的语言至少都有一种获取对象的字符串表示形式的标准方 式。Python 提供了两种方式。 repr()   以便于开发者理解的方式返回对象的字符串表示形式。str()   以便于用户理解的方式返回对象的字符串表示形式。 正如你所知&#xff0c;我们要实现_…

springboot配置tomcat端口的方法

在Spring Boot中配置Tomcat端口可通过以下方法实现&#xff1a; 配置文件方式 properties格式 在application.properties中添加&#xff1a;server.port8081YAML格式 在application.yml中添加&#xff1a;server:port: 8082多环境配置 创建不同环境的配置文件&#xff08;如app…

DeepSeek指令微调与强化学习对齐:从SFT到RLHF

后训练微调的重要性 预训练使大模型获得丰富的语言和知识表达能力,但其输出往往与用户意图和安全性需求不完全匹配。业内普遍采用三阶段训练流程:预训练 → 监督微调(SFT)→ 人类偏好对齐(RLHF)。预训练阶段模型在大规模语料上学习语言规律;监督微调利用人工标注的数据…

Maven 插件扩展点与自定义生命周期

&#x1f9d1; 博主简介&#xff1a;CSDN博客专家&#xff0c;历代文学网&#xff08;PC端可以访问&#xff1a;https://literature.sinhy.com/#/?__c1000&#xff0c;移动端可微信小程序搜索“历代文学”&#xff09;总架构师&#xff0c;15年工作经验&#xff0c;精通Java编…