实用指南:深度学习:从零开始手搓一个浅层神经网络(Single Hidden Layer Neural Network)

news/2025/12/9 18:12:40/文章来源:https://www.cnblogs.com/ljbguanli/p/19327803

实用指南:深度学习:从零开始手搓一个浅层神经网络(Single Hidden Layer Neural Network)

本文带你一步步用 Python 实现一个最基础的神经网络模型,理解前向传播、反向传播、梯度下降等核心概念,真正“动手”掌握深度学习的本质。


一、引言:为什么我们要自己写神经网络?

在今天,我们有 TensorFlow、PyTorch 等强大的框架可以快速构建复杂的模型。但如果你只是初学者,或者想深入理解神经网络的内部机制,亲手实现一个简单的神经网络是绝佳的学习方式。

本篇博客将带你从头到尾实现一个含有一个隐藏层的神经网络,用于解决二维平面数据分类问题(如非线性可分的数据)。我们将使用 NumPy 和 Matplotlib 手动完成所有计算,不依赖任何高级框架。


二、整体架构概览

我们构建的是如下结构的神经网络:

输入层 (2个特征) → 隐藏层 (n_h个神经元) → 输出层 (1个神经元)↓                  ↓                   ↓x₁, x₂             tanh(Z₁)           tanh(Z₂)
  • 输入:每个样本有两个特征(横纵坐标)
  • 隐藏层:n_h 个神经元,激活函数为 tanh
  • 输出层:1 个神经元,激活函数也为 tanh(也可改为 sigmoid,这里为了简化)

(注:图中红色椭圆表示输入特征,蓝色箭头表示权重连接)


三、环境准备与数据加载

import matplotlib
matplotlib.use('TkAgg')  # 如果有GUI支持
import numpy as np
import matplotlib.pyplot as plt
import sklearn
import sklearn.linear_model
from planar_utils import plot_decision_boundary, sigmoid, load_planar_dataset, load_extra_datasets
from testCases import *
np.random.seed(1)
X, Y = load_planar_dataset()
plt.scatter(X[0,:], X[1,:], c=Y.ravel(), s=40, cmap=plt.cm.Spectral)
plt.title("原始数据分布")
plt.show()

✅ 解释:

  • load_planar_dataset() 是一个辅助函数,生成了带有标签的二维点集(通常是环形或花形数据,非线性可分)。
  • 使用 scatter 可视化数据,可以看到颜色不同的两类点无法用直线分开。
  • 这正是神经网络发挥作用的地方!

️ 四、参数初始化

def initialize_parameters(n_x, n_h, n_y):np.random.seed(2)W1 = np.random.randn(n_h, n_x) * 0.01b1 = np.zeros((n_h, 1))W2 = np.random.randn(n_y, n_h) * 0.01b2 = np.zeros((n_y, 1))parameters = {"W1": W1,"b1": b1,"W2": W2,"b2": b2}return parameters

✅ 说明:

  • n_x: 输入维度(这里是 2)
  • n_h: 隐藏层神经元数量(可调参)
  • n_y: 输出维度(这里是 1)
  • 权重 W 用小随机数初始化(乘以 0.01),防止梯度爆炸
  • 偏置 b 初始为 0

⏭️ 五、前向传播(Forward Propagation)

def forward_propagation(X, parameters):m = X.shape[1]W1 = parameters['W1']b1 = parameters['b1']W2 = parameters['W2']b2 = parameters['b2']Z1 = np.dot(W1, X) + b1A1 = np.tanh(Z1)Z2 = np.dot(W2, A1) + b2A2 = np.tanh(Z2)cache = {"Z1": Z1, "A1": A1, "Z2": Z2, "A2": A2}return A2, cache

✅ 数学公式:

  • A2 是最终输出,代表对每个样本的预测值
  • cache 保存中间变量,供后续反向传播使用

六、损失函数计算

def compute_cost(A2, Y, parameters):m = Y.shape[1]logprobs = np.multiply(np.log(A2), Y) + np.multiply(1 - Y, np.log(1 - A2))cost = -np.sum(logprobs) / mreturn cost

✅ 注意:

  • 这里用了 二元交叉熵损失函数(Binary Cross-Entropy),适用于二分类任务
  • 但注意:由于我们用了 tanh 作为激活函数,其输出范围是 [-1, 1],而标准交叉熵要求 [0,1],所以严格来说应该改用 sigmoid 或调整标签
  • 实际上,在这个练习中,我们更关注流程而非精确性,因此仍可用此形式

七、反向传播(Backward Propagation)

def backward_promagation(parameters, cache, X, Y):m = X.shape[1]W1 = parameters['W1']W2 = parameters['W2']A1 = cache['A1']A2 = cache['A2']dZ2 = A2 - YdW2 = (1 / m) * np.dot(dZ2, A1.T)db2 = (1 / m) * np.sum(dZ2, axis=1, keepdims=True)dZ1 = np.multiply(np.dot(W2.T, dZ2), 1 - np.power(A1, 2))dW1 = (1 / m) * np.dot(dZ1, X.T)db1 = (1 / m) * np.sum(dZ1, axis=1, keepdims=True)grads = {"dW1": dW1,"db1": db1,"dW2": dW2,"db2": db2}return grads

✅ 数学推导简述:

  • $ \frac{\partial L}{\partial Z_2} = A_2 - Y $
  • $ \frac{\partial L}{\partial W_2} = \frac{1}{m} dZ2 \cdot A1^T $
  • $ \frac{\partial L}{\partial Z_1} = W_2^T \cdot dZ2 \cdot (1 - A1^2) $ (因为 $ \frac{d}{dz}\tanh(z) = 1 - \tanh^2(z) $

⬇️ 八、参数更新(梯度下降)

def update_parameters(parameters, grads, learning_rate=1.2):W1 = parameters['W1'] - learning_rate * grads['dW1']b1 = parameters['b1'] - learning_rate * grads['db1']W2 = parameters['W2'] - learning_rate * grads['dW2']b2 = parameters['b2'] - learning_rate * grads['db2']parameters = {"W1": W1,"b1": b1,"W2": W2,"b2": b2}return parameters

✅ 说明:

  • 使用标准梯度下降法更新参数
  • 学习率设为 1.2,可根据训练情况调整

九、整合训练流程

def nn_model(X, Y, n_h, num_iterations=10000, print_cost=False):np.random.seed(3)n_x = X.shape[0]n_y = Y.shape[0]parameters = initialize_parameters(n_x, n_h, n_y)for i in range(num_iterations):A2, cache = forward_propagation(X, parameters)cost = compute_cost(A2, Y, parameters)grads = backward_promagation(parameters, cache, X, Y)parameters = update_parameters(parameters, grads)if print_cost and i % 1000 == 0:print(f"第 {i} 次迭代后,成本为: {cost}")return parameters

✅ 功能:

  • 封装整个训练流程
  • 多次迭代优化参数
  • 打印成本变化趋势

十、测试与可视化

X_assess, Y_assess = nn_model_test_case()
parameters = nn_model(X_assess, Y_assess, 4, num_iterations=10000, print_cost=False)
# 绘制决策边界
plot_decision_boundary(lambda x: predict(x, parameters), X, Y)
plt.title("训练后的决策边界")
plt.show()

✅ 效果:

  • 使用 plot_decision_boundary 绘制出模型划分的区域
  • 可以看到,即使数据是非线性可分的,神经网络也能画出弯曲的边界进行分类

十一、关键知识点总结

概念作用
前向传播计算预测值
损失函数衡量预测误差
反向传播计算梯度
梯度下降更新参数以最小化损失
隐藏层提供非线性表达能力

十二、常见问题与改进方向

  1. 激活函数选择:建议将 tanh 改为 sigmoidReLU 更符合实际应用
  2. 损失函数:应使用 sigmoid + binary cross entropy 保证数值稳定性
  3. 正则化:可加入 L2 正则项防止过拟合
  4. 优化器:可尝试 Adam、RMSProp 等现代优化算法
  5. 多层网络:扩展为更深的网络(如 2 层隐藏层)

✅ 结语

通过这篇文章,你已经亲手实现了一个完整的浅层神经网络!虽然它简单,但它包含了深度学习的核心思想:

前向传播 → 计算损失 → 反向传播 → 参数更新

这正是所有深度学习框架背后的底层逻辑。


记住:真正的理解来自于亲手编码。
当你能写出这段代码,并读懂每一行的意义时,你就不再是“黑箱使用者”,而是真正的 AI 探索者。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/995333.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux《Socket编程Tcp》 - 指南

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

2025年中国珍珠奶茶加盟TOP10一线品牌榜

2025年珍珠奶茶行业步入结构化竞争新阶段,全球市场规模达 24.8 亿美元,国内现制茶饮市场逼近 3000亿元。消费端健康化诉求凸显,低糖、植物基、高蛋白产品成为主流,Z世代主导定制化消费与社交传播。技术层面,AI 配…

大学生必备APP精选:助力学业与生活的实用工具

大学生必备APP精选:助力学业与生活的实用工具工欲善其事,必先利其器。一款得心应手的APP,能让大学生活事半功倍。 在当今数字化校园中,选择合适的应用程序能极大提升学习效率和生活便利性。本文将为你介绍几款在语…

什么是 Spring AOP - Higurashi

AOP(面向切面编程)是 Spring 两大核心之一,它是一种编程思想,是对 OOP 的一种补充。它通过横向抽取共性功能(如日志、事务),解决代码重复和耦合问题,提升代码复用性和可维护性。它的底层是通过动态代理实现的。…

2025最新油田助剂厂家推荐榜:实力企业赋能油气开发,全国优质供应商精选

在油气勘探开发的钻井、采油、压裂等关键环节,油田助剂的性能直接关系到作业效率、采收率与作业安全。选择技术成熟、供应稳定、服务完善的厂家,是油田企业实现降本增效的重要保障。以下结合企业综合实力、产品适配性…

如何在Flutter中使用CustomPainter实现自定义绘制?

在 Flutter 中,CustomPainter是实现自定义绘制的核心组件,可灵活绘制图形、路径、文本、渐变甚至复杂动效,其核心逻辑是通过重写paint()(定义绘制逻辑)和shouldRepaint()(控制重绘时机)来实现自定义视觉效果。以…

Linux 中文本显示字体以颜色突出

Linux 中文本显示字体以颜色突出 001、绿色002、红色

博弈论模型中的学习与算法设计

本文探讨了博弈论模型中的学习问题,特别是算法博弈论的应用。文章分析了在重复博弈中,参与者如何通过学习最大化自身奖励,以及如何设计游戏结构以同时优化个人与集体利益,并研究了存在遗留效应环境下的学习算法。v…

《Zephyr RTOS 深度学习指南与生成式AI结合方法探讨》第六章 - 详解

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

2025 年 12 月上海逃生装备厂家权威推荐榜:聚焦逃生滑道、缓降管、应急器材,解析智能与柔性技术的安全守护之选

2025 年 12 月上海逃生装备厂家权威推荐榜:聚焦逃生滑道、缓降管、应急器材,解析智能与柔性技术的安全守护之选 随着城市化进程的加速,高层及超高层建筑日益增多,火灾等突发性公共安全事件的应急逃生需求变得前所未…

HiAgent vs Coze:企业级智能体平台大对比

HiAgent vs Coze:企业级智能体平台大对比Posted on 2025-12-09 18:00 Java后端的Ai之路 阅读(0) 评论(0) 收藏 举报HiAgent vs Coze:企业级智能体平台的深度对比 专业术语解释 HiAgent HiAgent是字节跳动火山引…

关于敏感信息检测技术的理论知识

在之前的文章中,探索了不同的检测敏感信息的方法,并通过Demo进行了学习,对算法、模型等一些概念有一些初步认知,这片文章想更加完整的学习涉及的概念,以及知识框架。 信息识别 “敏感信息检测”本质上是一种信息识…

自定义拦截器不生效问题记录

新项目里面我把之前的告警添加了进来,添加后发现有个问题:我新增的拦截器一直不生效:我的代码如下Configuration public class OraDingdingConfigurer implements WebMvcConfigurer, Interceptor {/*** 拦截器参数校…

2025年地毯品牌最新推荐榜,聚焦企业技术创新、原料品质与市场口碑深度解析羊毛,无胶,可拆洗双层,客厅,卧室,中古风,儿童房,可拆洗,床边,无胶防水地毯公司推荐

引言 随着家居消费升级,健康环保与设计美学成为地毯选购核心诉求,为精准筛选优质品牌,本次推荐榜依托中国家用纺织品行业协会(CNTAC)2024-2025 年度地毯品类测评数据,结合第三方检测机构 SGS 的 128 项指标检测结…

中美跨境国际快递配送清单:轻小件低价寄,带电_特货合规清关

2025 年中美跨境电商轻小件需求同比增长 45%,饰品、3C 配件等 0.5-10KG 包裹占比超 60%,但 “低价难寻合规渠道、带电特货清关险、轨迹追踪不透明” 仍是核心痛点。第三方数据显示,48% 卖家曾因 “敏感货扣关” 损失…

Elasticsearch:如何为 Elastic Stack 部署 E5 模型 - 下载及隔离环境 - 详解

Elasticsearch:如何为 Elastic Stack 部署 E5 模型 - 下载及隔离环境 - 详解2025-12-09 17:51 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word-wrap: normal !important; overflow-…

JVM运营内存清空查看

ps -ef | grep "java" 找到定开服务PID 然后 jmap -heap PID 可以看到对应jvm 内存分配情况

Flutter 应该如何实现 iOS 26 的 Liquid Glass

要在 Flutter 中实现 iOS 26 的Liquid Glass(液态玻璃) 视觉交互效果,需先明确 Liquid Glass 的核心特征:iOS 26 推出的液态玻璃质感聚焦「动态流体形变、玻璃拟态(Glassmorphism)进阶版、触控反馈的液态柔化、层…

IIS反向代理

模块安装 首先安装代理需要的模块,Application Request Routing Cache和URL重写(URL Rewrite)两个模块 下载地址: Application Request Routing Cache URL重写(URL Rewrite) 注:Application Request Routing …