保存权重文件时,最好使用copy.deepcopy,不然可能出现引用的问题,导致本应该保存best pth的变成保存最后一个epoch的pth。
/root/unified_nas/training/trainer.py
# 更新最佳模型 if val_metrics['accuracy'] > best_accuracy: best_accuracy = val_metrics['accuracy'] best_val_metrics = val_metrics best_model_state = { # 'model': self.model.state_dict(), # 'head': self.task_head.state_dict() 'model': copy.deepcopy(self.model.state_dict()), # ✅ 深拷贝 'head': copy.deepcopy(self.task_head.state_dict()) # ✅ 深拷贝 } # 保存最佳模型权重到文件 torch.save(best_model_state, save_path) # print(f"✅ Best model saved with accuracy: {best_accuracy:.2f}%") self._output(f"✅ Best model saved with accuracy: {best_accuracy:.2f}%")这部分的完整代码如下:
import torch import torch.nn as nn from torch.optim import Adam from tqdm import tqdm import numpy as np from collections import defaultdict import copy # 设置随机数种子 SEED = 42 # 你可以选择任何整数作为种子 torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) class SingleTaskTrainer: """ 针对单个数据集的训练器 """ def __init__(self, model, dataloaders, device='cuda', logger=None): """ 初始化训练器 参数: model: 要训练的模型 dataloaders: 数据加载器字典,包含 'train' 和 'test' 两个键 device: 训练设备 ('cuda' 或 'cpu') """ self.model = model.to(device) self.dataloaders = dataloaders self.device = device self.logger = logger # 如果没有提供logger,创建一个简单的logger来模拟print行为 # 确保模型有 output_dim 属性 if not hasattr(model, 'output_dim'): raise AttributeError("Model must have 'output_dim' attribute") # 获取类别数 self.num_classes = len(dataloaders['train'].dataset.classes) print(f"Number of classes: {self.num_classes}") # 创建任务头 self.task_head = nn.Linear(model.output_dim, self.num_classes).to(device) # 定义损失函数和优化器 self.criterion = nn.CrossEntropyLoss() self.optimizer = Adam( list(model.parameters()) + list(self.task_head.parameters()), lr=1e-3 ) def _output(self, message): """统一的输出方法:如果有logger则使用logger,否则使用print""" if self.logger: self.logger.info(message) else: print(message) def train_epoch(self): """ 单个训练周期 """ self.model.train() self.task_head.train() running_loss = 0.0 correct = 0 total = 0 for inputs, labels in tqdm(self.dataloaders['train'], desc="Training"): inputs = inputs.to(self.device) labels = labels.to(self.device) self.optimizer.zero_grad() features = self.model(inputs) outputs = self.task_head(features) loss = self.criterion(outputs, labels) loss.backward() self.optimizer.step() running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() metrics = { 'loss': running_loss / len(self.dataloaders['train']), 'accuracy': 100. * correct / total } return metrics def evaluate(self): """ 模型评估 """ self.model.eval() self.task_head.eval() running_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for inputs, labels in tqdm(self.dataloaders['test'], desc="Evaluating"): inputs = inputs.to(self.device) labels = labels.to(self.device) features = self.model(inputs) outputs = self.task_head(features) loss = self.criterion(outputs, labels) running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() metrics = { 'loss': running_loss / len(self.dataloaders['test']), 'accuracy': 100. * correct / total } return metrics def train(self, epochs=10, save_path='best_model.pth'): """ 训练模型并保存最佳权重 参数: epochs: 训练周期数 save_path: 最佳模型权重保存路径 返回: best_accuracy: 最佳验证准确率 best_val_metrics: 最佳验证指标 history: 训练历史记录 best_model_state: 最佳模型状态字典 """ best_accuracy = 0.0 best_val_metrics = None # 保存最佳验证指标 history = [] best_model_state = None # 保存最佳模型状态 for epoch in range(epochs): # print(f"\nEpoch {epoch + 1}/{epochs}") self._output(f"\nEpoch {epoch + 1}/{epochs}") # 训练阶段 train_metrics = self.train_epoch() # 验证阶段 val_metrics = self.evaluate() # 保存历史 history.append({ 'train': train_metrics, 'val': val_metrics }) # print(f"\nValidation Accuracy: {val_metrics['accuracy']:.2f}%") self._output(f"\nValidation Accuracy: {val_metrics['accuracy']:.2f}%") # 更新最佳模型 if val_metrics['accuracy'] > best_accuracy: best_accuracy = val_metrics['accuracy'] best_val_metrics = val_metrics best_model_state = { # 'model': self.model.state_dict(), # 'head': self.task_head.state_dict() 'model': copy.deepcopy(self.model.state_dict()), # ✅ 深拷贝 'head': copy.deepcopy(self.task_head.state_dict()) # ✅ 深拷贝 } # 保存最佳模型权重到文件 torch.save(best_model_state, save_path) # print(f"✅ Best model saved with accuracy: {best_accuracy:.2f}%") self._output(f"✅ Best model saved with accuracy: {best_accuracy:.2f}%") return best_accuracy, best_val_metrics, history, best_model_stateNote:不要轻易改模型结构,与其改模型结构来调整问题,不如相信就按照现在的结构继续做。问题往往意想不到。