【单层神经网络】基于MXNet的线性回归实现(底层实现)

写在前面

  1. 刚开始先从普通的寻优算法开始,熟悉一下学习训练过程
  2. 下面将使用梯度下降法寻优,但这大概只能是局部最优,它并不是一个十分优秀的寻优算法

整体流程

  1. 生成训练数据集(实际工程中,需要从实际对象身上采集数据)
  2. 确定模型及其参数(输入输出个数、阶次,偏置等)
  3. 确定学习方式(损失函数、优化算法,学习率,训练次数,终止条件等)
  4. 读取数据集(不同的读取方式会影响最终的训练效果)
  5. 训练模型

完整程序及注释

from IPython import display
from matplotlib import pyplot as plt
from mxnet import autograd, nd
import random'''
获取(生成)训练集
'''
input_num = 2				# 输入个数
examples_num = 1000			# 生成样本个数
# 确定真实模型参数
real_W = [10.9, -8.7]		
real_bias = 6.5	features = nd.random.normal(scale=1, shape=(examples_num, input_num))       # 标准差=1,均值缺省=0
labels = real_W[0]*features[:,0] + real_W[1]*features[:,1] + real_bias		# 根据特征和参数生成对应标签
labels_noise = labels + nd.random.normal(scale=0.1, shape=labels.shape)		# 为标签附加噪声,模拟真实情况# 绘制标签和特征的散点图(矢量图)
# def use_svg_display():
#     display.set_matplotlib_formats('svg')# def set_figure_size(figsize=(3.5,2.5)):
#     use_svg_display()
#     plt.rcParams['figure.figsize'] = figsize# set_figure_size()
# plt.scatter(features[:,0].asnumpy(), labels_noise.asnumpy(), 1)
# plt.scatter(features[:,1].asnumpy(), labels_noise.asnumpy(), 1)
# plt.show()# 创建一个迭代器(确定从数据集获取数据的方式)
def data_iter(batch_size, features, labels):num = len(features)indices = list(range(num))                                  # 生成索引数组random.shuffle(indices)                                     # 打乱indices# 该遍历方式同时确保了随机采样和无遗漏for i in range(0, num, batch_size):j = nd.array(indices[i: min(i+batch_size, num)])        # 对indices从i开始取,取batch_size个样本,并转换为列表yield features.take(j), labels.take(j)                  # take方法使用索引数组,从features和labels提取所需数据"""
训练的基础准备
"""
# 声明训练变量,并赋高斯随机初始值
w = nd.random.normal(scale=0.01, shape=(input_num))
b = nd.zeros(shape=(1,))
# b = nd.zeros(1)       # 不同写法,等价于上面的
w.attach_grad()         # 为需要迭代的参数申请求梯度空间
b.attach_grad()# 定义模型
def linreg(X, w, b):return nd.dot(X,w)+b# 定义损失函数
def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) **2 /2# 定义寻优算法
def sgd(params, learning_rate, batch_size):for param in params:# 新参数 = 原参数 - 学习率*当前批量的参数梯度/当前批量的大小param[:] = param - learning_rate * param.grad / batch_size# 确定超参数和学习方式
lr = 0.03
num_iterations = 5
net = linreg				# 目标模型
loss = squared_loss			# 代价函数(损失函数)
batch_size = 10				# 每次随机小批量的大小'''
开始训练
'''
for iteration in range(num_iterations):		# 确定迭代次数for x, y in data_iter(batch_size, features, labels):with autograd.record():l = loss(net(x,w,b), y)			# 求当前小批量的总损失l.backward()						# 求梯度sgd([w,b], lr, batch_size)			# 梯度更新参数train_l = loss(net(features,w,b), labels)print("iteration %d, loss %f" % (iteration+1, train_l.mean().asnumpy()))
# 打印比较真实参数和训练得到的参数
print("real_w " + str(real_W) + "\n train_w " + str(w))
print("real_w " + str(real_bias) + "\n train_b " + str(b))

具体程序解释

param[:] = param - learning_rate * param.grad / batch_size
将batch_size与参数调整相关联的原因,是为了使得每次更新的步长不受批次大小的影响
具体来说,当计算一批数据的损失函数的梯度时,实际上是将这批数据中每个样本对损失函数的贡献累加起来。这意味着如果批次较大,梯度的模也会相应增大
故更新权值时,使用的是数据集的平均梯度,而不是总和

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/894648.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

本地快速部署DeepSeek-R1模型——2025新年贺岁

一晃年初六了,春节长假余额马上归零了。今天下午在我的电脑上成功部署了DeepSeek-R1模型,抽个时间和大家简单分享一下过程: 概述 DeepSeek模型 是一家由中国知名量化私募巨头幻方量化创立的人工智能公司,致力于开发高效、高性能…

C++11详解(一) -- 列表初始化,右值引用和移动语义

文章目录 1.列表初始化1.1 C98传统的{}1.2 C11中的{}1.3 C11中的std::initializer_list 2.右值引用和移动语义2.1左值和右值2.2左值引用和右值引用2.3 引用延长生命周期2.4左值和右值的参数匹配问题2.5右值引用和移动语义的使用场景2.5.1左值引用主要使用场景2.5.2移动构造和移…

在K8S中,pending状态一般由什么原因导致的?

在Kubernetes中,资源或Pod处于Pending状态可能有多种原因引起。以下是一些常见的原因和详细解释: 资源不足 概述:当集群中的资源不足以满足Pod或服务的需求时,它们可能会被至于Pending状态。这通常涉及到CPU、内存、存储或其他资…

手写MVVM框架-构建虚拟dom树

MVVM的核心之一就是虚拟dom树,我们这一章节就先构建一个虚拟dom树 首先我们需要创建一个VNode的类 // 当前类的位置是src/vnode/index.js export default class VNode{constructor(tag, // 标签名称(英文大写)ele, // 对应真实节点children,…

linux内核源代码中__init的作用?

在 Linux 内核源代码中,__init是一个特殊的宏,用于标记在内核初始化阶段使用的变量或函数。这个宏的作用是告诉内核编译器和链接器,被标记的变量或函数只在内核的初始化阶段使用,在系统启动完成后就不再需要了。因此,这…

【大数据技术】教程03:本机PyCharm远程连接虚拟机Python

本机PyCharm远程连接虚拟机Python 注意:本文需要使用PyCharm专业版。 pycharm-professional-2024.1.4VMware Workstation Pro 16CentOS-Stream-10-latest-x86_64-dvd1.iso写在前面 本文主要介绍如何使用本地PyCharm远程连接虚拟机,运行Python脚本,提高编程效率。 注意: …

pytorch实现门控循环单元 (GRU)

人工智能例子汇总:AI常见的算法和例子-CSDN博客 特性GRULSTM计算效率更快,参数更少相对较慢,参数更多结构复杂度只有两个门(更新门和重置门)三个门(输入门、遗忘门、输出门)处理长时依赖一般适…

PAT甲级1032、sharing

题目 To store English words, one method is to use linked lists and store a word letter by letter. To save some space, we may let the words share the same sublist if they share the same suffix. For example, loading and being are stored as showed in Figure …

最小生成树kruskal算法

文章目录 kruskal算法的思想模板 kruskal算法的思想 模板 #include <bits/stdc.h> #define lowbit(x) ((x)&(-x)) #define int long long #define endl \n #define PII pair<int,int> #define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0); using na…

为何在Kubernetes容器中以root身份运行存在风险?

作者&#xff1a;马辛瓦西奥内克&#xff08;Marcin Wasiucionek&#xff09; 引言 在Kubernetes安全领域&#xff0c;一个常见的建议是让容器以非root用户身份运行。但是&#xff0c;在容器中以root身份运行&#xff0c;实际会带来哪些安全隐患呢&#xff1f;在Docker镜像和…

js --- 获取时间戳

介绍 使用js获取当前时间戳 语法 Date.now()

ConcurrentHashMap线程安全:分段锁 到 synchronized + CAS

专栏系列文章地址&#xff1a;https://blog.csdn.net/qq_26437925/article/details/145290162 本文目标&#xff1a; 理解ConcurrentHashMap为什么线程安全&#xff1b;ConcurrentHashMap的具体细节还需要进一步研究 目录 ConcurrentHashMap介绍JDK7的分段锁实现JDK8的synchr…

Vue和Java使用AES加密传输

背景&#xff1a;Vue对参数进行加密&#xff0c;对响应进行解密。Java对参数进行解密&#xff0c;对响应进行解密。不拦截文件上传类请求、GET请求。 【1】前端配置 安装crypto npm install crypto-js编写加解密工具类encrypt.js import CryptoJS from crypto-jsconst KEY …

开发板目录 /usr/lib/fonts/ 中的字体文件 msyh.ttc 的介绍【微软雅黑(Microsoft YaHei)】

本文是博文 https://blog.csdn.net/wenhao_ir/article/details/145433648 的延伸扩展。 本文是博文 https://blog.csdn.net/wenhao_ir/article/details/145433648 的延伸扩展。 问&#xff1a;运行 ls /usr/lib/fonts/ 发现有一个名叫 msyh.ttc 的字体文件&#xff0c;能介绍…

[ESP32:Vscode+PlatformIO]新建工程 常用配置与设置

2025-1-29 一、新建工程 选择一个要创建工程文件夹的地方&#xff0c;在空白处鼠标右键选择通过Code打开 打开Vscode&#xff0c;点击platformIO图标&#xff0c;选择PIO Home下的open&#xff0c;最后点击new project 按照下图进行设置 第一个是工程文件夹的名称 第二个是…

述评:如果抗拒特朗普的“普征关税”

题 记 美国总统特朗普宣布对美国三大贸易夥伴——中国、墨西哥和加拿大&#xff0c;分别征收10%、25%的关税。 他威胁说&#xff0c;如果这三个国家不解决他对非法移民和毒品走私的担忧&#xff0c;他就要征收进口税。 去年&#xff0c;中国、墨西哥和加拿大这三个国家&#…

九. Redis 持久化-AOF(详细讲解说明,一个配置一个说明分析,步步讲解到位 2)

九. Redis 持久化-AOF(详细讲解说明&#xff0c;一个配置一个说明分析&#xff0c;步步讲解到位 2) 文章目录 九. Redis 持久化-AOF(详细讲解说明&#xff0c;一个配置一个说明分析&#xff0c;步步讲解到位 2)1. Redis 持久化 AOF 概述2. AOF 持久化流程3. AOF 的配置4. AOF 启…

C++11新特性之long long超长整形

1.介绍 long long 超长整形是C11标准新添加的&#xff0c;用于表示更大范围整数的类型。 2.用法 占用空间&#xff1a;至少64位&#xff08;8个字节&#xff09;。 对于有符号long long 整形&#xff0c;后缀用“LL”或“II”标识。例如&#xff0c;“10LL”就表示有符号超长整…

浏览器查询所有的存储信息,以及清除的语法

要在浏览器的控制台中查看所有的存储&#xff08;例如 localStorage、sessionStorage 和 cookies&#xff09;&#xff0c;你可以使用浏览器开发者工具的 "Application" 标签页。以下是操作步骤&#xff1a; 1. 打开开发者工具 在 Chrome 或 Edge 浏览器中&#xf…

基于Springboot框架的学术期刊遴选服务-项目演示

项目介绍 本课程演示的是一款 基于Javaweb的水果超市管理系统&#xff0c;主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的 Java 学习者。 1.包含&#xff1a;项目源码、项目文档、数据库脚本、软件工具等所有资料 2.带你从零开始部署运行本套系统 3.该项目附…