Pytorch解决 多元回归 问题的算法

Pytorch解决 多元回归 问题的算法

回归是一种基本的统计建模技术,用于建立因变量与一个或多个自变量之间的关系。
我们将使用 PyTorch(一种流行的深度学习框架)来开发和训练线性回归模型。

在这里插入图片描述

二元回归的简单示例

训练数据集(可获取)

对于此分析,我们将使用scikit-learn 库中的 make regression() 函数生成的合成数据集。数据集由输入特征和目标变量组成。输入特征代表自变量,而目标变量代表我们想要预测的因变量

import seaborn as sns
import numpy as sns
import torch
import torch.nn as nn
import torch.optim as optim
import sklearn
from sklearn import datasets
import pandas as pddata=datasets.make_regression()    # from sklearn we are going to select one dataset
df = pd.DataFrame(data[0], columns=[f"feature_{i+1}" for i in range(data[0].shape[1])])
df["target"] = data[1]

在这里插入图片描述

数据的结构,100 rows × 101 columns,最后 1 column为目标值

准备训练集与测试集

PyTorch 是一个功能强大的开源深度学习框架,提供了一种灵活的方式来构建和训练神经网络。它提供了一系列张量运算、自动微分和优化算法的功能。

使用 sklearn Train-Test-split 准备数据以开发模型

x=df.iloc[: , :-1]   # 除目标数据身下所以的
y=df.iloc[: , -1]    # targetfrom sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(x,y,test_size=0.2,random_state=42)
print(type(X_train))
# X_train=torch.tensor(X_train,dtype=torch.float32)X_train = torch.tensor(X_train.values, dtype=torch.float32)  # 转化为 tensor
X_test = torch.tensor(X_test.values, dtype=torch.float32)
y_train = torch.tensor(y_train.values, dtype=torch.float32)
y_test = torch.tensor(y_test.values, dtype=torch.float32)

模型架构

数据准备好了,可以准备模型了

我们的线性回归模型是作为PyTorch 中nn.Module类的子类实现的。该模型由多个按顺序连接的完全连接(线性)层组成。

class linearRegression(nn.Module): 
# 所有来自torch的依赖项将被传递给这个类[父类] 
# nn.Module 包含了神经网络的所有构建模块:def __init__(self,input_dim):super(linearRegression,self).__init__()   # building connection with parent and child classesself.fc1=nn.Linear(input_dim,10)          # hidden layer 1self.fc2=nn.Linear(10,5)                  # hidden layer 2self.fc3=nn.Linear(5,3)                   # hidden layer 3self.fc4=nn.Linear(3,1)                   # last layerdef forward(self,d):out=torch.relu(self.fc1(d))              # input * weights + bias for layer 1out=torch.relu(self.fc2(out))            # input * weights + bias for layer 2out=torch.relu(self.fc3(out))            # input * weights + bias for layer 3out=self.fc4(out)                        # input * weights + bias for last layerreturn out                               # final outcomeinput_dim=X_train.shape[1]     # 获取 input_dim 变量的数量
torch.manual_seed(42)          # to make initilized weights stable:
model=linearRegression(input_dim)
# select loss and optimizersloss=nn.MSELoss() # loss function
optimizers=optim.Adam(params=model.parameters(),lr=0.01)loss_values_all = []  # 创建一个列表来存储每个迭代的loss值# training the model:num_of_epochs=1000
for i in range(num_of_epochs):# give the input data to the architecurey_train_prediction=model(X_train)  # model initilizingloss_value=loss(y_train_prediction.squeeze(),y_train)   # find the loss function:optimizers.zero_grad() # make gradients zero for every iteration so next iteration it will be clearloss_value.backward()  # back propagationoptimizers.step()      # update weights in NNloss_values_all.append(loss_value.item())  # 将当前的loss值添加到列表中# print the loss in training part:if i % 10 == 0:print(f'[epoch:{i}]: The loss value for training part={loss_value}')

绘制 loss 曲线图
在这里插入图片描述
在测试数据集上的效果(test data)

with torch.no_grad():model.eval()   # make model in evaluation stagey_test_prediction=model(X_test)test_loss=loss(y_test_prediction.squeeze(),y_test)print(f'Test loss value : {test_loss.item():.4f}')

测试自己随机生成的数据

# Inference with own data:
pr = torch.tensor(torch.arange(1, 101).unsqueeze(dim=0), dtype=torch.float32).clone().detach()
print(pr)

保存训练好的模型

# save the torch model:from pathlib import Pathfilename=Path('models')
filename.mkdir(parents=True,exist_ok=True)model_name='linear_regression.pth' # model name# saving pathsaving_path=filename/model_name
print(saving_path)
torch.save(obj=model.state_dict(),f=saving_path)# we can load the saved model and do the inference again:load_model=linearRegression(input_dim) # creating an instance again for loaded model
load_model.load_state_dict(torch.load('./models/linear_regression.pth'))load_model.eval()   # make model in evaluation stage
with torch.no_grad():pred = load_model(torch.tensor([[  1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,  11.,  12.,13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,  22.,  23.,  24.,25.,  26.,  27.,  28.,  29.,  30.,  31.,  32.,  33.,  34.,  35.,  36.,37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,  45.,  46.,  47.,  48.,49.,  50.,  51.,  52.,  53.,  54.,  55.,  56.,  57.,  58.,  59.,  60.,61.,  62.,  63.,  64.,  65.,  66.,  67.,  68.,  69.,  70.,  71.,  72.,73.,  74.,  75.,  76.,  77.,  78.,  79.,  80.,  81.,  82.,  83.,  84.,85.,  86.,  87.,  88.,  89.,  90.,  91.,  92.,  93.,  94.,  95.,  96.,97.,  98.,  99., 100.]]))print(f'prediction value : {pred.item()}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/852971.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【git】 OpenSSL SSL_connect: SSL_ERROR_SYSCALL in connection to github.com:443

修改/etc/hosts文件,删掉以下 192.30.253.113 github.com 192.30.253.113 github.com 192.30.253.118 gist.github.com 192.30.253.119 gist.github.com #172.24.132.179 gerrit.sdp.nd 140.82.112.25 alive.github.com 140.82.114.6 api.github.com 185.199.110.15…

哪里有海量的短视频素材,以及短视频制作教程?

在当下,短视频已成为最火爆的内容形式之一,尤其是在抖音上。但很多创作者都面临一个问题:视频素材从哪里来?怎么拍摄才能吸引更多观众?别担心,今天我将为大家推荐几个宝藏网站,确保你素材多到用…

【Kafka】Kafka生产者-04

【Kafka】Kafka生产者-04 1. 生产者发送消息流程1.1 发送原理 2. 相关文档 1. 生产者发送消息流程 1.1 发送原理 在消息发送的过程中,涉及到了两个线程——main 线程和 Sender 线程。 在 main 线程中创建了一个双端队列 RecordAccumulator。 main 线程将消息发送给…

数据库的权限管理和安全策略

数据库的权限管理和安全策略是确保数据库安全、可靠和稳定运行的关键措施。以下是对数据库权限管理和安全策略的详细解释: 数据库权限管理 1. 权限定义 数据库权限是指用户对数据库中的数据和操作所拥有的执行权利。这些权限决定了用户可以访问哪些数据、可以对数…

【CSP】202312-2 因子简化

2023年 第32次CCF计算机软件能力认证 202312-2 因子化简 原题链接:CSP32-因子简化 时间限制: 2.0 秒 空间限制: 512 MiB 目录 题目背景 题目描述 输入格式 输出格式 样例输入 样例输出 样例解释 子任务 解题思路 AC代码 题目背…

STM32 MDK Keil5软件调试功能使用(无需连接硬件)

MDK Keil5 在线仿真STM32(无需连接硬件) 首先点击工具栏的魔术棒配置一下:(记得选择自己的STM32芯片类型) 开启调试 使用逻辑分析仪查看IO输出 会打开这个界面,点击左边的setup按钮 会打开这个窗口&am…

182.二叉树:二叉搜索树的最小绝对差(力扣)

代码解决 /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *right;* TreeNode() : val(0), left(nullptr), right(nullptr) {}* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}* Tre…

手把手教你入门vue+springboot开发(三)--登录功能后端

文章目录 前言一、redis安装二、后端代码1.修改application.yml文件2.增加utils文件3.增加Result类4.修改UserController类5.修改UserMapper类6.修改UserService和UserServiceImpl类7.增加LoginInterceptor类8.增加WebConfig类9.修改pom.xml文件 前言 前两篇我们用vuespringbo…

去除数组重复成员的方法

方法 1 扩展运算符和 Set 结构相结合,就可以去除数组的重复成员 // 去除数组的重复成员 [...new Set([1, 2, 2, 3, 4, 5, 5])]; // [1, 2, 3, 4, 5] 方法 2 function dedupe(array) {return Array.from(new Set(array)); } dedupe([1, 1, 2, 3]); // [1, 2, 3] …

FPGA中复位电路的设计

复位电路也是数字逻辑设计中常用的电路,不管是 FPGA 还是 ASIC 设计,都会涉及到复位,一般 FPGA或者 ASIC 的复位需要我们自己设计复位方案。复位指的是将寄存器恢复到默认值。一般复位功能包括同步复位和异步复位。复位一般由硬件开关触发引起…

基于LangChain-Chatchat实现的RAG-本地知识库的问答应用[2]-简洁部署版

基于LangChain-Chatchat实现的RAG-本地知识库的问答应用[2]-简洁部署版 1.环境要求 1.1 软件要求 要顺利运行本代码,请按照以下系统要求进行配置 已经测试过的系统 Linux Ubuntu 22.04.5 kernel version 6.7其他系统可能出现系统兼容性问题。 最低要求 该要求仅针对标准模…

oracle中执行select ... for update需要什么权限?

oracle中执行select … for update需要什么权限? 问题 在oracle中,一个用户执行select … for update需要什么权限? 分析测试 用户1: test_0614 用户2:test 目标表:test.t_0614 执行语句:se…

MySQL用户密码插件mysql_native_password和caching_sha2_password的区别

MySQL用户密码插件mysql_native_password和caching_sha2_password有几个关键的区别,主要集中在安全性、性能和兼容性方面: 1. 算法和安全性 mysql_native_password: 使用的是基于SHA-1的密码散列算法。SHA-1算法已被认为不再足够安全,存在一…

【深度学习基础】理解 PyTorch 中的 logits 和交叉熵损失函数

在深度学习中,理解损失函数是训练模型的关键一步。在分类任务中,交叉熵损失函数是最常用的损失函数之一。本文将详细解释 PyTorch 中的 logits、交叉熵损失函数的工作原理,并展示如何调整张量的形状以确保计算正确的损失。 什么是 logits&am…

论人工智能与真实性

论人工智能与真实性 这让我们都感到不安:不是因为人工智能已经足够好,可以准确地预测某人可能会如何回答(包括猫的名字、表情符号的使用、汤的参考以及对“精神动物”的随意参考),而是因为提供这些反应菜单的模式首先代表了对这些互动功能的误解。即使回…

59.指向指针的指针(二级指针)

目录 一.什么是指向指针的指针 二.扩展 三.视频教程 一.什么是指向指针的指针 我们先看回顾一下指针&#xff1a; #include <stdio.h>int main(void) {int a 100;int *p &a;printf("*p is %d\n",*p);return 0;} 解析&#xff1a; 所以printf输出的结…

TCP/IP协议,三次握手,四次挥手

IP - 网际协议 IP 负责计算机之间的通信。 IP 负责在因特网上发送和接收数据包。 HTTP - 超文本传输协议 HTTP 负责 web 服务器与 web 浏览器之间的通信。 HTTP 用于从 web 客户端&#xff08;浏览器&#xff09;向 web 服务器发送请求&#xff0c;并从 web 服务器向 web …

Java 网站开发入门指南:如何用java写一个网站

Java 网站开发入门指南&#xff1a;如何用java写一个网站 Java 作为一门强大的编程语言&#xff0c;在网站开发领域也占据着重要地位。虽然现在 Python、JavaScript 等语言在网站开发中越来越流行&#xff0c;但 Java 凭借其稳定性、可扩展性和丰富的生态系统&#xff0c;仍然…

【CS.AL】算法必学之贪心算法:从入门到进阶 —— 关键概念和代码示例

文章目录 1. 概述2. 适用场景3. 设计步骤4. 优缺点5. 典型应用6. 题目和代码示例6.1 简单题目&#xff1a;找零问题6.2 中等题目&#xff1a;区间调度问题6.3 困难题目&#xff1a;分数背包问题 7. 题目和思路表格8. 总结References 1000.1.CS.AL.1.4-核心-GreedyAlgorithm-Cre…

李永乐线代笔记

线性方程组 解方程组的变换就是矩阵初等行变换 三秩相等 方程组系数矩阵的行秩列秩&#xff0c;线性相关的问题应求列秩&#xff0c;但求行秩方便 齐次线性方程组 对应向量组的线性相关&#xff0c;所以回顾下线性相关的知识&#xff1a; 其中k是x&#xff0c;所以用向…