Day40 早停策略和模型权重的保存

@浙大疏锦行

作业:对信贷数据集进行训练后保持权重,后继续训练50次,采取早停策略

import torch import torch.nn as nn import torch.optim as optim from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.preprocessing import MinMaxScaler import time import matplotlib.pyplot as plt from tqdm import tqdm import warnings warnings.filterwarnings("ignore") # 检查GPU是否可用,优先使用GPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 若有多个GPU,可指定具体GPU,例如cuda:1 # 验证GPU是否真的在使用(可选) if torch.cuda.is_available(): print(f"GPU名称: {torch.cuda.get_device_name(0)}") torch.cuda.empty_cache() # 清空GPU缓存 # 加载信贷数据集 iris = load_iris() X = iris.data # 特征数据 y = iris.target # 标签数据 # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 归一化数据 scaler = MinMaxScaler() X_train = scaler.fit_transform(X_train) X_test = scaler.transform(X_test) # 转换为PyTorch张量并强制移至指定设备(GPU/CPU) X_train = torch.FloatTensor(X_train).to(device, non_blocking=True) y_train = torch.LongTensor(y_train).to(device, non_blocking=True) X_test = torch.FloatTensor(X_test).to(device, non_blocking=True) y_test = torch.LongTensor(y_test).to(device, non_blocking=True) class MLP(nn.Module): def __init__(self): super(MLP, self).__init__() self.fc1 = nn.Linear(4, 10) # 输入层(信贷数据集需修改输入维度) self.relu = nn.ReLU() self.fc2 = nn.Linear(10, 3) # 输出层(信贷数据集需修改输出维度) def forward(self, x): out = self.fc1(x) out = self.relu(out) out = self.fc2(out) return out # 实例化模型并移至GPU model = MLP().to(device) criterion = nn.CrossEntropyLoss() # 分类损失函数 optimizer = optim.SGD(model.parameters(), lr=0.01) # 优化器 # 首次训练参数 first_train_epochs = 20000 train_losses = [] # 首次训练损失 test_losses = [] epochs = [] # 早停参数(首次训练和继续训练共用相同策略) best_test_loss = float('inf') best_epoch = 0 patience = 50 counter = 0 early_stopped = False print("\n===== 开始首次训练 =====") start_time = time.time() with tqdm(total=first_train_epochs, desc="首次训练进度", unit="epoch") as pbar: for epoch in range(first_train_epochs): model.train() # 前向传播 outputs = model(X_train) train_loss = criterion(outputs, y_train) # 反向传播和优化 optimizer.zero_grad() train_loss.backward() optimizer.step() # 每200轮记录损失并检查早停 if (epoch + 1) % 200 == 0: model.eval() with torch.no_grad(): test_outputs = model(X_test) test_loss = criterion(test_outputs, y_test) train_losses.append(train_loss.item()) test_losses.append(test_loss.item()) epochs.append(epoch + 1) # 更新进度条 pbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'}) # 早停逻辑 if test_loss.item() < best_test_loss: best_test_loss = test_loss.item() best_epoch = epoch + 1 counter = 0 # 保存最佳模型 torch.save(model.state_dict(), 'best_model.pth') else: counter += 1 if counter >= patience: print(f"\n首次训练早停触发!在第{epoch+1}轮,测试集损失已有{patience}轮未改善。") print(f"最佳测试集损失出现在第{best_epoch}轮,损失值为{best_test_loss:.4f}") early_stopped = True break # 更新进度条 if (epoch + 1) % 1000 == 0: pbar.update(1000) # 补全进度条 if pbar.n < first_train_epochs: pbar.update(first_train_epochs - pbar.n) # 保存首次训练结束后的模型权重(核心修改点1) torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch + 1, 'best_loss': best_test_loss }, 'trained_model.pth') print(f"\n首次训练完成,权重已保存至 trained_model.pth") print(f"首次训练总耗时: {time.time() - start_time:.2f} 秒") print("\n===== 加载权重并开始继续训练 =====") # 加载保存的权重(核心修改点2) checkpoint = torch.load('trained_model.pth', map_location=device) model.load_state_dict(checkpoint['model_state_dict']) print(f"加载了首次训练至第{checkpoint['epoch']}轮的权重,最佳损失: {checkpoint['best_loss']:.4f}") # 重新初始化优化器(核心修改点3:继续训练必须重置优化器) optimizer = optim.SGD(model.parameters(), lr=0.01) # 若需要延续优化器状态,可取消下面注释(视场景选择) # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # 继续训练的参数 continue_train_epochs = 50 # 目标继续训练50轮 continue_train_losses = [] # 继续训练损失 continue_test_losses = [] continue_epochs = [] # 重置早停参数(针对继续训练) continue_best_loss = checkpoint['best_loss'] continue_counter = 0 continue_early_stop = False start_continue_time = time.time() with tqdm(total=continue_train_epochs, desc="继续训练进度", unit="epoch") as pbar: for epoch in range(continue_train_epochs): model.train() # 前向传播 outputs = model(X_train) train_loss = criterion(outputs, y_train) # 反向传播和优化 optimizer.zero_grad() train_loss.backward() optimizer.step() # 每1轮就检查损失和早停(继续训练轮数少,无需间隔) model.eval() with torch.no_grad(): test_outputs = model(X_test) test_loss = criterion(test_outputs, y_test) continue_train_losses.append(train_loss.item()) continue_test_losses.append(test_loss.item()) continue_epochs.append(epoch + 1) # 更新进度条 pbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'}) pbar.update(1) # 继续训练的早停逻辑 if test_loss.item() < continue_best_loss: continue_best_loss = test_loss.item() continue_counter = 0 # 保存继续训练后的最佳模型 torch.save(model.state_dict(), 'continue_best_model.pth') else: continue_counter += 1 if continue_counter >= patience: print(f"\n继续训练早停触发!在第{epoch+1}轮,测试集损失已有{patience}轮未改善。") print(f"继续训练最佳损失: {continue_best_loss:.4f}") continue_early_stop = True break print(f"继续训练完成,总耗时: {time.time() - start_continue_time:.2f} 秒") print(f"继续训练实际轮数: {len(continue_epochs)} 轮(早停触发则少于50轮)") print("\n===== 最终模型评估 =====") model.load_state_dict(torch.load('continue_best_model.pth', map_location=device)) model.eval() with torch.no_grad(): outputs = model(X_test) _, predicted = torch.max(outputs, 1) correct = (predicted == y_test).sum().item() accuracy = correct / y_test.size(0) print(f'测试集最终准确率: {accuracy * 100:.2f}%') # ====================== 8. 可视化 ====================== plt.figure(figsize=(12, 6)) # 绘制首次训练损失 plt.subplot(1, 2, 1) plt.plot(epochs, train_losses, label='Train Loss') plt.plot(epochs, test_losses, label='Test Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('首次训练损失曲线') plt.legend() plt.grid(True) # 绘制继续训练损失 plt.subplot(1, 2, 2) plt.plot(continue_epochs, continue_train_losses, label='Train Loss') plt.plot(continue_epochs, continue_test_losses, label='Test Loss') plt.xlabel('Continue Epoch') plt.ylabel('Loss') plt.title('继续训练50轮损失曲线') plt.legend() plt.grid(True) plt.tight_layout() plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/1198647.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI说话人拆分实战:基于Speech Seaco的多角色语音处理

AI说话人拆分实战&#xff1a;基于Speech Seaco的多角色语音处理 在日常工作中&#xff0c;我们经常会遇到包含多个发言者的会议录音、访谈记录或课堂讲解。如果需要将不同人的讲话内容区分开来&#xff0c;传统方式是人工听写后手动标注&#xff0c;效率极低且容易出错。有没…

如何验证MinerU安装成功?test.pdf运行结果查看指南

如何验证MinerU安装成功&#xff1f;test.pdf运行结果查看指南 1. 确认MinerU镜像已正确加载 你拿到的是一个专为PDF内容提取优化的深度学习环境——MinerU 2.5-1.2B 深度学习 PDF 提取镜像。这个镜像不是普通的工具包&#xff0c;而是一个完整封装了模型、依赖和测试文件的“…

BERT填空AI生产环境落地:稳定性与兼容性实测报告

BERT填空AI生产环境落地&#xff1a;稳定性与兼容性实测报告 1. 引言&#xff1a;当BERT走进真实业务场景 你有没有遇到过这样的情况&#xff1a;写文案时卡在一个词上&#xff0c;翻来覆去总觉得不够贴切&#xff1f;或者校对文档时&#xff0c;明明感觉某句话“怪怪的”&am…

从零部署DeepSeek OCR模型|WebUI镜像简化流程,支持单卡推理

从零部署DeepSeek OCR模型&#xff5c;WebUI镜像简化流程&#xff0c;支持单卡推理 1. 为什么选择 DeepSeek OCR&#xff1f; 你有没有遇到过这样的场景&#xff1a;一堆纸质发票、合同、身份证需要录入系统&#xff0c;手动打字不仅慢&#xff0c;还容易出错&#xff1f;或者…

3步搞定Llama3部署:Open-WebUI可视化界面教程

3步搞定Llama3部署&#xff1a;Open-WebUI可视化界面教程 1. 为什么选Meta-Llama-3-8B-Instruct&#xff1f;轻量、强指令、真可用 你是不是也遇到过这些情况&#xff1a;想本地跑个大模型&#xff0c;结果显存不够卡在半路&#xff1b;好不容易加载成功&#xff0c;命令行交…

GPEN教育场景应用:学生证件照自动美化系统搭建

GPEN教育场景应用&#xff1a;学生证件照自动美化系统搭建 在校园管理数字化转型的进程中&#xff0c;学生证件照作为学籍档案、一卡通、考试系统等核心业务的基础数据&#xff0c;其质量直接影响到人脸识别准确率和整体管理效率。然而&#xff0c;传统拍摄方式存在诸多痛点&a…

为什么要学数字滤波器与C语言实现

嵌入式开发中&#xff0c;你大概率遇到过这类问题&#xff1a;温度传感器数据跳变导致温控误动作、电机电流信号含高频噪声引发抖动、工业仪表测量值不稳定。这些均源于信号噪声干扰&#xff0c;而数字滤波器是解决这类问题的实用工具。 有同学会问&#xff0c;直接用现成滤波库…

YOLO26镜像功能全测评:目标检测新标杆

YOLO26镜像功能全测评&#xff1a;目标检测新标杆 近年来&#xff0c;目标检测技术在工业、安防、自动驾驶等领域持续发挥关键作用。YOLO系列作为实时检测的代表&#xff0c;不断迭代进化。最新发布的 YOLO26 在精度与速度之间实现了新的平衡&#xff0c;而基于其官方代码库构…

Z-Image-Turbo推理延迟高?9步生成优化技巧实战分享

Z-Image-Turbo推理延迟高&#xff1f;9步生成优化技巧实战分享 你是不是也遇到过这种情况&#xff1a;明明用的是RTX 4090D这种顶级显卡&#xff0c;跑Z-Image-Turbo文生图模型时&#xff0c;推理时间却迟迟下不来&#xff1f;生成一张10241024的高清图动辄几十秒&#xff0c;…

创建型模式:简单工厂模式(C语言实现)

作为C语言开发者&#xff0c;我们每天都在和各种“对象”打交道——传感器、外设、缓冲区、任务控制块……尤其是做嵌入式开发时&#xff0c;经常要写一堆类似的初始化代码&#xff1a;温度传感器要初始化I2C接口&#xff0c;光照传感器要配置SPI时序&#xff0c;湿度传感器又要…

语音社交App创新:用SenseVoiceSmall增加情感互动反馈

语音社交App创新&#xff1a;用SenseVoiceSmall增加情感互动反馈 1. 让语音社交更有“温度”&#xff1a;为什么需要情感识别&#xff1f; 你有没有这样的经历&#xff1f;在语音聊天室里&#xff0c;朋友说了一句“我还好”&#xff0c;语气却明显低落。但文字消息看不到表情…

Glyph启动失败?常见错误代码排查步骤详解教程

Glyph启动失败&#xff1f;常见错误代码排查步骤详解教程 1. 引言&#xff1a;你遇到的Glyph问题&#xff0c;可能比想象中更容易解决 你是不是也遇到了这种情况——满怀期待地部署了Glyph模型&#xff0c;点击运行后却卡在启动界面&#xff0c;或者直接弹出一串看不懂的错误…

对比实测:自己搭环境 vs 使用预置镜像微调效率差异

对比实测&#xff1a;自己搭环境 vs 使用预置镜像微调效率差异 你是否也曾经被“大模型微调”这个词吓退&#xff1f;总觉得需要庞大的算力、复杂的配置、动辄几天的调试时间&#xff1f;其实&#xff0c;随着工具链的成熟和生态的完善&#xff0c;一次完整的 LoRA 微调&#…

语音标注预处理:FSMN-VAD辅助人工标注实战案例

语音标注预处理&#xff1a;FSMN-VAD辅助人工标注实战案例 1. FSMN-VAD 离线语音端点检测控制台 在语音识别、语音合成或语音标注项目中&#xff0c;一个常见但耗时的环节是从长段录音中手动截取有效语音片段。传统的人工听辨方式不仅效率低下&#xff0c;还容易因疲劳导致漏…

效果展示:Qwen3-Reranker-4B打造的智能文档排序案例

效果展示&#xff1a;Qwen3-Reranker-4B打造的智能文档排序案例 在信息爆炸的时代&#xff0c;如何从海量文档中快速找到最相关的内容&#xff0c;是搜索、推荐和知识管理系统的共同挑战。传统检索系统往往依赖关键词匹配&#xff0c;容易忽略语义层面的相关性&#xff0c;导致…

Z-Image-Turbo生成动漫角色全过程分享

Z-Image-Turbo生成动漫角色全过程分享 1. 引言&#xff1a;为什么选择Z-Image-Turbo来创作动漫角色&#xff1f; 你有没有想过&#xff0c;只需一段文字描述&#xff0c;就能瞬间生成一张细节丰富、风格鲜明的动漫角色图&#xff1f;这不再是科幻场景。借助阿里通义实验室开源…

实时性要求高的场景:FSMN-VAD流式处理可能性分析

实时性要求高的场景&#xff1a;FSMN-VAD流式处理可能性分析 1. FSMN-VAD 离线语音端点检测控制台简介 在语音交互系统、自动转录服务和智能硬件设备中&#xff0c;语音端点检测&#xff08;Voice Activity Detection, VAD&#xff09;是不可或缺的前置环节。它负责从连续音频…

NewBie-image-Exp0.1内存泄漏?长时运行稳定性优化指南

NewBie-image-Exp0.1内存泄漏&#xff1f;长时运行稳定性优化指南 你是否在使用 NewBie-image-Exp0.1 镜像进行长时间动漫图像生成任务时&#xff0c;遇到了显存占用持续上升、系统变慢甚至进程崩溃的问题&#xff1f;这很可能是由潜在的内存泄漏或资源未及时释放导致的。虽然…

MinerU vs 其他PDF提取工具:多模态模型性能实战对比评测

MinerU vs 其他PDF提取工具&#xff1a;多模态模型性能实战对比评测 1. 引言&#xff1a;为什么PDF提取需要多模态模型&#xff1f; 你有没有遇到过这样的情况&#xff1a;一份科研论文PDF里夹着复杂的数学公式、三栏排版和嵌入式图表&#xff0c;用传统工具一转Markdown&…

科哥定制FunASR镜像实战|轻松实现语音识别与标点恢复

科哥定制FunASR镜像实战&#xff5c;轻松实现语音识别与标点恢复 1. 为什么你需要一个开箱即用的语音识别系统&#xff1f; 你有没有遇到过这样的场景&#xff1a;会议录音长达一小时&#xff0c;却要手动逐字整理成文字稿&#xff1f;或者做视频剪辑时&#xff0c;想自动生成…