百度飞浆ResNet50大模型微调实现十二种猫图像分类

12种猫分类比赛传送门

要求很简单,给train和test集,训练模型实现图像分类。

这里使用的是残差连接模型,这个平台有预训练好的模型,可以直接拿来主义。

训练十几个迭代,每个批次60左右,准确率达到90%以上

一、导入库,解压文件

import os
import zipfile
import random
import json
import cv2
import numpy as np
from PIL import Imageimport matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import paddle
import paddle.nn as nn
from paddle.io import Dataset,DataLoader
from paddle.nn import \Layer, \Conv2D, Linear, \Embedding, MaxPool2D, \BatchNorm2D, ReLUimport paddle.vision.transforms as transforms
from paddle.vision.models import resnet50
from paddle.metric import Accuracytrain_parameters = {"input_size": [3, 224, 224],                     # 输入图片的shape"class_dim": 12,                                 # 分类数"src_path":"data/data10954/cat_12_train.zip",   # 原始数据集路径"src_test_path":"data/data10954/cat_12_test.zip",   # 原始数据集路径"target_path":"/home/aistudio/data/dataset",     # 要解压的路径 "train_list_path": "./train.txt",                # train_data.txt路径"eval_list_path": "./eval.txt",                  # eval_data.txt路径"label_dict":{},                                 # 标签字典"readme_path": "/home/aistudio/data/readme.json",# readme.json路径"num_epochs":6,                                 # 训练轮数"train_batch_size": 16,                          # 批次的大小"learning_strategy": {                           # 优化函数相关的配置"lr": 0.0005                                  # 超参数学习率} 
}scr_path=train_parameters['src_path']
target_path=train_parameters['target_path']
src_test_path=train_parameters["src_test_path"]
z = zipfile.ZipFile(scr_path, 'r')
z.extractall(path=target_path)
z = zipfile.ZipFile(src_test_path, 'r')
z.extractall(path=target_path)
z.close()
for imgpath in os.listdir(target_path + '/cat_12_train'):src = os.path.join(target_path + '/cat_12_train/', imgpath)img = Image.open(src)if img.mode != 'RGB':img = img.convert('RGB')img.save(src)for imgpath in os.listdir(target_path + '/cat_12_test'):src = os.path.join(target_path + '/cat_12_test/', imgpath)img = Image.open(src)if img.mode != 'RGB':img = img.convert('RGB')img.save(src)

 解压后将所有图像变为RGB图像

二、加载训练集,进行预处理、数据增强、格式变换

transform = transforms.Compose([transforms.Resize(size=224),transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])x_train,x_eval,y_train=[],[],[]#获取训练图像和标签、测试图像和标签
contents=[]
with open('data/data10954/train_list.txt')as f:contents=f.read().split('\n')for item in contents:if item=='':continuepath='data/dataset/'+item.split('\t')[0]data=np.array(Image.open(path).convert('RGB'))data=np.array(transform(data))x_train.append(data)y_train.append(int(item.split('\t')[-1]))contetns=os.listdir('data/dataset/cat_12_test')
for item in contetns:path='data/dataset/cat_12_test/'+itemdata=np.array(Image.open(path).convert('RGB'))data=np.array(transform(data))x_eval.append(data)

重点是transforms变换的预处理

三、划分训练集和测试集

x_train=np.array(x_train)y_train=np.array(y_train)x_eval=np.array(x_eval)x_train,x_test,y_train,y_test=train_test_split(x_train,y_train,test_size=0.2,random_state=42,stratify=y_train)x_train=paddle.to_tensor(x_train,dtype='float32')
y_train=paddle.to_tensor(y_train,dtype='int64')
x_test=paddle.to_tensor(x_test,dtype='float32')
y_test=paddle.to_tensor(y_test,dtype='int64')
x_eval=paddle.to_tensor(x_eval,dtype='float32')

 这是必要的,可以随时利用测试集查看准确率

四、加载预训练模型,选择损失函数和优化器

learning_rate=0.001
epochs =5  # 迭代轮数
batch_size = 50  # 批次大小
weight_decay=1e-5
num_class=12cnn=resnet50(pretrained=True)
checkpoint=paddle.load('checkpoint.pdparams')for param in cnn.parameters():param.requires_grad=False
cnn.fc = nn.Linear(2048, num_class)
cnn.set_dict(checkpoint['cnn_state_dict'])
criterion=nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=learning_rate, parameters=cnn.fc.parameters(),weight_decay=weight_decay)

第一次训练把加载模型注释掉即可,优化器包含最后一层全连接的参数

五、模型训练 

if x_train.shape[3]==3:x_train=paddle.transpose(x_train,perm=(0,3,1,2))dataset = paddle.io.TensorDataset([x_train, y_train])
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(epochs):for batch_data, batch_labels in data_loader:outputs = cnn(batch_data)loss = criterion(outputs, batch_labels)print(epoch)loss.backward()optimizer.step()optimizer.clear_grad()print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.numpy()[0]}")#保存参数
paddle.save({'cnn_state_dict': cnn.state_dict(),}, 'checkpoint.pdparams')

 使用批处理,这个很重要,不然平台分分钟炸了

六、测试集准确率

num_class=12
batch_size=64
cnn=resnet50(pretrained=True)
checkpoint=paddle.load('checkpoint.pdparams')for param in cnn.parameters():param.requires_grad=False
cnn.fc = nn.Linear(2048, num_class)
cnn.set_dict(checkpoint['cnn_state_dict'])cnn.eval()if x_test.shape[3]==3:x_test=paddle.transpose(x_test,perm=(0,3,1,2))
dataset = paddle.io.TensorDataset([x_test, y_test])
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)with paddle.no_grad():score=0for batch_data, batch_labels in data_loader:predictions = cnn(batch_data)predicted_probabilities = paddle.nn.functional.softmax(predictions, axis=1)predicted_labels = paddle.argmax(predicted_probabilities, axis=1) print(predicted_labels)for i in range(len(predicted_labels)):if predicted_labels[i].numpy()==batch_labels[i]:score+=1print(score/len(y_test))

设置eval模式,使用批处理测试准确率 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/99736.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第二证券:市场情绪或逐步修复 十月行情值得期待

第二证券指出,周一A股商场探底回升、小幅轰动收拾,沪指全天底子出现先抑后扬的运转特征。其时上证综指与创业板指数的平均市盈率分别为12.46倍、33.94倍,处于近三年中位数以下水平,商场估值仍然处于较低区域,合适中长期…

ubuntu22.04设置中文

安装了中文语言包。 sudo apt-get install language-pack-zh-hans将系统的默认语言设置为中文 sudo update-locale LANGzh_CN.UTF-8添加环境 /etc/profile 最后中添加 export LANGzh_CN.utf8 export LC_CTYPE"zh_CN.utf8"可以在~/.bashrc文件后面也加上

华为OD机考B卷 | 100分】阿里巴巴找黄金宝箱(JAVA题解——也许是全网最详)

前言 本人是算法小白,甚至也没有做过Leetcode。所以,我相信【同为菜鸡的我更能理解作为菜鸡的你们的痛点】。 题干 1. 题目描述 一贫如洗的樵夫阿里巴巴在去砍柴的路上,无意中发现了强盗集团的藏宝地,藏宝地有编号从0~N的箱子&…

【C++设计模式之责任链模式:行为型】分析及示例

简介 责任链模式是一种行为型设计模式,它允许将请求沿着处理链传递,直到有一个处理器能够处理该请求。这种模式将请求的发送者和接收者解耦,同时提供了更高的灵活性和可扩展性。 描述 责任链模式由多个处理器组成一个处理链,每…

香橙派、树莓派、核桃派、鲁班猫安装jupyter notebook【ubuntu、Debian开发板操作类似】

文章目录 前言一、安装环境二、使用方法总结 前言 香橙派树莓派鲁班猫安装一下调试代码还是比较方便的。 一、安装环境 假设已经安装好了miniconda3。如果还没安装可以参考我另外一篇博文,有写怎么安装。 pip install jupyter notebook # 生成Jupyter Notebook的…

算法题:K 次取反后最大化的数组和(典型的贪心算法问题)

这道题没有看题解,直接提交,成绩超越99.5%,说明思路是优的。就是考虑的情况里面弯弯绕比较多,需要考虑全面一点。(本题完整题目附在了最后面) 具体思路如下: 1、首先排序,然后从最…

智能合约漏洞,价值 5200 万美元的 Vyper 漏洞攻击原理分析

智能合约漏洞,价值 5200 万美元的 Vyper 漏洞攻击原理分析 7 月 30 日,因为 Vyper 部分版本中的漏洞,导致 Curve、JPEG’d 等项目陆续受到攻击,损失总计超过 5200 万美元。 Safful 对此事件第一时间进行了技术分析,并…

2023 IDC中国数字金融论坛丨中电金信向行业分享“源启+应用重构”新范式

9月8日,IDC主办的“2023 IDC中国数字金融论坛”在北京召开。中电金信受邀参会,并带来了深度数字化转型趋势之下关于应用重构的分享与洞见。 论坛重点关注金融科技创新发展趋势与数字化转型之路,中电金信副总经理、研究院院长况文川带来了“创…

nSoftware IPWorks IoT 2022 Java 22.0.8 Crack

物联网库,使用这个轻量级组件库,可以在任何平台上的应用程序中轻松实现物联网 (IoT) 通信协议。 nSoftware IPWorks IoT 最新的 IPWorks IoT 现已推出!最新版本的 IPWorks IoT 具有现代化和简化的体验,包括 .NET 中的异步和跨平台…

[开源]MIT协议,开源论坛程序,拥有友好的用户界面和操作体验

一、开源项目简介 尤得一物是一个开源论坛程序,提供丰富的功能,可以作为管理或分享文章的论坛博客,也可以在此基础上进行自定义开发。 二、开源协议 使用MIT开源协议 三、界面展示 四、功能概述 尤得一物是一个开源论坛程序,…

vue-7-vuex

一、Vuex 概述 目标:明确Vuex是什么,应用场景以及优势 1.是什么 Vuex 是一个 Vue 的 状态管理工具,状态就是数据。 大白话:Vuex 是一个插件,可以帮我们管理 Vue 通用的数据 (多组件共享的数据)。例如:购…

Spring Boot与Kubernetes结合:构建高可靠、高性能的微服务架构

Spring Boot和Kubernetes(K8s)是当今非常热门的技术,它们的结合可以帮助开发者更高效地构建、部署和管理应用程序。本文将详细介绍Spring Boot和Kubernetes的主要特点,以及它们结合使用的优势。 一、Spring Boot的特点 Spring B…

arcgis地形分析全流程

主要内容:DEM的获取与处理、高程分析、坡度分析、坡向分析、地形起伏度分析、地表粗糙度分析、地表曲率分析; 主要工具:镶嵌至新栅格、按掩膜提取、投影栅格、坡度、坡向、焦点统计 一 DEM的获取与处理 1.1 DEM是什么? DEM(D…

安全与隐私:直播购物App开发中的重要考虑因素

随着直播购物App的崭露头角,开发者需要特别关注安全性和隐私问题。本文将介绍在直播购物App开发中的一些重要安全和隐私考虑因素,并提供相关的代码示例。 1. 数据加密 在直播购物App中,用户的个人信息和支付信息是极为敏感的数据。为了保护…

Linux文件与目录的增删改查

一、增 1、mkdir命令 作用: 创建一个新目录。格式: mkdir [选项] 要创建的目录 常用参数: -p:创建目录结构中指定的每一个目录,如果目录不存在则创建,如果目录已存在也不会被覆盖。用法示例: 1、mkdir directory:创建单个目录 这个命令会在当前目录下创建一个名为…

国外网站国内镜像

国外网站国内镜像 1. huggingface 1. huggingface huggingface—>互链高科

简单好用的CHM文件阅读器 CHM Viewer Star最新 for mac

CHM Viewer Star 是一款适用于 Mac 平台的 CHM 文件阅读器软件,支持本地和远程 CHM 文件的打开和查看。它提供了直观易用的界面设计,支持多种浏览模式,如书籍模式、缩略图模式和文本模式等,并提供了丰富的功能和工具,如…

亚马逊流量攻略:如何将流量转化为销售,测评实现销售飙升!

在电商领域,流量获取一直是一个核心议题。对于任何希望增加订单量的商家而言,将流量引导至自身店铺并成功转化为销售至关重要。对于初入电商领域或规模较小的卖家来说,亚马逊内部的流量获取通常可带来显著的销售业绩。那么,如何利…

python—如何提取word中指定内容

假设有一个Word,该Word中存在 “联系人” 关键字,如何将该Word中的联系人所对应的内容提取出来呢? 该Word内容如下所示: 要在给定的Word文档中提取出与"联系人"关键字对应的内容,可以使用Python的py…

【分享】xpath的属性表达式

在XPath中,要选择HTML文档中具有特定类的元素,您通常需要使用属性选择器 [attribute-nameattribute-value] 来选择元素,其中 attribute-name 是属性名称,attribute-value 是要匹配的属性值。对于HTML元素的类选择器,您…