Python打卡训练营day33——2025.05.22

知识点回顾:
PyTorch和cuda的安装
查看显卡信息的命令行命令(cmd中使用)
cuda的检查
简单神经网络的流程
数据预处理(归一化、转换成张量)
模型的定义
继承nn.Module类
定义每一个层
定义前向传播流程
定义损失函数和优化器
定义训练流程
可视化loss过程
预处理补充:

注意事项:

1. 分类任务中,若标签是整数(如 0/1/2 类别),需转为long类型(对应 PyTorch 的torch.long),否则交叉熵损失函数会报错。

2. 回归任务中,标签需转为float类型(如torch.float32)。

作业:今日的代码,要做到能够手敲。这已经是最简单最基础的版本了。

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import MinMaxScaler
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import numpy as npif torch.cuda.is_available():print("CUDA可用!")# 获取可用的CUDA设备数量device_count = torch.cuda.device_count()print(f"可用的CUDA设备数量: {device_count}")# 获取当前使用的CUDA设备索引current_device = torch.cuda.current_device()print(f"当前使用的CUDA设备索引: {current_device}")# 获取当前CUDA设备的名称device_name = torch.cuda.get_device_name(current_device)print(f"当前CUDA设备的名称: {device_name}")# 获取CUDA版本cuda_version = torch.version.cudaprint(f"CUDA版本: {cuda_version}")
else:print("CUDA不可用。")# 加载4特征,3分类的鸢尾花数据集
iris = load_iris()
X = iris.data  # 特征数据
y = iris.target  # 标签数据# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# # 打印下尺寸
# print(X_train.shape)
# print(y_train.shape)
# print(X_test.shape)
# print(y_test.shape)# 归一化数据,神经网络对于输入数据的尺寸敏感,归一化是最常见的处理方式
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test) #确保训练集和测试集是相同的缩放# 转换为PyTorch张量
X_train=torch.FloatTensor(X_train) 
X_test=torch.FloatTensor(X_test) 
y_train=torch.LongTensor(y_train) 
y_test=torch.LongTensor(y_test) # print(X_train.shape)
# print(y_train.shape)
# print(X_test.shape)
# print(y_test.shape)class MLP(nn.Module): # 定义一个多层感知机MLP模型,继承nn.Module类def __init__(self):super(MLP,self).__init__() #调用父类的构造函数self.fc1=nn.Linear(4,10) #输入层到隐藏层,4个特征,10个神经元self.relu=nn.ReLU() #激活函数self.fc2=nn.Linear(10,3) #隐藏层到输出层,10个神经元,3个类别def forward(self,x): #前向传播out=self.fc1(x) #输入层到隐藏层out=self.relu(out) #激活函数out=self.fc2(out) #隐藏层到输出层return out #返回输出层的结果model=MLP() #实例化模型criterion=nn.CrossEntropyLoss() #定义损失函数,交叉熵损失函数,适用于多分类问题optimizer=optim.SGD(model.parameters(),lr=0.01) #定义优化器,随机梯度下降,学习率为0.01num_epochs=20000 #定义训练轮数
losses=[] #定义一个列表,用于存储损失值
for epoch in range(num_epochs):outputs=model.forward(X_train) #前向传播,得到输出层的结果loss=criterion(outputs,y_train) #计算损失值,y_train是真实标签,outputs是模型的预测值losses.append(loss.item()) #记录损失值optimizer.zero_grad() #清空梯度loss.backward() #反向传播,计算梯度optimizer.step() #更新参数if (epoch+1)%1000==0: #每10000轮输出一次损失值print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')import matplotlib.pyplot as plt
# 可视化损失曲线
plt.plot(range(num_epochs), losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/81909.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uni-app学习笔记九-vue3 v-for指令

v-for 指令基于一个数组来渲染一个列表。v-for 指令的值需要使用 item in items 形式的特殊语法&#xff0c;其中 items 是源数据的数组&#xff0c;而 item 是迭代项的别名&#xff1a; <template><view v-for"(item,index) in 10" :key"index"…

【C++算法】70.队列+宽搜_N 叉树的层序遍历

文章目录 题目链接&#xff1a;题目描述&#xff1a;解法C 算法代码&#xff1a; 题目链接&#xff1a; 429. N 叉树的层序遍历 题目描述&#xff1a; 解法 使用队列层序遍历就可以了。 先入根节点1。queue&#xff1a;1 然后出根节点1&#xff0c;入孩子节点2&#xff0c;3&a…

pycharm无法正常调试问题

pycharm无法正常调试问题 1.错误代码 已连接到 pydev 调试器(内部版本号 231.8109.197)Traceback (most recent call last):File "E:\Python\pycharm\PyCharm 2023.1\plugins\python\helpers\pydev\_pydevd_bundle\pydevd_comm.py", line 304, in _on_runr r.deco…

【机器学习基础】机器学习入门核心算法:线性回归(Linear Regression)

机器学习入门核心算法&#xff1a;线性回归&#xff08;Linear Regression&#xff09; 1. 算法逻辑2. 算法原理与数学推导3. 评估指标4. 应用案例5. 面试题6. 扩展分析总结 1. 算法逻辑 核心思想 通过线性方程拟合数据&#xff0c;最小化预测值与真实值的误差平方和&#xff0…

手机打电话时由对方DTMF响应切换多级IVR语音菜单(话术脚本与实战)

手机打电话时由对方DTMF响应切换多级IVR语音菜单 &#xff08;话术脚本与实战&#xff09; --本地AI电话机器人 上一篇&#xff1a;手机打电话时由对方DTMF响应切换多级IVR语音应答&#xff08;二&#xff09; 下一篇&#xff1a;手机打电话时由对方DTMF响应切换多级IVR语音…

flutter dart class语法说明、示例

&#x1f539; Dart 中的 class 基本语法 class ClassName {// 属性&#xff08;字段&#xff09;数据类型 属性名;// 构造函数ClassName(this.属性名);// 方法返回类型 方法名() {// 方法体} }✅ 示例&#xff1a;创建一个简单的 Person 类 class Person {// 属性String name;…

Apollo10.0学习——planning模块(10)之依赖注入器injector_

好不好奇injector_是干什么用的&#xff1f;为什么planning每个模块都要初始化这个变量&#xff1f; 类功能概述 DependencyInjector&#xff08;依赖注入器&#xff09;是一个 集中管理规划模块关键数据和服务 的容器类。它通过提供统一的访问接口&#xff0c;解耦各个组件之…

关于vue彻底删除node_modules文件夹

Vue彻底删除node_modules的命令 vue的node_modules文件夹非常大,常规手段根本无法删除. 解决方法: 在node_modules文件夹所在的路径运行命令窗口,并执行下面的命令. npm install rimraf -g rimraf node_modules说明&#xff1a; npm install rimraf -g 该命令是安装 node…

MCTS-RAG:通过树搜索重塑小模型中的检索增强生成(RAG)

https://arxiv.org/pdf/2503.20757v1这篇论文提出了MCTS-RAG框架&#xff0c;用于解决小型语言模型在知识密集型任务上的推理能力不足问题。具体来说&#xff0c; ​​MCTS-RAG框架​​&#xff1a;MCTS-RAG通过迭代地精炼检索和推理过程来工作。给定一个查询&#xff0c;它探…

数据结构:绪论之时间复杂度与空间复杂度

作者主页 失踪人口回归&#xff0c;陆续回三中。 开辟文章新专栏——数据结构&#xff0c;恳请各位大佬批评指正&#xff01; 文章目录 作者主页 数据结构的基本知识数据&#xff1a;数据元素&#xff1a;数据对象&#xff1a;数据类型&#xff1a;数据结构&#xff1a;逻辑结…

位图算法——判断唯一字符

这道题有多种解法&#xff0c;可以创建hash数组建立映射关系判断&#xff0c;但不用新的数据结构会加分&#xff0c;因此我们有“加分”办法——用位图。 我们可以创建一个整型变量&#xff08;32位&#xff09;而一共才26个字母&#xff0c;所以我们只要用到0-25位即可&#…

深度学习之-目标检测算法汇总(超全面)

YOLO目标检测改进 YOLO V1- YOLO V10: 点这进入https://www.researchgate.net/publication/381470743_YOLOv1_to_YOLOv10_A_comprehensive_review_of_YOLO_variants_and_their_application_in_the_agricultural_domain YOLO V11: YOLO11 &#x1f680;Ultralytics YOLO11 &…

软考中级软件设计师——计算机网络篇

一、计算机网络体系结构 1.OSI七层模型 1. 物理层&#xff08;Physical Layer&#xff09; 功能&#xff1a;传输原始比特流&#xff08;0和1&#xff09;&#xff0c;定义物理介质&#xff08;如电缆、光纤&#xff09;的电气、机械特性。 关键设备&#xff1a;中继器&#…

高等数学-空间中的曲线与曲面

一、 向量的数量积&#xff1a; 直线与直线的夹角&#xff1a; 直线与平面的夹角&#xff1a; 平面与平面的夹角&#xff08;锐角&#xff09;&#xff1a; 方向余弦&#xff1a; 注&#xff1a;空间向量与坐标轴的夹角定义为向量与坐标轴正方向的夹角 例1: 二、 1、z0所…

使用计算机视觉实现目标分类和计数!!超详细入门教程

什么是物体计数和分类 在当今自动化和技术进步的时代&#xff0c;计算机视觉作为一项关键工具脱颖而出&#xff0c;在物体计数和分类任务中提供了卓越的功能。 无论是在制造、仓储、零售&#xff0c;还是在交通监控等日常应用中&#xff0c;计算机视觉系统都彻底改变了我们感知…

javaweb-html

1.交互流程&#xff1a; 浏览器向服务器发送http请求&#xff0c;服务器对浏览器进行回应&#xff0c;并发送字符串&#xff0c;浏览器能对这些字符串&#xff08;html代码&#xff09;进行解释&#xff1b; 三大web语言&#xff1a;&#xff08;1&#xff09;html&#xff1a…

图漾相机错误码解析

文章目录 1.相机错误码汇总2.常见报错码2.1 -1001报错2.1.1 没有找到相机2.1.2 SDK没有进行初始化2.1.3 相机不支持Histo属性 2.2 -1005报错2.2.1 跨网段打开相机2.2.2 旧版本SDK在软触发失败后提示的报错2.2.3 相机初始化上电时报错2.2.4 USB相机被占用 2.3 -1009报错2.3.1 相…

18. 结合Selenium和YAML对页面继承对象PO的改造

18. 结合Selenium和YAML对页面继承对象PO的改造 一、架构改造核心思路 1.1 改造前后对比 #mermaid-svg-ziagMhNLS5fIFWrx {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-ziagMhNLS5fIFWrx .error-icon{fill:#5522…

将VMware上的虚拟机和当前电脑上的Wifi网卡处在同一个局域网下,实现同一个局域网下实现共享

什么是桥接模式&#xff1a;桥接模式&#xff08;Bridging Mode&#xff09;是一种网络连接模式&#xff0c;常用于虚拟机环境中&#xff0c;将虚拟机的虚拟网络适配器直接连接到主机的物理网络适配器上&#xff0c;使虚拟机能够像独立的物理设备一样直接与物理网络通信 1.打开…

gitee错误处理总结

背景 如上图&#xff0c;根据图片中的 Git 错误提示&#xff0c;我们遇到的问题是 ​本地分支落后于远程分支&#xff0c;导致 git push 被拒绝。 ​问题原因​ 远程仓库的 master 分支有其他人推送的新提交&#xff0c;而您的本地 master 分支未同步这些更新&#xff08;即本…