神经网络代码入门解析

神经网络代码入门解析

import torch
import matplotlib.pyplot as pltimport randomdef create_data(w, b, data_num):  # 数据生成x = torch.normal(0, 1, (data_num, len(w)))y = torch.matmul(x, w) + b  # 矩阵相乘再加bnoise = torch.normal(0, 0.01, y.shape)  # 为y添加噪声y += noisereturn x, ynum = 500true_w = torch.tensor([8.1, 2, 2, 4])
true_b = 1.1X, Y = create_data(true_w, true_b, num)# plt.scatter(X[:, 3], Y, 1)  # 画散点图 对X取全部的行的第三列,标签Y,点大小
# plt.show()def data_provider(data, label, batchsize):  # 每次取batchsize个数据length = len(label)indices = list(range(length))# 这里需要把数据打乱random.shuffle(indices)for each in range(0, length, batchsize):get_indices = indices[each: each+batchsize]get_data = data[get_indices]get_label = label[get_indices]yield get_data, get_label  # 有存档点的returnbatchsize = 16
# for batch_x, batch_y in data_provider(X, Y, batchsize):
#     print(batch_x, batch_y)
#     break# 定义模型
def fun(x, w, b):pred_y = torch.matmul(x, w) + breturn pred_y# 定义loss
def maeLoss(pre_y, y):return torch.sum(abs(pre_y-y))/len(y)# sgd(梯度下降)
def sgd(paras, lr):with torch.no_grad():  # 这部分代码不计算梯度for para in paras:para -= para.grad * lr  # 不能写成 para = para - paras.grad * lr !!!! 这句相当于要创建一个新的para,会导致报错para.grad.zero_()  # 将使用过的梯度归零lr = 0.01
w_0 = torch.normal(0, 0.01, true_w.shape, requires_grad=True)
b_0 = torch.tensor(0.01, requires_grad=True)
print(w_0, b_0)epochs = 50
for epoch in range(epochs):data_loss = 0for batch_x, batch_y in data_provider(X, Y, batchsize):pred_y = fun(batch_x, w_0, b_0)loss = maeLoss(pred_y, batch_y)loss.backward()sgd([w_0, b_0], lr)data_loss += lossprint("epoch %03d: loss: %.6f" % (epoch, data_loss))print("真实函数值:", true_w, true_b)
print("训练得到的函数值:", w_0, b_0)idx = 0
plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy())
plt.scatter(X[:, idx].detach().numpy(), Y, 1)
plt.show()

逐步分析代码

1.数据生成

image-20250301120222530

首先设计一个函数create_data,提供我们所需要的数据集的x与y

def create_data(w, b, data_num):  # 数据生成x = torch.normal(0, 1, (data_num, len(w)))  # 生成特征数据,形状为 (data_num, len(w))y = torch.matmul(x, w) + b  # 计算目标值 y = x * w + bnoise = torch.normal(0, 0.01, y.shape)  # 生成噪声,形状与 y 相同y += noise  # 为 y 添加噪声,模拟真实数据中的随机误差return x, y
  • torch.normal() 生成一个张量

    • torch.normal(0, 1, (data_num, len(w))):生成一个形状为 (data_num, len(w)) 的张量,其中的元素是从均值为 0、标准差为 1 的正态分布中随机采样的。
  • torch.matmul() 让矩阵相乘

    matmul: matrix multiply

  • 再使用torch.normal()生成一个张量,添加到y上,相当于为y添加了随机的噪声

    噪声的引入是为了模拟真实数据中的随机误差,使生成的数据更接近现实场景。

2.设计一个数据加载器

def data_provider(data, label, batchsize):  # 每次取 batchsize 个数据length = len(label)indices = list(range(length))random.shuffle(indices)  # 打乱数据顺序,避免模型学习到顺序特征for each in range(0, length, batchsize):get_indices = indices[each: each+batchsize]  # 获取当前批次的索引get_data = data[get_indices]  # 获取当前批次的数据get_label = label[get_indices]  # 获取当前批次的标签yield get_data, get_label  # 返回当前批次的数据和标签

data_provider可以分批提供数据,并通过yield来返回已实现记忆功能

首先把list y顺序打乱,这样就相当于从生成的训练集y中随机读取,若不打乱数据,可能造成训练结果的不理想

打乱数据可以避免模型在训练过程中学习到数据的顺序特征,从而提高模型的泛化能力。

之后分段遍历打乱的y,返回对应的局部的数据集来给神经网络进行训练

3.定义模型函数

image-20250301122853184

def fun(x, w, b):pred_y = torch.matmul(x, w) + b  # 计算预测值 y = x * w + breturn pred_y

fun(x, w, b) 是一个线性模型,形式为 y = x * w + b,其中 x 是输入特征,w 是权重,b 是偏置。

4.定义Loss函数

image-20250301122958888

def maeLoss(pre_y, y):return torch.sum(abs(pre_y - y)) / len(y)  # 计算平均绝对误差 (MAE)
  • maeLoss 是平均绝对误差(Mean Absolute Error, MAE),它计算预测值 pre_y 和真实值 y 之间的绝对误差的平均值。
  • 公式为:MAE = (1/n) * Σ|pre_y - y|,其中 n 是样本数量。

5.梯度下降sgd函数

# sgd(梯度下降)
def sgd(paras, lr):with torch.no_grad():  # 这部分代码不计算梯度for para in paras:para -= para.grad * lr  # 不能写成 para = para - paras.grad * lr !!!! 这句相当于要创建一个新的para,会导致报错para.grad.zero_()  # 将使用过的梯度归零

这里需要使用torch.no_grad()来避免重复计算梯度

image-20250301123531781

在前向过程中已经累计过一次梯度了,如果在梯度下降过程中又累计了梯度,那么就会造成不必要的麻烦

PyTorch 会累积梯度,如果不手动清零,梯度会不断累积,导致参数更新错误。

para -= para.grad * lr就是将参数w修正的过程(w=w-(dy^/dw)*learningRate)

torch.no_grad() 是一个上下文管理器,用于禁用梯度计算。在参数更新时,禁用梯度计算可以避免不必要的计算和内存占用。

5.开始训练

epochs = 50
for epoch in range(epochs):data_loss = 0num_batches = len(Y) // batchsize  # 计算批次数量for batch_x, batch_y in data_provider(X, Y, batchsize):pred_y = fun(batch_x, w_0, b_0)  # 前向传播loss = maeLoss(pred_y, batch_y)  # 计算损失loss.backward()  # 反向传播sgd([w_0, b_0], lr)  # 更新参数data_loss += loss.item()  # 累积损失print("epoch %03d: loss: %.6f" % (epoch, data_loss / num_batches))  # 打印平均损失

先定义一个训练轮次epochs=50,表示训练50轮

在每轮训练中将loss记录下来,以此评价训练的效果

首先用data_provider来获取数据集中随机的一部分

接着传入相应数据给模型函数,通过前向传播获得预测y值pred_y

调用Loss计算函数,获取这次的loss,再通过反向传播loss.backward()计算梯度

loss.backward() 是反向传播的核心步骤,用于计算损失函数对模型参数的梯度。

再通过梯度下降sgd([w_0, b_0], lr)来更新模型的参数

最终将这组数据的loss累加到这轮数据的loss中

6.结果绘制

idx = 0
plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy() * w_0[idx].detach().numpy() + b_0.detach().numpy())  # 绘制预测直线
plt.scatter(X[:, idx].detach().numpy(), Y, 1)  # 绘制真实数据点
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/71113.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

DeepSeek 开源狂欢周(一)FlashMLA:高效推理加速新时代

上周末,DeepSeek在X平台(Twitter)宣布将开启连续一周的开源,整个开源社区为之沸腾,全球AI爱好者纷纷为关注。没错,这是一场由DeepSeek引领的开源盛宴,推翻了传统推理加速的种种限制。这周一&…

EfficientViT模型详解及代码复现

核心架构 在EfficientViT模型的核心架构中,作者设计了一种创新的 sandwich布局 作为基础构建块,旨在提高内存效率和计算效率。这种布局巧妙地平衡了自注意力层和前馈神经网络层的比例,具体结构如下: 基于深度卷积的Token Interaction :通过深度卷积操作对输入特征进行初步…

大语言模型(LLM)如何赋能时间序列分析?

引言 近年来,大语言模型(LLM)在文本生成、推理和跨模态任务中展现了惊人能力。与此同时,时间序列分析作为工业、金融、物联网等领域的核心技术,长期依赖传统统计模型(如ARIMA)或深度学习模型&a…

Java 设计模式:软件开发的精髓与艺

目录 一、设计模式的起源二、设计模式的分类1. 创建型模式2. 结构型模式3. 行为型模式三、设计模式的实践1. 单例模式2. 工厂模式3. 策略模式四、设计模式的优势五、设计模式的局限性六、总结在软件开发的浩瀚星空中,设计模式犹如一颗颗璀璨的星辰,照亮了开发者前行的道路。它…

【基于Raft的KV共识算法】-序:Raft概述

本文目录 1.为什么会有Raft?CAP理论 2.Raft基本原理流程为什么要以日志作为中间载体? 3.实现思路任期领导选举日志同步 1.为什么会有Raft? 简单来说就是数据会随着业务和时间的增长,单机不能存的下,这个时候需要以某种…

【愚公系列】《Python网络爬虫从入门到精通》040-Matplotlib 概述

标题详情作者简介愚公搬代码头衔华为云特约编辑,华为云云享专家,华为开发者专家,华为产品云测专家,CSDN博客专家,CSDN商业化专家,阿里云专家博主,阿里云签约作者,腾讯云优秀博主,腾讯云内容共创官,掘金优秀博主,亚马逊技领云博主,51CTO博客专家等。近期荣誉2022年度…

EasyRTC嵌入式WebRTC技术与AI大模型结合:从ICE框架优化到AI推理

实时通信技术在现代社会中扮演着越来越重要的角色,从视频会议到在线教育,再到远程医疗,其应用场景不断拓展。WebRTC作为一项开源项目,为浏览器和移动应用提供了便捷的实时通信能力。而EasyRTC作为基于WebRTC的嵌入式解决方案&…

javaEE初阶————多线程初阶(5)

本期是多线程初阶的最后一篇文章了,下一篇就是多线程进阶的文章了,大家加油! 一,模拟实现线程池 我们上期说过线程池类似一个数组,我们有任务就放到线程池中,让线程池帮助我们完成任务,我们该如…

工业AR眼镜的‘芯’动力:FPC让制造更智能【新立电子】

随着增强现实(AR)技术的快速发展,工业AR智能眼镜也正逐步成为制造业领域的重要工具。它不仅为现场工作人员提供了视觉辅助,还极大地提升了远程协助的效率、优化了仓储管理。FPC在AI眼镜中的应用,为工业AR智能眼镜提供了…

FPGA开发,使用Deepseek V3还是R1(5):temperature设置

以下都是Deepseek生成的答案 FPGA开发,使用Deepseek V3还是R1(1):应用场景 FPGA开发,使用Deepseek V3还是R1(2):V3和R1的区别 FPGA开发,使用Deepseek V3还是R1&#x…

网站内容更新后百度排名下降怎么办?有效策略有哪些?

转自 网站内容更新后百度排名下降怎么办?有效策略有哪些? 网站内容更新是促进网站优化的关键环节,但是频繁修改网站内容会对网站的搜索引擎排名造成很大的影响。为了保持网站排名,我们需要采取一些措施来最小化对百度排名的影响。…

安装 cpolar 内网穿透工具的步骤

安装 cpolar 内网穿透工具的步骤 1. 下载 cpolar 软件安装包 步骤: 前往 cpolar 官方下载页面。 根据您的操作系统(Windows、macOS、Linux 等),选择对应的安装包进行下载。 2. 注册 cpolar 账号 步骤: 访问 cpolar…

Linux :进程状态

目录 1 引言 2 操作系统的资源分配 3进程状态 3.1运行状态 3.2 阻塞状态 3.3挂起状态 4.进程状态详解 4.1 运行状态R 4.2 休眠状态S 4.3深度睡眠状态D 4.4僵尸状态Z 5 孤儿进程 6 进程优先级 其他概念 1 引言 🌻在前面的文章中,我们已…

openwebUI访问vllm加载deepseek微调过的本地大模型

文章目录 前言一、openwebui安装二、配置openwebui环境三、安装vllm四、启动vllm五、启动openwebui 前言 首先安装vllm,然后加载本地模型,会起一个端口好。 在安装openwebui,去访问这个端口号。下面具体步骤的演示。 一、openwebui安装 rootautodl-co…

DeepSeek-V3:AI语言模型的高效训练与推理之路

参考:【论文学习】DeepSeek-V3 全文翻译 在人工智能领域,语言模型的发展日新月异。从早期的简单模型到如今拥有数千亿参数的巨无霸模型,技术的进步令人瞩目。然而,随着模型规模的不断扩大,训练成本和推理效率成为了摆在…

Spring单例模式 Spring 中的单例 饿汉式加载 懒汉式加载

目录 核心特性 实现方式详解 1. 饿汉式(Eager Initialization) 2. 懒汉式(Lazy Initialization) 3. 静态内部类(Bill Pugh 实现) 4. 枚举(Enum) 破坏单例的场景及防御 Sprin…

DeepSeek MLA(Multi-Head Latent Attention)算法浅析

目录 前言1. 从MHA、MQA、GQA到MLA1.1 MHA1.2 瓶颈1.3 MQA1.4 GQA1.5 MLA1.5.1 Part 11.5.2 Part 21.5.3 Part 3 结语参考 前言 学习 DeepSeek 中的 MLA 模块,究极缝合怪,东抄抄西抄抄,主要 copy 自苏神的文章,仅供自己参考&#…

uniapp 中引入使用uView UI

文章目录 一、前言:选择 uView UI的原因二、完整引入步骤1. 安装 uView UI2. 配置全局样式变量(关键!)3. 在 pages.json中添加:4. 全局注册组件5. 直接使用组件 五、自定义主题色(秒换皮肤) 一、…

zookeeper-docker版

Zookeeper-docker版 1 zookeeper概述 1.1 什么是zookeeper Zookeeper是一个分布式的、高性能的、开源的分布式系统的协调(Coordination)服务,它是一个为分布式应用提供一致性服务的软件。 1.2 zookeeper应用场景 zookeeper是一个经典的分…

【量化金融自学笔记】--开篇.基本术语及学习路径建议

在当今这个信息爆炸的时代,金融领域正经历着一场前所未有的变革。传统的金融分析方法逐渐被更加科学、精准的量化技术所取代。量化金融,这个曾经高不可攀的领域,如今正逐渐走进大众的视野。它将数学、统计学、计算机科学与金融学深度融合&…