苏州相城区网站建设渭南市工程建设项目审批网上办事大厅
news/
2025/9/29 10:46:13/
文章来源:
苏州相城区网站建设,渭南市工程建设项目审批网上办事大厅,做关于什么内容的网站,台前网站建设✅作者简介#xff1a;人工智能专业本科在读#xff0c;喜欢计算机与编程#xff0c;写博客记录自己的学习历程。 #x1f34e;个人主页#xff1a;小嗷犬的个人主页 #x1f34a;个人网站#xff1a;小嗷犬的技术小站 #x1f96d;个人信条#xff1a;为天地立心人工智能专业本科在读喜欢计算机与编程写博客记录自己的学习历程。 个人主页小嗷犬的个人主页 个人网站小嗷犬的技术小站 个人信条为天地立心为生民立命为往圣继绝学为万世开太平。 本文目录 数据集与 Notebook环境准备数据集可视化模型预测Loss 与评价指标 数据集与 Notebook
数据集70 Dog Breeds-Image Data Set Notebook「MobileNet V3」70 Dog Breeds-Image Classification 环境准备
import warnings
warnings.filterwarnings(ignore)禁用警告防止干扰。
!pip install lightning --quiet安装 PyTorch Lightning。
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as snssns.set_theme(styledarkgrid, font_scale1.5, fontSimHei, rc{axes.unicode_minus:False})导入常用的库设置绘图风格。
import torch
import torchmetrics
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models导入 PyTorch 相关的库。
import lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping导入 PyTorch Lightning 相关的库。
seed 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
pl.seed_everything(seed, workersTrue)设置随机种子。 数据集
batch_size 64设置批次大小。
train_transform transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),
])test_transform transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),
])设置数据集的预处理。
train_dataset datasets.ImageFolder(root/kaggle/input/70-dog-breedsimage-data-set/train, transformtrain_transform)
val_dataset datasets.ImageFolder(root/kaggle/input/70-dog-breedsimage-data-set/valid, transformtest_transform)
test_dataset datasets.ImageFolder(root/kaggle/input/70-dog-breedsimage-data-set/test, transformtest_transform)读取数据集。
train_loader DataLoader(train_dataset, batch_sizebatch_size, shuffleTrue)
val_loader DataLoader(val_dataset, batch_sizebatch_size, shuffleTrue)
test_loader DataLoader(test_dataset, batch_sizebatch_size, shuffleFalse)加载数据集。 可视化
class_names train_dataset.classes
class_count [train_dataset.targets.count(i) for i in range(len(class_names))]
df pd.DataFrame({Class: class_names, Count: class_count})plt.figure(figsize(12, 20), dpi100)
sns.barplot(xCount, yClass, datadf)
plt.tight_layout()
plt.show()绘制训练集的类别分布。 plt.figure(figsize(12, 20), dpi100)
images, labels next(iter(val_loader))
for i in range(8):ax plt.subplot(8, 4, i 1)plt.imshow(images[i].permute(1, 2, 0).numpy())plt.title(class_names[labels[i]])plt.axis(off)
plt.tight_layout()
plt.show()绘制训练集的样本。 模型
class LitModel(pl.LightningModule):def __init__(self, num_classes1000):super().__init__()self.model models.mobilenet_v3_large(weightsIMAGENET1K_V2)# for param in self.model.parameters():# param.requires_grad Falseself.model.classifier[3] nn.Linear(self.model.classifier[3].in_features, num_classes)self.accuracy torchmetrics.Accuracy(taskmulticlass, num_classesnum_classes)self.precision torchmetrics.Precision(taskmulticlass, averagemacro, num_classesnum_classes)self.recall torchmetrics.Recall(taskmulticlass, averagemacro, num_classesnum_classes)self.f1score torchmetrics.F1Score(taskmulticlass, num_classesnum_classes)def forward(self, x):x self.model(x)return xdef configure_optimizers(self):optimizer optim.Adam(self.parameters(), lr0.001, betas(0.9, 0.99), eps1e-08, weight_decay1e-5)return optimizerdef training_step(self, batch, batch_idx):x, y batchy_hat self(x)loss F.cross_entropy(y_hat, y)self.log(train_loss, loss, on_stepTrue, on_epochFalse, prog_barTrue, loggerTrue)self.log_dict({train_acc: self.accuracy(y_hat, y),train_prec: self.precision(y_hat, y),train_recall: self.recall(y_hat, y),train_f1score: self.f1score(y_hat, y),},on_stepTrue,on_epochFalse,loggerTrue,)return lossdef validation_step(self, batch, batch_idx):x, y batchy_hat self(x)loss F.cross_entropy(y_hat, y)self.log(val_loss, loss, on_stepFalse, on_epochTrue, loggerTrue)self.log_dict({val_acc: self.accuracy(y_hat, y),val_prec: self.precision(y_hat, y),val_recall: self.recall(y_hat, y),val_f1score: self.f1score(y_hat, y),},on_stepFalse,on_epochTrue,loggerTrue,)def test_step(self, batch, batch_idx):x, y batchy_hat self(x)self.log_dict({test_acc: self.accuracy(y_hat, y),test_prec: self.precision(y_hat, y),test_recall: self.recall(y_hat, y),test_f1score: self.f1score(y_hat, y),})def predict_step(self, batch, batch_idx, dataloader_idxNone):x, y batchy_hat self(x)preds torch.argmax(y_hat, dim1)return preds定义模型。
num_classes len(class_names)
model LitModel(num_classesnum_classes)
logger CSVLogger(./)
early_stop_callback EarlyStopping(monitorval_loss, min_delta0.00, patience5, verboseFalse, modemin
)
trainer pl.Trainer(max_epochs20,enable_progress_barTrue,loggerlogger,callbacks[early_stop_callback],deterministicTrue,
)
trainer.fit(model, train_loader, val_loader)训练模型。
trainer.test(model, val_loader)测试模型。 预测
pred trainer.predict(model, test_loader)
pred torch.cat(pred, dim0)
pred pd.DataFrame(pred.numpy(), columns[Class])
pred[Class] pred[Class].apply(lambda x: class_names[x])plt.figure(figsize(12, 20), dpi100)
sns.countplot(yClass, datapred)
plt.tight_layout()
plt.show()绘制预测结果的类别分布。 Loss 与评价指标
log_path logger.log_dir /metrics.csv
metrics pd.read_csv(log_path)
x_name epochplt.figure(figsize(8, 6), dpi100)
sns.lineplot(xx_name, ytrain_loss, datametrics, labelTrain Loss, linewidth2, markero, markersize10)
sns.lineplot(xx_name, yval_loss, datametrics, labelValid Loss, linewidth2, markerX, markersize12)
plt.xlabel(Epoch)
plt.ylabel(Loss)
plt.tight_layout()
plt.show()plt.figure(figsize(14, 12), dpi100)plt.subplot(2,2,1)
sns.lineplot(xx_name, ytrain_acc, datametrics, labelTrain Accuracy, linewidth2, markero, markersize10)
sns.lineplot(xx_name, yval_acc, datametrics, labelValid Accuracy, linewidth2, markerX, markersize12)
plt.xlabel(Epoch)
plt.ylabel(Accuracy)plt.subplot(2,2,2)
sns.lineplot(xx_name, ytrain_prec, datametrics, labelTrain Precision, linewidth2, markero, markersize10)
sns.lineplot(xx_name, yval_prec, datametrics, labelValid Precision, linewidth2, markerX, markersize12)
plt.xlabel(Epoch)
plt.ylabel(Precision)plt.subplot(2,2,3)
sns.lineplot(xx_name, ytrain_recall, datametrics, labelTrain Recall, linewidth2, markero, markersize10)
sns.lineplot(xx_name, yval_recall, datametrics, labelValid Recall, linewidth2, markerX, markersize12)
plt.xlabel(Epoch)
plt.ylabel(Recall)plt.subplot(2,2,4)
sns.lineplot(xx_name, ytrain_f1score, datametrics, labelTrain F1-Score, linewidth2, markero, markersize10)
sns.lineplot(xx_name, yval_f1score, datametrics, labelValid F1-Score, linewidth2, markerX, markersize12)
plt.xlabel(Epoch)
plt.ylabel(F1-Score)plt.tight_layout()
plt.show()绘制 Loss 与评价指标的变化。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/921708.shtml
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!