深度学习 - PyTorch基本流程 (代码)

直接上代码

import torch 
import matplotlib.pyplot as plt 
from torch import nn# 创建data
print("**** Create Data ****")
weight = 0.3
bias = 0.9
X = torch.arange(0,1,0.01).unsqueeze(dim = 1)
y = weight * X + bias
print(f"Number of X samples: {len(X)}")
print(f"Number of y samples: {len(y)}")
print(f"First 10 X & y sample: \n X: {X[:10]}\n y: {y[:10]}")
print("\n")# 将data拆分成training 和 testing
print("**** Splitting data ****")
train_split = int(len(X) * 0.8)
X_train = X[:train_split]
y_train = y[:train_split]
X_test = X[train_split:]
y_test = y[train_split:]
print(f"The length of X train: {len(X_train)}")
print(f"The length of y train: {len(y_train)}")
print(f"The length of X test: {len(X_test)}")
print(f"The length of y test: {len(y_test)}\n")# 显示 training 和 testing 数据
def plot_predictions(train_data = X_train,train_labels = y_train,test_data = X_test,test_labels = y_test,predictions = None):plt.figure(figsize = (10,7))plt.scatter(train_data, train_labels, c = 'b', s = 4, label = "Training data")plt.scatter(test_data, test_labels, c = 'g', label="Test data")if predictions is not None:plt.scatter(test_data, predictions, c = 'r', s = 4, label = "Predictions")plt.legend(prop = {"size": 14})
plot_predictions()# 创建线性回归
print("**** Create PyTorch linear regression model by subclassing nn.Module ****")
class LinearRegressionModel(nn.Module):def __init__(self):super().__init__()self.weight = nn.Parameter(data = torch.randn(1,requires_grad = True,dtype = torch.float))self.bias = nn.Parameter(data = torch.randn(1,requires_grad = True,dtype = torch.float))def forward(self, x):return self.weight * x + self.biastorch.manual_seed(42)
model_1 = LinearRegressionModel()
print(model_1)
print(model_1.state_dict())
print("\n")# 初始化模型并放到目标机里
print("*** Instantiate the model ***")
print(list(model_1.parameters()))
print("\\n")# 创建一个loss函数并优化
print("*** Create and Loss function and optimizer ***")
loss_fn = nn.L1Loss()
optimizer = torch.optim.SGD(params = model_1.parameters(),lr = 0.01)
print(f"loss_fn: {loss_fn}")
print(f"optimizer: {optimizer}\n")# 训练
print("*** Training Loop ***")
torch.manual_seed(42)
epochs = 300
for epoch in range(epochs):# 将模型加载到训练模型里model_1.train()# 做 Forwardy_pred = model_1(X_train)# 计算 Lossloss = loss_fn(y_pred, y_train)# 零梯度optimizer.zero_grad()# 反向传播loss.backward()# 步骤优化optimizer.step()### 做测试if epoch % 20 == 0:# 将模型放到评估模型并设置上下文model_1.eval()with torch.inference_mode():# 做 Forwardy_preds = model_1(X_test)# 计算测试 losstest_loss = loss_fn(y_preds, y_test)# 输出测试结果print(f"Epoch: {epoch} | Train loss: {loss:.3f} | Test loss: {test_loss:.3f}")# 在测试集上对训练模型做预测
print("\n")
print("*** Make predictions with the trained model on the test data. ***")
model_1.eval()
with torch.inference_mode():y_preds = model_1(X_test)
print(f"y_preds:\n {y_preds}")
## 画图
plot_predictions(predictions = y_preds) # 保存训练好的模型
print("\n")
print("*** Save the trained model ***")
from pathlib import Path 
## 创建模型的文件夹
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents = True, exist_ok = True)
## 创建模型的位置
MODEL_NAME = "trained model"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME 
## 保存模型到刚创建好的文件夹
print(f"Saving model to {MODEL_SAVE_PATH}")
torch.save(obj = model_1.state_dict(), f = MODEL_SAVE_PATH)
## 创建模型的新类型
loaded_model = LinearRegressionModel()
loaded_model.load_state_dict(torch.load(f = MODEL_SAVE_PATH))
## 做预测,并跟之前的做预测
y_preds_new = loaded_model(X_test)
print(y_preds == y_preds_new)

结果如下

**** Create Data ****
Number of X samples: 100
Number of y samples: 100
First 10 X & y sample: X: tensor([[0.0000],[0.0100],[0.0200],[0.0300],[0.0400],[0.0500],[0.0600],[0.0700],[0.0800],[0.0900]])y: tensor([[0.9000],[0.9030],[0.9060],[0.9090],[0.9120],[0.9150],[0.9180],[0.9210],[0.9240],[0.9270]])**** Splitting data ****
The length of X train: 80
The length of y train: 80
The length of X test: 20
The length of y test: 20**** Create PyTorch linear regression model by subclassing nn.Module ****
LinearRegressionModel()
OrderedDict([('weight', tensor([0.3367])), ('bias', tensor([0.1288]))])*** Instantiate the model ***
[Parameter containing:
tensor([0.3367], requires_grad=True), Parameter containing:
tensor([0.1288], requires_grad=True)]*** Create and Loss function and optimizer ***
loss_fn: L1Loss()
optimizer: SGD (
Parameter Group 0dampening: 0differentiable: Falseforeach: Nonelr: 0.01maximize: Falsemomentum: 0nesterov: Falseweight_decay: 0
)*** Training Loop ***
Epoch: 0 | Train loss: 0.757 | Test loss: 0.725
Epoch: 20 | Train loss: 0.525 | Test loss: 0.454
Epoch: 40 | Train loss: 0.294 | Test loss: 0.183
Epoch: 60 | Train loss: 0.077 | Test loss: 0.073
Epoch: 80 | Train loss: 0.053 | Test loss: 0.116
Epoch: 100 | Train loss: 0.046 | Test loss: 0.105
Epoch: 120 | Train loss: 0.039 | Test loss: 0.089
Epoch: 140 | Train loss: 0.032 | Test loss: 0.074
Epoch: 160 | Train loss: 0.025 | Test loss: 0.058
Epoch: 180 | Train loss: 0.018 | Test loss: 0.042
Epoch: 200 | Train loss: 0.011 | Test loss: 0.026
Epoch: 220 | Train loss: 0.004 | Test loss: 0.009
Epoch: 240 | Train loss: 0.004 | Test loss: 0.006
Epoch: 260 | Train loss: 0.004 | Test loss: 0.006
Epoch: 280 | Train loss: 0.004 | Test loss: 0.006*** Make predictions wit the trained model on the test data. ***
y_preds:tensor([[1.1464],[1.1495],[1.1525],[1.1556],[1.1587],[1.1617],[1.1648],[1.1679],[1.1709],[1.1740],[1.1771],[1.1801],[1.1832],[1.1863],[1.1893],[1.1924],[1.1955],[1.1985],[1.2016],[1.2047]])*** Save the trained model ***
Saving model to models/trained model
tensor([[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True]])

第一个结果图
第二个结果图

点个赞支持一下咯~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/778659.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于资源的约束委派(下)

webclient http self relay Web 分布式创作和版本控制 (WebDAV) 是超文本传输协议 (HTTP) 的扩展,它定义了如何使用 HTTP ( docs.microsoft.com )执行复 制、移动、删除和创建等基本文件功能 需要启用 WebClient 服务才能使基于 WebDAV 的程序和功能正常工作。事实…

全国中学基础信息 API 数据接口

全国中学基础信息 API 数据接口 基础数据,高校高考,提供全国初级高级中学基础数据,定时更新,多维度筛选。 1. 产品功能 2024 年数据已更新;提供最新全国中学学校基本信息;包含全国初级中学与高等中学&…

虚拟现实(VR)项目的开发工具

虚拟现实(VR)项目的开发涉及到多种工具,这些工具可以帮助开发者从建模、编程到最终内容的发布。以下是一些被广泛认可的VR开发工具,它们覆盖了从3D建模到交互设计等多个方面。北京木奇移动技术有限公司,专业的软件外包…

为什么我的微信小程序 窗口背景色backgroundColor设置参数 无效的问题处理记录!

当我们在微信小程序 json 中设置 backgroundColor 时,实际在电脑的模拟器中根本看不到效果。 这是因为 backgroundColor 指的窗体背景颜色,而不是页面的背景颜色,即窗体下拉刷新或上拉加载时露出的背景。在电脑的模拟器中是看不到这个动作的…

Acwing_795前缀和 【一维前缀和】+【模板】二维前缀和

Acwing_795前缀和 【一维前缀和】 题目&#xff1a; 代码&#xff1a; #include <bits/stdc.h> #define int long long #define INF 0X3f3f3f3f #define endl \n using namespace std; const int N 100010; int arr[N];int n,m; int l,r; signed main(){std::ios::s…

Gitlab 实现仓库完全迁移,包括所有提交记录、分支、标签

1 方案一&#xff1a;命令 cd <项目目录> git fetch --all git fetch --tags git remote rename origin old-origin #可以不保留 git remote add origin http://***(项目的新仓库地址) #git remote set-url origin <项目的新仓库地址> git push origin --all git…

项目管理:项目进度管理的五大关键步骤

作为项目经理&#xff0c;要想做好项目进度管理&#xff0c;可以遵循以下五个关键步骤&#xff1a; 一、确定范围和分解目标 1、明确项目目标&#xff1a;首先&#xff0c;要清晰地定义项目的总体目标和预期成果。 2、范围界定&#xff1a;详细列出项目所需完成的所有任务和…

LDL^H分解求逆矩阵与MATLAB仿真(Right-Looking)

通过分解将对称正定厄米特矩阵分解成下三角矩阵L和对角矩阵D来求其逆矩阵 目录 前言 一、LDL^H基本算法 二、LDL^H Right-Looking算法 三、D矩阵求逆 四、L矩阵求逆 五、A矩阵求逆 六、计算量分析 七、MATLAB仿真 八、参考资料 总结 前言 在线性代数中&#xff0c;LDL…

HarmonyOS入门--配置环境 + IDE汉化

文章目录 下载安装DevEco Studio配置环境先认识DevEco Studio界面工程目录工程级目录模块级目录 app.json5module.json5main_pages.json通知栏预览区 运行模拟器IED汉化 下载安装DevEco Studio 去官网下载DevEco Studio完了安装 配置环境 打开已安装的DevEco Studio快捷方式…

Java中有哪些容器(集合类)?

Java中的集合类主要由Collection和Map这两个接口派生而出&#xff0c;其中Collection接口又派生出三个子接 口&#xff0c;分别是Set、List、Queue。所有的Java集合类&#xff0c;都是Set、List、Queue、Map这四个接口的实现 类&#xff0c;这四个接口将集合分成了四大类&#…

蓝桥杯 - 小明的背包1(01背包)

解题思路&#xff1a; 本题属于01背包问题&#xff0c;使用动态规划 dp[ j ]表示容量为 j 的背包的最大价值 注意&#xff1a; 需要时刻提醒自己dp[ j ]代表的含义&#xff0c;不然容易晕头转向 注意越界问题&#xff0c;且 j 需要倒序遍历 如果正序遍历 dp[1] dp[1 - vo…

Android应用程序的概念性描述

1.概述 Android 应用程序包含了工程文件、代码和各种资源&#xff0c;主要由 Java 语言编写&#xff0c;每一个应用程序将被编译成Android 的一个 Java 应用程序包&#xff08;*.apk&#xff09;。 由于 Android 系统本身是基于 Linux 操作系统运行的&#xff0c;因此 …

SpringBoot Redis 之Lettuce 驱动

一、前言 一直以为SpringBoot中 spring-boot-starter-data-redis使用的是Jredis连接池&#xff0c;直到昨天在部署报价系统生产环境时&#xff0c;因为端口配置错误造成无法连接&#xff0c;发现报错信息如下&#xff1a; 一了解才知道在SpringBoot2.X以后默认是使用Lettuce作…

蓝桥杯 2022 省A 选数异或

一种比较无脑暴力点的方法&#xff0c;时间复杂度是(nm)。 (注意的优先级比^高&#xff0c;记得加括号(a[i]^a[j])x&#xff09; #include <iostream> #include <vector> #include <bits/stdc.h> // 包含一些 C 标准库中未包含的特定实现的函数的头文件 usi…

成都市酷客焕学新媒体科技有限公司:实现品牌的更大价值!

成都市酷客焕学新媒体科技有限公司专注于短视频营销&#xff0c;深知短视频在社交媒体中的巨大影响力。该公司巧妙地将品牌信息融入富有创意和趣味性的内容中&#xff0c;使观众在轻松愉悦的氛围中接受并传播这些信息。凭借独特的创意和精准的营销策略&#xff0c;成都市酷客焕…

第二证券|打新股有风险吗?

打新股有危险&#xff0c;其主要危险是破发&#xff0c;其间呈现以下状况&#xff0c;新股可能会破发&#xff1a; 1、估值过高 新股的估值过高&#xff0c;与其价值不相契合&#xff0c;其泡沫性较大&#xff0c;然后导致个股在上市之后&#xff0c;稳健投资者以及主力大量地…

10个替代Sketch的软件大盘点!第一款震撼来袭!

Sketch是Mac平台上专门为用户界面设计的矢量图形绘制工具。Sketch简单的界面背后有优秀的矢量绘制能力和丰富的插件库。但遗憾的是&#xff0c;Sketch只能在Mac平台上使用和浏览&#xff0c;而且是本地化的工具&#xff0c;云共享功能并不完善。在本文中&#xff0c;我们评估了…

金三银四面试题(一):JVM类加载与垃圾回收

面试过程中最经典的一题&#xff1a; 请你讲讲在JVM中类的加载过程以及垃圾回收&#xff1f; 加载过程 当Java虚拟机&#xff08;JVM&#xff09;启动时&#xff0c;它会通过类加载器&#xff08;ClassLoader&#xff09;加载Java类到内存中。类加载是Java程序运行的重要组成…

python(一)网络爬取

在爬取网页信息时&#xff0c;需要注意网页爬虫规范文件robots.txt eg:csdn的爬虫规范文件 csdn.net/robots.txt User-agent: 下面的Disallow规则适用于所有爬虫&#xff08;即所有用户代理&#xff09;。星号*是一个通配符&#xff0c;表示“所有”。 Disallow&…

Scikit-Learn K近邻分类

Scikit-Learn K近邻分类 1、K近邻分类1.1、K近邻分类及原理1.2、超参数K1.3、K近邻分类的优缺点2、Scikit-Learn K近邻分类2.1、Scikit-Learn K近邻分类API2.2、K近邻分类实践(鸢尾花分类)2.3、交叉验证寻找最佳K2.4、K近邻分类与Pipeline1、K近邻分类 K近邻是一种常用的分类…