基于简单神经网络的线性回归

一、概述

本代码实现了一个简单的神经网络进行线性回归任务。通过生成包含噪声的线性数据集,定义一个简单的神经网络类,使用梯度下降算法训练网络以拟合数据,并最终通过可视化展示原始数据、真实线性关系以及模型的预测结果。

二、依赖库

  1. numpy:用于数值计算,包括生成数组、进行随机数操作、执行数学运算等。
  2. matplotlib.pyplot:用于数据可视化,绘制散点图和折线图以展示数据和模型的预测结果。

三、代码详解

1. 生成数据集

python

np.random.seed(42)
x = np.linspace(-10, 10, 100)
y = x + np.random.normal(0, 1, x.shape)  # 添加噪声

  • np.random.seed(42):设置随机数种子,确保每次运行代码时生成的随机数序列相同,从而使结果可复现。
  • np.linspace(-10, 10, 100):生成一个包含 100 个元素的一维数组x,元素均匀分布在 - 10 到 10 之间。
  • x + np.random.normal(0, 1, x.shape):生成因变量y,它基于真实的线性关系y = x,并添加了均值为 0、标准差为 1 的高斯噪声。np.random.normal(0, 1, x.shape)生成与x形状相同的随机噪声数组。

2. 定义神经网络(线性回归)

python

class SimpleNN:def __init__(self):self.w = np.random.randn()  # 权重self.b = np.random.randn()  # 偏置def forward(self, x):return self.w * x + self.b  # 前向传播def loss(self, y_true, y_pred):return np.mean((y_true - y_pred) **2)  # 均方误差def gradient(self, x, y_true, y_pred):dw = -2 * np.mean(x * (y_true - y_pred))  # 权重的梯度db = -2 * np.mean(y_true - y_pred)       # 偏置的梯度return dw, dbdef train(self, x, y, lr=0.01, epochs=1000):for epoch in range(epochs):y_pred = self.forward(x)dw, db = self.gradient(x, y, y_pred)self.w -= lr * dw  # 更新权重self.b -= lr * db  # 更新偏置if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {self.loss(y, y_pred):.4f}')

  • __init__方法:初始化神经网络的权重self.w和偏置self.b,使用np.random.randn()生成随机的初始值。
  • forward方法:实现前向传播,根据输入x、权重self.w和偏置self.b计算输出y_pred,即y_pred = self.w * x + self.b
  • loss方法:计算预测值y_pred和真实值y_true之间的均方误差(MSE),公式为np.mean((y_true - y_pred) ** 2)
  • gradient方法:计算权重self.w和偏置self.b的梯度。dw是权重的梯度,计算公式为-2 * np.mean(x * (y_true - y_pred))db是偏置的梯度,计算公式为-2 * np.mean(y_true - y_pred)
  • train方法:使用梯度下降算法训练神经网络。在指定的epochs(训练轮数)内,每次迭代进行前向传播计算预测值y_pred,然后计算梯度dwdb,根据学习率lr更新权重self.w和偏置self.b。每 100 轮打印一次当前轮数和损失值。

3. 训练模型

python

model = SimpleNN()
model.train(x, y, lr=0.01, epochs=1000)

  • SimpleNN():创建一个SimpleNN类的实例model
  • model.train(x, y, lr=0.01, epochs=1000):调用modeltrain方法,使用生成的数据集xy,学习率lr=0.01,训练轮数epochs=1000进行训练。

4. 可视化结果

python

y_pred = model.forward(x)
plt.scatter(x, y, label='Data points')
plt.plot(x, x, color='red', label='y = x')
plt.plot(x, y_pred, color='green', label='Predicted')
plt.legend()
plt.show()

  • model.forward(x):使用训练好的模型model对数据集x进行前向传播,得到预测值y_pred
  • plt.scatter(x, y, label='Data points'):绘制原始数据集的散点图,标签为Data points
  • plt.plot(x, x, color='red', label='y = x'):绘制真实的线性关系y = x的折线图,颜色为红色,标签为y = x
  • plt.plot(x, y_pred, color='green', label='Predicted'):绘制模型预测结果的折线图,颜色为绿色,标签为Predicted
  • plt.legend():显示图例,方便区分不同的图形。
  • plt.show():显示绘制好的图形。

四、注意事项

  1. 本代码实现的是一个简单的线性回归神经网络,实际应用中可能需要更复杂的模型结构和优化方法。
  2. 学习率lr和训练轮数epochs是超参数,可能需要根据具体数据和任务进行调整以获得更好的训练效果。
  3. 代码中使用的均方误差损失函数和梯度计算公式是针对线性回归问题的常见选择,但在其他问题中可能需要使用不同的损失函数和梯度计算方法。

完整代码

import numpy as np
import matplotlib.pyplot as plt# 1. 生成数据集
np.random.seed(42)
x = np.linspace(-10, 10, 100)
y = x + np.random.normal(0, 1, x.shape)  # 添加噪声# 2. 定义神经网络(线性回归)
class SimpleNN:def __init__(self):self.w = np.random.randn()  # 权重self.b = np.random.randn()  # 偏置def forward(self, x):return self.w * x + self.b  # 前向传播def loss(self, y_true, y_pred):return np.mean((y_true - y_pred) **2)  # 均方误差def gradient(self, x, y_true, y_pred):dw = -2 * np.mean(x * (y_true - y_pred))  # 权重的梯度db = -2 * np.mean(y_true - y_pred)       # 偏置的梯度return dw, dbdef train(self, x, y, lr=0.01, epochs=1000):for epoch in range(epochs):y_pred = self.forward(x)dw, db = self.gradient(x, y, y_pred)self.w -= lr * dw  # 更新权重self.b -= lr * db  # 更新偏置if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {self.loss(y, y_pred):.4f}')# 3. 训练模型
model = SimpleNN()
model.train(x, y, lr=0.01, epochs=1000)# 4. 可视化结果
y_pred = model.forward(x)
plt.scatter(x, y, label='Data points')
plt.plot(x, x, color='red', label='y = x')
plt.plot(x, y_pred, color='green', label='Predicted')
plt.legend()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/899703.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

‌19.思科路由器:OSPF协议引入直连路由的实验研究

思科路由器:OSPF协议引入直连路由的实验研究 一、实验拓扑二、基本配置2.1、sw1的配置2.2、开启交换机三层功能三、ospf的配置3.1、R1的配置3.2、R2的配置3.3、重启ospf进程四、引入直连路由五、验证结果随着互联网技术的不断发展,路由器作为网络互联的关键设备,其性能与稳定…

USB——删除注册表信息

文章目录 背景工具下载地址工具使用删除注册表信息背景 注测表中已记录这个设备的信息,但现在设备描述符又指定为了 WinUSB 设备,所以当设备再次插入的时候,不会发送 0xEE 命令,造成了枚举失败。 两种处理方式: 修改枚举时候的 VID/PID删除 USB 的注册表信息工具下载地址…

如何快速解决django报错:cx_Oracle.DatabaseError: ORA-00942: table or view does not exist

我们在使用django连接oracle进行编程时,使用model进行表映射对接oracle数据时,默认表名组成结构为:应用名_类名(如:OracleModel_test),故即使我们库中存在表test,运行查询时候&#…

从 0 到跑通的 Qt + OpenGL + VS 项目的完整流程

🧩 全流程目标: 在 Visual Studio 中成功打开、编译并运行一个 Qt OpenGL 项目(.vcxproj 格式) ✅ 第 1 步:安装必要环境 工具说明Visual Studio 2017 / 2019 / 2022必须勾选 “使用 C 的桌面开发” 和 “MSVC 工具…

鸿蒙开发03样式相关介绍(二)

文章目录 一、样式复用1.1 Styles修饰符1.2 Extend修饰符 二、多态样式 一、样式复用 在页面开发过程中,会出出现大量重复的样式设置代码,可以使用Styles和Extend修饰符将帮助我们进行样式复用。 1.1 Styles修饰符 Styles装饰器可以将多条样式设置提炼…

装饰器模式与模板方法模式实现MyBatis-Plus QueryWrapper 扩展

pom <dependency><groupId>com.github.yulichang</groupId><artifactId>mybatis-plus-join-boot-starter</artifactId> <!-- MyBatis 联表查询 --> </dependency>MPJLambdaWrapperX /*** 拓展 MyBatis Plus Join QueryWrapper 类&…

05-031-自考数据结构(20331)- 哈希表 - 例题分析

哈希表考题主要涵盖四大类型:1)函数设计类(如除留余数法计算地址,需掌握质数p的选择技巧);2)冲突处理类(线性探测法要解决堆积现象,链地址法需绘制链表结构);3)性能分析类(重点计算ASL,理解装填因子α的影响规律);4)综合应用类(如设计ISBN查询系统,需结合实际问…

rustdesk 自建服务器 key不匹配

请确保id_ed25519文件的权限为&#xff1a; -rw------- 1 root root 88 Apr 31 10:02 id_ed25519在rustdesk安装目录执行命令&#xff1a; chmod 700 id_ed25519

Dify 深度集成 MCP实现灾害应急响应

一、架构设计 1.1 分层架构 #mermaid-svg-5dVNjmixTX17cCfg {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-5dVNjmixTX17cCfg .error-icon{fill:#552222;}#mermaid-svg-5dVNjmixTX17cCfg .error-text{fill:#552222…

AI与.NET技术实操系列(三):在 .NET 中使用大语言模型(LLMs)

1. 引言 在技术迅猛发展的今天&#xff0c;大语言模型&#xff08;Large Language Models, LLMs&#xff09;已成为人工智能领域的核心驱动力之一。从智能对话系统到自动化内容生成&#xff0c;LLMs的应用正在深刻改变我们的工作与生活方式。对于.NET开发者而言&#xff0c;掌…

一个极简的词法分析器实现

文章目录 推荐&#xff1a;Tiny Lexer - 一个极简的C语言词法分析器特点核心代码实现学习价值扩展建议 用Java实现一个简单的词法分析器完整实现代码代码解析示例输出扩展建议 用Go实现极简词法分析器完整实现代码代码解析示例输出扩展建议 最近两天搞一个DSL&#xff0c;不得不…

强制用户裸奔,微软封锁唯一后门操作

周末刚结束&#xff0c;那个常年将「用户为中心」挂嘴边的微软又双叒叕开始作妖&#xff01; 不错&#xff0c;大伙儿今后可能再没法通过「OOBE\BYPASSNRO」命令绕过微软强制联网要求了。 熟悉 Windows 11 操作系统的都知道&#xff0c;除硬件上诸多限制外&#xff1b; 软件层…

大模型备案:拦截关键词列表与敏感词库深度解析

随着《生成式人工智能服务管理暂行办法》正式实施&#xff0c;大模型上线备案成为企业合规运营的核心环节。其中&#xff0c;敏感词库建设与拦截关键词列表管理直接关系内容安全红线&#xff0c;今天我们就来详细解析一下大模型备案的这一部分&#xff0c;希望对想要做备案的朋…

快速上手Linux系统输入输出

一、管理系统中的输入输出 1.什么是重定向&#xff1f; 将原本要输出到屏幕上的内容&#xff0c;重新输入到其他设备中或文件中 重定向类型包括 输入重定向输出重定向 2.输入重定向 指定设备&#xff08;通常是文件或命令的执行结果&#xff09;来代替键盘作为新的输入设…

文小言全新升级!多模型协作与智能语音功能带来更流畅的AI体验

文小言全新升级&#xff01;多模型协作与智能语音功能带来更流畅的AI体验 在3月31日的百度AI DAY上&#xff0c;文小言正式宣布了一系列令人兴奋的品牌焕新与功能升级。此次更新不仅带来了全新的品牌视觉形象&#xff0c;更让文小言在智能助手的技术和用户体验方面迈上了一个新…

C++基础算法(插入排序)

1.插入排序 插入排序&#xff08;Insertion Sort&#xff09;介绍&#xff1a; 插入排序是一种简单直观的排序算法&#xff0c;它的工作原理类似于我们整理扑克牌的方式。 1.基本思想 插入排序的基本思想是&#xff1a; 1.将数组分为已排序和未排序两部分 2.每次从未排序部分…

k近邻算法K-Nearest Neighbors(KNN)

算法核心 KNN算法的核心思想是“近朱者赤&#xff0c;近墨者黑”。对于一个待分类或预测的样本点&#xff0c;它会查找训练集中与其距离最近的K个样本点&#xff08;即“最近邻”&#xff09;。然后根据这K个最近邻的标签信息来对当前样本进行分类或回归。 在分类任务中&#…

【Feign】⭐️使用 openFeign 时传递 MultipartFile 类型的参数参考

&#x1f4a5;&#x1f4a5;✈️✈️欢迎阅读本文章❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;本篇文章阅读大约耗时三分钟。 ⛳️motto&#xff1a;不积跬步、无以千里 &#x1f4cb;&#x1f4cb;&#x1f4cb;本文目录如下&#xff1a;&#x1f381;&#x1f381;&a…

zk基础—1.一致性原理和算法二

大纲 1.分布式系统特点 2.分布式系统的理论 3.两阶段提交Two-Phase Commit(2PC) 4.三阶段提交Three-Phase Commit(3PC) 5.Paxos岛的故事来对应ZooKeeper 6.Paxos算法推导过程 7.Paxos协议的核心思想 8.ZAB算法简述 6.Paxos算法推导过程 (1)Paxos的概念 (2)问题描述 …

216. 组合总和 III 回溯

目录 问题描述 解决思路 关键点 代码实现 代码解析 1. 初始化结果和路径 2. 深度优先搜索&#xff08;DFS&#xff09; 3. 遍历候选数字 4. 递归与回溯 示例分析 复杂度与优化 回溯算法三部曲 1. 路径选择&#xff1a;记录当前路径 2. 递归探索&#xff1a;进入下…