保存模型:
torch.save({'epoch': epoch + 1,'state_dict': model.state_dict(),'optimizer': optimizer.state_dict(),}, datadir)
加载模型
model = model_class(num_classes=num_classes) # 定义模型
state = torch.load(datadir)
model.load_state_dict(state['state_dict'])