自定义数据集,使用 PyTorch 框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

在本文中,我们将展示如何使用 NumPy 创建自定义数据集,利用 PyTorch 实现一个简单的逻辑回归模型,并在训练完成后保存该模型,最后加载模型并用它进行预测。

1. 创建自定义数据集

首先,我们使用 NumPy 创建一个简单的二分类数据集。假设我们的数据集包含两个特征。

import numpy as np# 生成随机数据
np.random.seed(42)
X = np.random.randn(100, 2)  # 100个样本,2个特征
y = (X[:, 0] + X[:, 1] > 0).astype(int)  # 标签为1或0# 打印数据集的前5个样本
print(X[:5], y[:5])

2. 构建逻辑回归模型

接下来,我们使用 PyTorch 来构建一个简单的逻辑回归模型。PyTorch 提供了 torch.nn.Module 类,能够轻松实现神经网络模型。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# 转换 NumPy 数据为 PyTorch 张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)# 创建数据集并加载
dataset = TensorDataset(X_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)# 定义逻辑回归模型
class LogisticRegressionModel(nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = nn.Linear(2, 1)  # 2个输入特征,1个输出def forward(self, x):return torch.sigmoid(self.linear(x))# 初始化模型
model = LogisticRegressionModel()

3. 训练模型

接下来,我们定义损失函数和优化器,并训练模型。

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
epochs = 1000
for epoch in range(epochs):for inputs, labels in dataloader:optimizer.zero_grad()  # 清空梯度outputs = model(inputs)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新权重if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

4. 保存模型

模型训练完成后,我们可以将模型保存到文件中。PyTorch 提供了 torch.save() 方法来保存模型的状态字典。

# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')
print("模型已保存!")

5. 加载模型并进行预测

我们可以加载保存的模型,并对新数据进行预测。

# 加载模型
loaded_model = LogisticRegressionModel()
loaded_model.load_state_dict(torch.load('logistic_regression_model.pth'))
loaded_model.eval()  # 切换到评估模式# 进行预测
with torch.no_grad():test_data = torch.tensor([[1.5, -0.5]], dtype=torch.float32)prediction = loaded_model(test_data)print(f'预测值: {prediction.item():.4f}')

6. 总结

在这篇博客中,我们展示了如何使用 NumPy 创建一个简单的自定义数据集,并使用 PyTorch 实现一个逻辑回归模型。我们还展示了如何保存训练好的模型,并加载模型进行预测。通过保存和加载模型,我们可以在不同的时间或环境中重复使用已经训练好的模型,而不需要重新训练它。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/894005.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

​ONES 春节假期服务通知

ONES 春节假期服务通知 灵蛇贺岁,瑞气盈门。感谢大家一直以来对 ONES 的认可与支持,祝您春节快乐! 「2025年1月28日 ~ 2025年2月4日」春节假期期间,我们的值班人员将为您提供如下服务 : 紧急问题 若有紧急问…

python:洛伦兹变换

洛伦兹变换(Lorentz transformations)是相对论中的一个重要概念,特别是在讨论时空的变换时非常重要。在四维时空的背景下,洛伦兹变换描述了在不同惯性参考系之间如何变换时间和空间坐标。在狭义相对论中,洛伦兹变换通常…

Solon Cloud Gateway 开发:Route 的配置与注册方式

路由的配置与注册有三种方式:手动配置;自动发现配置;代码注册。 1、手动配置方式 solon.cloud.gateway:routes: #!必选- id: demotarget: "http://localhost:8080" # 或 "lb://user-service"predicates: #?可选- &quo…

LangChain:使用表达式语言优化提示词链

在 LangChain 里,LCEL 即 LangChain Expression Language(LangChain 表达式语言),本文为你详细介绍它的定义、作用、优势并举例说明,从简单示例到复杂组合示例,让你快速掌握LCEL表达式语言使用技巧。 定义 …

Julia DataFrames.jl:深入理解和使用

随着数据科学和机器学习的发展,数据框架广泛应用于数据处理与分析工作中。在 Julia 语言中,DataFrames.jl 是一个强大且灵活的数据框库,为数据操作提供了丰富的功能。本文旨在系统地介绍 DataFrames.jl 的基础概念、使用方法、常见实践和最佳…

Kotlin判空辅助工具

1)?.操作符 //执行逻辑 if (person ! null) {person.doSomething() } //表达式 person?.doSomething() 2)?:操作符 //执行逻辑 val c if (a ! null) {a } else {b } //表达式 val c a ?: b 3)!!表达式 var message: String? &qu…

unity学习20:time相关基础 Time.time 和 Time.deltaTime

目录 1 unity里的几种基本时间 1.1 time 相关测试脚本 1.2 游戏开始到现在所用的时间 Time.time 1.3 时间缩放值 Time.timeScale 1.4 固定时间间隔 Time.fixedDeltaTime 1.5 两次响应时间之间的间隔:Time.deltaTime 1.6 对应测试代码 1.7 需要关注的2个基本…

Vue.js组件开发-实现下载时暂停恢复下载

在 Vue 中实现下载时暂停和恢复功能,通常可以借助 XMLHttpRequest 对象来控制下载过程。XMLHttpRequest 允许在下载过程中暂停和继续请求。 实现步骤 创建 Vue 组件:创建一个 Vue 组件,包含下载、暂停和恢复按钮。初始化 XMLHttpRequest 对…

【llm对话系统】大模型 RAG 之回答生成:融合检索信息,生成精准答案

今天,我们将深入 RAG 流程的最后一步,也是至关重要的一步:回答生成 (Answer Generation)。 在这一步,LLM 将融合用户问题和检索到的文档片段,生成最终的答案。这个过程不仅仅是简单的文本拼接,更需要 LLM …

赚钱的究极认识

1、赚钱的本质是提供了价值或者价值想象 价值: 比如小米手机靠什么?“性价比”,什么饥饿营销,创新,用户参与,生态供应链,品牌这些不能说不重要,但是加在一起都没有“性价比”这3字重…

世上本没有路,只有“场”et“Bravo”

楔子:电气本科“工程电磁场”电气研究生课程“高等电磁场分析”和“电磁兼容”自学”天线“、“通信原理”、“射频电路”、“微波理论”等课程 文章目录 前言零、学习历程一、Maxwells equations1.James Clerk Maxwell2.自由空间中传播的电磁波3.边界条件和有限时域…

electron typescript运行并设置eslint检测

目录 一、初始化package.json 二、安装依赖 三、项目结构 四、配置启动项 五、补充:ts转js别名问题 一、初始化package.json 我的:这里的"main"没太大影响,看后面的步骤。 {"name": "xloda-cloud-ui-pc"…

学习数据结构(3)顺序表

1.动态顺序表的实现 (1)初始化 (2)扩容 (3)头部插入 (4)尾部插入 (5)头部删除 (这里注意要保证有效数据个数不为0) (6&a…

PydanticAI应用实战

PydanticAI 是一个 Python Agent 框架,旨在简化使用生成式 AI 构建生产级应用程序的过程。 它由 Pydantic 团队构建,该团队也开发了 Pydantic —— 一个在许多 Python LLM 生态系统中广泛使用的验证库。PydanticAI 的目标是为生成式 AI 应用开发带来类似 FastAPI 的体验,它基…

deepseek R1的确不错,特别是深度思考模式

deepseek R1的确不错,特别是深度思考模式,每次都能自我反省改进。比如我让 它写文案: 【赛博朋克版程序员新春密码——2025我们来破局】 亲爱的代码骑士们: 当CtrlS的肌肉记忆遇上抢票插件,当Spring Boot的…

macbook安装go语言

通过brew来安装go语言 使用brew命令时,一般都会通过brew search看看有哪些版本 brew search go执行后,返回了一堆内容,最下方展示 If you meant "go" specifically: It was migrated from homebrew/cask to homebrew/core. Cas…

若依基本使用及改造记录

若依框架想必大家都了解得不少,不可否认这是一款及其简便易用的框架。 在某种情况下(比如私活)使用起来可谓是快得一匹。 在这里小兵结合自身实际使用情况,记录一下我对若依框架的使用和改造情况。 一、源码下载 前往码云进行…

Kafka 深入服务端 — 时间轮

Kafka中存在大量的延迟操作,比如延时生产、延时拉取和延时删除等。Kafka基于时间轮概念自定义实现了一个用于延时功能的定时器,来完成这些延迟操作。 1 时间轮 Kafka没有使用基于JDK自带的Timer或DelayQueue来实现延迟功能,因为它们的插入和…

数据分析系列--②RapidMiner导入数据和存储过程

一、下载数据 点击下载AssociationAnalysisData.xlsx数据集 二、导入数据 1. 在本地计算机中创建3个文件夹 2. 从本地选择.csv或.xlsx 三、界面说明 四、存储过程 将刚刚新建的过程存储到本地 Congratulations, you are done.

AIP-133 标准方法:Create

编号133原文链接AIP-133: Standard methods: Create状态批准创建日期2019-01-23更新日期2019-01-23 在REST API中,通常向集合URI(如 /v1/publishers/{publisher}/books )发出POST请求,在集合中创建新资源。 面向资源设计&#x…