机器学习深度学习——线性回归的简洁实现

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——线性回归的从零开始实现
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

由于数据迭代器、损失函数、优化器以及神经网络很常用,现代深度学习库也为我们实现了这些组件。

线性回归的简洁实现

  • 生成数据集
  • 读取数据集
  • 定义模型
  • 初始化模型参数
  • 定义损失函数
  • 定义优化算法
  • 训练

生成数据集

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2ltrue_w = torch.tensor([2, -3.4])
true_b = 4.2
# d2l.synthetic_data将会生成y=Xw+b,其中函数原理可以看上一章节
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

读取数据集

可以调用框架中现有API来读取数据。我们将features和labels作为API的参数传递,并通过数据迭代器指定batch_size。此外,布尔值is_train表示是够希望数据迭代器对象在每个迭代周期内打乱数据。

def load_array(data_arrays, batch_size, is_train=True):  #@save"""构造一个pytorch数据迭代器"""# 传入数据。其中*表示对list解开入参,也就是把列表元素分别当做参数传入dataset = data.TensorDataset(*data_arrays)# 随机从数据集中取出batch_size个数量的参数return data.TensorDataset(dataset, batch_size, shuffle=is_train)# 读取数据集
batch_size = 10  # 小批量样本大小
data_iter = load_array((features, labels), batch_size)

如果我们要读取并打印第一个小批量样本,因为我们这边是使用了iter来构造了Python迭代器,因此用next来从迭代器中获取第一项:

print(next(iter(data_iter)))

结果:

[tensor([[-0.1996, 0.5686],
[ 0.6253, -0.1051],
[ 0.4497, -0.2051],
[ 0.3645, -0.4241],
[-2.6413, -0.4506],
[ 1.4606, 1.3924],
[-0.2853, -0.4866],
[ 1.0096, 0.5627],
[ 0.4851, -0.0612],
[ 0.5598, 1.4693]]), tensor([[1.8608],
[5.8130],
[5.8021],
[6.3728],
[0.4482],
[2.3855],
[5.2871],
[4.3162],
[5.3818],
[0.3218]])]

定义模型

对于标准深度学习模型,我们可以使用框架的预定好的层,我们只需要知道使用哪些层来构造模型,而不必关注层的实现细节。
我们首先定义一个模型变量net,这是一个Sequential(顺序)类的实例。Sequential类将多个层串联在一起。当给定输入数据时,Sequential实例将数据传入第一层,然后将第一层输出作为第二层的输入,一次类推。
这里的例子里面,模型只包含一个层,因此实际上不需要Sequential。
回归之前所说的单层网络架构,这一单层被称为全连接层,因为其每个输入都通过矩阵-向量乘法得到它的每个输出。
在PyTorch中,全连接层在Linear类中定义。 我们将两个参数传递到nn.Linear中。 第一个指定输入特征形状,即2,第二个指定输出特征形状,输出特征形状为单个标量,因此为1。

# nn是神经网络的缩写
from torch import nn# 定义神经网络模型
net = nn.Sequential(nn.Linear(2, 1))

初始化模型参数

在这里,我们指定每个权重参数应该从均值为0、标准差为0.01的正态分布中随机采样, 偏置参数将初始化为零。
正如我们在构造nn.Linear时指定输入和输出尺寸一样, 现在我们能直接访问参数以设定它们的初始值。
我们通过net[0]选择网络中的第一个图层, 然后使用weight.data和bias.data方法访问参数。

# 初始化模型参数
net[0].weight.data.normal_(0, 0.01)  # 均值为0,标准差为0.01的正态分布
net[0].bias.data.fill_(0)  # 偏置参数设为0

定义损失函数

利用MSELoss来计算均方误差:

# 定义损失函数
loss = nn.MSELoss()

定义优化算法

小批量随机梯度下降算法是一种优化神经网络的工具,pytorch在optim模块中实现了改算法的多个变种。
当我们要实例化一个SGD实例时,需要指定优化的参数(可通过net.parameters()从我们的模型中获得)及优化算法所需要的超参数字典。而SCG只需要设置lr即可:

# 定义优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.03)

训练

在每个迭代周期,我们将完整遍历一次数据集,不停地从中获取一个小批量的输入和相应的标签。而后,对于每个小批量,都会进行以下步骤:
1、通过调用net(X)生成预测并计算损失l(前向传播)。
2、通过进行反向传播来计算梯度。
3、通过调用优化器来更新模型参数。

num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X), y)  # 计算损失trainer.zero_grad()  # 梯度清零l.backward()  # 损失后向传播trainer.step()  # 更新网络参数l = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {l:f}')

结果:

epoch 1, loss 0.000185
epoch 2, loss 0.000097
epoch 3, loss 0.000098

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/11178.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

全球程序员需要知道的50+网址,有多少你第一次听说?

作为程序员,需要知道的50网址,有多少你第一次听说 GitHub (github.com): 最大的代码托管平台,开源项目和代码分享的社区。程序员可以在这里找到各种有趣的项目,参与开源贡献或托管自己的代码。 Stack Overflow (stackoverflow.co…

Python[parquet文件 转 json文件]

将Python中的Parquet文件转换为JSON文件 引言 Parquet是一种高效的列式存储格式,而JSON是一种常见的数据交换格式。我们将使用pandas和pyarrow库来实现这个转换过程,并且提供相关的代码示例。 安装所需库 首先,请确保您已经安装了pandas和…

Rust: Vec类型的into_boxed_slice()方法

比如&#xff0c;我们经常看到Vec类型&#xff0c;但取转其裸指针&#xff0c;经常会看到into_boxed_slice()方法&#xff0c;这是为何&#xff1f; use std::{fmt, slice};#[derive(Clone, Copy)] struct RawBuffer {ptr: *mut u8,len: usize, }impl From<Vec<u8>&g…

垃圾回收之三色标记法(Tri-color Marking)

关于垃圾回收算法&#xff0c;基本就是那么几种&#xff1a;标记-清除、标记-复制、标记-整理。在此基础上可以增加分代&#xff08;新生代/老年代&#xff09;&#xff0c;每代采取不同的回收算法&#xff0c;以提高整体的分配和回收效率。 无论使用哪种算法&#xff0c;标记…

【libevent】http客户端2:使用post 发送本地文件到服务器

HttpClient2POST的例子 看起来只post了一次?#include <stdio.h> #include <assert.h> #include <stdlib.h> #include

深入浅出Pytorch函数——torch.maximum

分类目录&#xff1a;《深入浅出Pytorch函数》总目录 相关文章&#xff1a; 深入浅出Pytorch函数——torch.max 深入浅出Pytorch函数——torch.maximum 计算input和other的元素最大值。 语法 torch.maximum(input, other, *, outNone) -> Tensor参数 input&#xff1a;…

C# OpenCvSharpe 二值化工具 阈值 自适应阈值 局部阈值 InRange

效果 阈值 自适应阈值 局部阈值 InRange 项目 VS2010.net4.0OpenCvSharper3 Demo下载

Educational Codeforces Round 152 (Rated for Div. 2)

B. Monsters 题意&#xff1a;你的攻击力为k&#xff0c;你优先攻击血量最多的怪物&#xff0c;血量相同击杀编号小的&#xff0c;问怪物被击杀的顺序&#xff0c; 思路&#xff1a;我们可以知道最后肯定存在一个状态&#xff0c;所有怪物就差一次攻击就死了&#xff0c;这个…

AWS / VPC 云流量监控

由于安全性、数据现代化、增长、灵活性和成本等原因促使更多企业迁移到云&#xff0c;将数据存储在本地的组织正在使用云来存储其重要数据。亚马逊网络服务&#xff08;AWS&#xff09;仍然是最受追捧和需求的服务之一&#xff0c;而亚马逊虚拟私有云&#xff08;VPC&#xff0…

LED芯片 VAS1260IB05E 带内部开关LED驱动器 汽车硬灯带灯条解决方案

VAS1260IB05E深力科LED芯片是一种连续模式电感降压转换器&#xff0c;设计用于从高于LED电压的电压源高效驱动单个或多个串联连接的LED。该设备在5V至60V之间的输入电源下工作&#xff0c;并提供高达1.2A的外部可调输出电流。包括输出开关和高侧输出电流感测电路&#xff0c;该…

UE4/5C++多线程插件制作(十七、封装协程管理)

目录 MTPThreadInterface.h MTPManageBase.h MTPCoroutinesManage.h MTPManage.cpp MTPManage.h 添加继承: cpp实现: MTPThreadTaskMan

双系统的一些设置

1、windows和ubuntu双系统时间不同步的问题&#xff1a; 在安装Windows和Ubuntu双系统时&#xff0c;两个操作系统会分别使用自己的时间设置。Windows默认使用本地时间&#xff08;Local Time&#xff09;&#xff0c;而Ubuntu则默认使用协调世界时&#xff08;Coordinated Un…

TypeScript 在前端开发中的应用实践

TypeScript 在前端开发中的应用实践 TypeScript 已经成为前端开发领域越来越多开发者的首选工具。它是一种静态类型的超集&#xff0c;由 Microsoft 推出&#xff0c;为开发者提供了强大的静态类型检查、面向对象编程和模块化开发的特性&#xff0c;解决了 JavaScript 的动态类…

趋动科技携手星辰天合,推出针对人工智能领域的两款联合解决方案

近日&#xff0c;趋动科技与 XSKY星辰天合联合宣布&#xff0c;结合双方优势能力和产品&#xff0c;携手推出高性能数据湖一站式方案及全协议存算一体化方案&#xff0c;帮助客户简化 AI 工作的 IT 基础设施部署&#xff0c;实现 AI 相关工作更加灵活和便捷。 全协议存算一体化…

janus-Gateway的服务端部署

janus-Gateway 需求是前后端的webRTC推拉流&#xff0c;但是后端用的是c&#xff0c;于是使用了这个库做视频流的推送和拉取&#xff0c;记录踩坑过程。 如果你也需要自己部署janus的服务端并在前端拉流测试&#xff0c;希望对你有所帮助。 由于janus的服务器搭建需要linux环境…

树莓派Pico|RP2040|官方文档|在MS Windows上构建“Hello World”及环境配置

9.2. 在MS Windows上构建 在Microsoft Windows 10或Windows 11上安装工具链与其他平台有些不同。然而安装后&#xff0c;RP2040的构建代码基本类似。  警告 官方不支持在Windows 7或8上使用Raspberry Pi Pico&#xff0c;但在Windows 7或8上可以使其工作。 9.2.1. 安装工具…

docker中设置容器健康检查

文章目录 一、docker-compose方式二、Dockerfile方式三、docker run方式四、查看检查日志 一、docker-compose方式 在docker-compose中加入healthcheck healthcheck 支持下列选项&#xff1a; test&#xff1a;健康检查命令&#xff0c;例如 ["CMD", "curl&quo…

向npm注册中心发布包(上)

目录 1、创建package.json文件 1.1 fields 字段 1.2 Author 字段 1.3 创建 package.json 文件 1.4 自定义 package.json 的问题 1.5 从当前目录提取的默认值 1.6 通过init命令设置配置选项 2、创建Node.js 模块 2.1 创建一个package.json 文件 2.2 创建在另一个应用程…

5G时代的APP开发:机遇与挑战

APP开发是互联网行业中的重要组成部分&#xff0c;随着5G时代的到来&#xff0c;移动 APP开发也迎来了新的机遇和挑战。 5G时代不仅会为移动 APP开发带来新的发展机遇&#xff0c;也会给移动 APP开发带来新的挑战。对于企业和开发者而言&#xff0c;5G时代带来的机遇和挑战是并…

【雕爷学编程】MicroPython动手做(02)——尝试搭建K210开发板的IDE环境5

#尝试搭建K210的Micropython开发环境&#xff08;Win10&#xff09; #实验程序之三&#xff1a;更新频率演示 #尝试搭建K210的Micropython开发环境&#xff08;Win10&#xff09; #实验程序之三&#xff1a;更新频率演示from Maix import freqcpu_freq, kpu_freq freq.get() …