PyTorch_构建线性回归

使用 PyTorch 的 API 来手动构建一个线性回归的假设函数,数据加载器,损失函数,优化方法,绘制训练过程中的损失变化。


数据构建

import torch
from sklearn.datasets import make_regression 
import matplotlib.pyplot as plt 
import random # 构建数据集
def create_dataset():x, y, coef = make_regression(n_samples = 100, n_features = 1, noise = 10, coef= True, bias = 14.5, random_state = 0) # 将构建数据转换为张量类型x = torch.tensor(x)y = torch.tensor(y)return x, y # 构建数据加载器
def data_load(x, y, batch_size):# 计算样本数量data_len = len(y)# 构建数据索引data_index = list(range(data_len))# 数据集打乱random.shuffle(data_index)# 计算总的batch数量batch_number = data_len // batch_size for idx in range(batch_number):start = idx * batch_size end = start + batch_size batch_train_x = x[start: end]batch_train_y = y[start: end]yield batch_train_x, batch_train_ydef test01():x, y = create_dataset()plt.scatter(x, y)plt.show()for x, y in data_load(x, y, batch_size=10):print(y)if __name__ == "__main__":test01() 

构建假设函数,损失函数,优化方法

所谓的假设函数,就是线性回归的方程。

损失函数:使用平方损失

优化方法:梯度下降

# 构建假设函数
w = torch.tensor(0.1, requires_grad=True, dtype=torch.float64)
b = torch.tensor(0.0, requires_grad=True, dtype=torch.float64)def linear_regression(x):return w * x + b # 损失函数
def square_loss(y_pred, y_true):return (y_pred - y_true) ** 2 # 优化方法
def sgd(learning_rate = 0.01):# 16 是批次样本的平均梯度值。 batch sizew.data = w.data - learning_rate * w.grad.data / 16b.data = b.data - learning_rate * b.grad.data / 16

训练函数

# 训练函数
def train():# 加载数据集x, y, coef = create_dataset()# 定义训练参数epochs = 100learning_rate = 0.01# 存储损失epoch_loss = []total_loss = 0.0 train_sample = 0 for _ in range(epochs):for train_x, train_y in data_load(x, y, 16):# 训练数据送入模型进行预测y_pred = linear_regression(train_x)# 计算预测值和真实值的平方损失loss = square_loss(y_pred, train_y.reshape(-1, 1)).sum()total_loss += loss.item()train_sample += len(train_y)# 梯度清零if w.grad is not None:w.grad.data.zero_() if b.grad is not None:b.grad.data.zero_() # 自动微分loss.backward()# 更新参数sgd(learning_rate)print('loss: %.10f' % (total_loss / train_sample))epoch_loss.append(total_loss / train_sample)# 绘制拟合直线print(coef, w.data.item())plt.scatter(x, y)x = torch.linspace(x.min(), x.max(), 1000)y1 = torch.tensor([v * w + 14.5 for v in x])y2 = torch.tensor([v * coef + 14.5 for v in x])plt.plot(x, y1, label = '训练')plt.plot(x, y2, label = '真实')plt.grid()plt.legend()plt.show()# 打印损失变化曲线plt.plot(range(epochs), epoch_loss)plt.title('损失变化曲线')plt.grid()plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/78744.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

005-nlohmann/json 基础方法-C++开源库108杰

《二、基础方法》:节点访问、值获取、显式 vs 隐式、异常处理、迭代器、类型检测、异常处理……一节课搞定C处理JSON数据85%的需求…… JSON 字段的简单类型包括:number、boolean、string 和 null(即空值);复杂类型则有…

HarmonyOS 5.0 分布式数据协同与跨设备同步​​

大家好,我是 V 哥。 使用 Mate 70有一段时间了,系统的丝滑使用起来那是爽得不要不要的,随着越来越多的应用适配,目前使用起来已经和4.3的兼容版本功能差异无碍了,还有些纯血鸿蒙独特的能力很是好用,比如&am…

Linux云计算训练营笔记day02(Linux、计算机网络、进制)

Linux 是一个操作系统 Linux版本 RedHat Rocky Linux CentOS7 Linux Ubuntu Linux Debian Linux Deepin Linux 登录用户 管理员 root a 普通用户 nsd a 打开终端 放大: ctrl shift 缩小: ctrl - 命令行提示符 [rootlocalhost ~]# ~ 家目录 /root 当前登录的用户…

macOS 安装了Docker Desktop版终端docker 命令没办法使用

macOS 安装了Docker Desktop版终端docker 命令没办法使用 1、检查Docker Desktop能否正常运行。 确保Docker Desktop能正常运行。 2、检查环境变量是否添加 1、添加环境变量 如果环境变量中没有包含Docker的路径,你可以手动添加。首先,找到Docker的…

Gradio全解20——Streaming:流式传输的多媒体应用(5)——基于WebRTC的摄像头实时目标检测

Gradio全解20——Streaming:流式传输的多媒体应用(5)——基于WebRTC的摄像头实时目标检测 本篇摘要20. Streaming:流式传输的多媒体应用20.5 基于WebRTC的摄像头实时目标检测20.5.1 环境配置及说明1. WebRTC2. TURN服务器 20.5.2 …

OSCP - Proving Grounds - NoName

主要知识点 linux命令注入SUID find提权 具体步骤 从nmap开始搜集信息,只开放了一个80端口 Nmap scan report for 192.168.171.15 Host is up (0.40s latency). Not shown: 65534 closed tcp ports (reset) PORT STATE SERVICE VERSION 80/tcp open http …

c++_csp-j算法 (6)_高精度算法(加减乘除)

高精度算法 C++高精度算法是指在C++编程语言中实现高精度计算的算法。在C++中,通常整数的范围是有限的,超出这个范围的整数计算会导致溢出。高精度算法的出现,使得C++程序能够处理超出常规整数范围的大整数计算,包括高精度加法、减法、乘法、除法等运算。 在C++中实现高精…

从OpenMP中的不兼容,窥探AI应用开发中的并行编程

在AI相关的项目开发中,偶然遇到下面这个问题: OMP: Error #15: Initializing libomp.dylib, but found libiomp5.dylib already initialized. OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the progr am. That is dangerous, sin…

vue2+element实现Table表格嵌套输入框、选择器、日期选择器、表单弹出窗组件的行内编辑功能

vue2element实现Table表格嵌套输入框、选择器、日期选择器、表单弹出窗组件的行内编辑功能 文章目录 vue2element实现Table表格嵌套输入框、选择器、日期选择器、表单弹出窗组件的行内编辑功能前言一、准备工作二、行内编辑1.嵌入Input文本输入框1.1遇到问题1.文本框内容修改失…

c#OdbcDataReader的数据读取

先有如下c#示例代码: string strconnect "DSNcustom;UIDsa;PWD123456;" OdbcConnection odbc new OdbcConnection(strconnect); odbc.Open(); if (odbc.State ! System.Data.ConnectionState.Open) { return; } string strSql "select ID from my…

【HTML5】老式放映机原理-实现图片无缝滚动

老式放映机原理-实现图片无缝滚动 实现思路: 页面设计部分——先将视口div设置为相对定位,再视口div里面嵌套一个类似“胶卷”的div,把该div设置为绝对定位,此时“胶卷"会挂靠在视口上面,再将“胶卷”的left属性设置为负值…

LeetCode 1781. 所有子字符串美丽值之和 题解

示例 输入:s "aabcb" 输出:5 解释:美丽值不为零的字符串包括 ["aab","aabc","aabcb","abcb","bcb"] ,每一个字符串的美丽值都为 1这题光用文字解说还是无法达到讲…

2025ACTF Web部分题解

文章目录 ACTF uploadnot so web 1not so web 2 ACTF upload 前面登录随便输入可以进入文件上传页面, 随便上传一张图片, 发现路由存在file_path参数, 尝试路径穿越读取文件 发现可以成功读取 读取源码 /upload?file_path../app.pyimport uuid import os import hashlib im…

双目标清单——AI与思维模型【96】

一、定义 双目标清单思维模型是一种将决策或任务分解为两个主要目标,并分别列出相关要素和行动步骤的思维方式。这两个目标通常具有相互关联又有所侧重的特点,通过明确并列出与每个目标相关的具体事项,有助于更清晰地分析问题、制定计划和分…

深度学习系统学习系列【6】之深度学习技巧

文章目录 数据集准备数据集扩展数据预处理1. 0均值(Zero Centralization)代码实现 2. 归一化(Normalization)代码实现 3. 主成分分析(Principal Component Analysis, PCA)实现步骤代码实现 4. 白化&#xf…

rfsoc petalinux适配调试记录

1。安装虚拟机 2.设置共享文件夹 https://xinzhi.wenda.so.com/a/1668239544201149先设置文件夹路径 vmware 12 下安装 ubuntu 16.04 后,按往常的惯例安装 vmware-tools,安装时提示建议使用 open-vm-tools,于是放弃 vmware-tools 的安装&am…

# YOLOv1:开启实时目标检测的新时代

YOLOv1:开启实时目标检测的新时代 在计算机视觉领域,目标检测一直是研究的热点和难点问题。它不仅需要准确地识别出图像中的物体,还需要确定这些物体的位置。YOLO(You Only Look Once)系列算法以其高效的实时目标检测…

uni-app vue3 实现72小时倒计时功能

功能介绍 &#xff0c;数组项有一个下单时间 &#xff0c;比如今天下单在72小时内可以继续支付&#xff0c;超过则默认取消订单 页面按钮处 加上倒计时 <!-- 倒计时 --> <text v-if"item.timeLeft > 0">{{ formatTime(item.remaining) }}</text&g…

一周学会Pandas2 Python数据处理与分析-Pandas2数据类型转换操作

锋哥原创的Pandas2 Python数据处理与分析 视频教程&#xff1a; 2025版 Pandas2 Python数据处理与分析 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili Pandas 提供了灵活的方法来处理数据类型转换&#xff0c;以下是常见操作及代码示例&#xff1a; 1. 查看数据类型 …

LLM损失函数面试会问到的

介绍一下KL散度 KL&#xff08;Kullback-Leibler散度衡量了两个概率分布之间的差异。其公式为&#xff1a; D K L ( P / / Q ) − ∑ x ∈ X P ( x ) log ⁡ 1 P ( x ) ∑ x ∈ X P ( x ) log ⁡ 1 Q ( x ) D_{KL}(P//Q)-\sum_{x\in X}P(x)\log\frac{1}{P(x)}\sum_{x\in X}…