基于PyTorch的图像分类特征提取与模型训练文档

概述

本代码实现了一个基于PyTorch的图像特征提取与分类模型训练流程。核心功能包括:

  1. 使用预训练ResNet18模型进行图像特征提取

  2. 将提取的特征保存为标准化格式

  3. 基于提取的特征训练分类模型

代码结构详解 

1. 库导入

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset
import numpy as np
import os
from ml.model_trainer import ModelTrainer
  • 关键库说明

    • torch:PyTorch核心库

    • torch.nn:神经网络模块

    • torchvision:计算机视觉专用模块

    • numpy:数值计算库

    • os:文件系统操作

    • ModelTrainer:自定义模型训练类(需另行实现)

2. 特征提取器类(FeatureExtractor)

初始化方法 __init__
def __init__(self):self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.model = torchvision.models.resnet18(weights='IMAGENET1K_V1')self.model = nn.Sequential(*list(self.model.children())[:-1])self.model = self.model.to(self.device).eval()self.transform = transforms.Compose([...])
  • 功能说明

    • 设备检测:自动选择GPU/CPU

    • 模型加载:使用ImageNet预训练的ResNet18

    • 模型修改:移除最后的全连接层(保留卷积特征提取器)

    • 预处理设置:标准化图像尺寸和颜色空间

特征提取方法 extract_features
def extract_features(self, data_dir):full_dataset = datasets.ImageFolder(...)loader = DataLoader(...)features = []labels = []with torch.no_grad():for inputs, targets in loader:inputs = inputs.to(self.device)outputs = self.model(inputs)features.append(outputs.squeeze().cpu().numpy())labels.append(targets.numpy())features = np.concatenate(...)labels = np.concatenate(...)return features, labels, full_dataset.classes
  • 关键参数

    • data_dir:包含分类子目录的图像数据集路径

    • batch_size=32:平衡内存使用与处理效率

    • num_workers=4:多线程数据加载

  • 处理流程

    1. 创建ImageFolder数据集

    2. 使用DataLoader批量加载

    3. 禁用梯度计算加速推理

    4. 特征维度压缩(squeeze)

    5. 设备间数据传输(GPU->CPU)

    6. 合并所有批次数据

3. 主执行流程

参数配置
DATA_DIR = "/home/.../data"  # 实际数据路径
SAVE_PATH = "./features.npz"  # 特征保存路径
特征提取与保存 
extractor = FeatureExtractor()
if not os.path.exists(SAVE_PATH):features, labels, classes = extractor.extract_features(DATA_DIR)np.savez(SAVE_PATH, features=features, labels=labels, classes=classes)
else:data = np.load(SAVE_PATH)features = data['features']labels = data['labels']
  • 文件结构

    • features: [N_samples, 512] 的特征矩阵

    • labels: [N_samples] 的标签数组

    • classes: 类别名称列表

模型训练与保存
X, y = features, labels
trainer = ModelTrainer()
model = trainer.train_model(X, y)
joblib.dump(model, 'pest_classifier.pkl')

 

  • 假设条件

    • ModelTrainer需实现训练逻辑(如SVM、随机森林等)

    • 默认使用全部数据进行训练(建议实际添加数据分割)

技术细节说明

1. 图像预处理流程

2. 特征维度分析

  • ResNet18最后层输出:512维特征向量

  • 假设1000张图像:

    • 原始图像:1000×3×224×224 (约150MB)

    • 提取特征:1000×512 (约2MB) → 显著降维

3. 性能优化策略

  • GPU加速:自动检测CUDA设备

  • 批量处理:32张/批平衡效率与内存

  • 缓存机制:避免重复特征提取

  • 梯度禁用:减少内存消耗

 

 

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/77716.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

写一个 Java 程序,用于将字符串中的指定子串替换为另一个子串

以下是一个 Java 程序,它可以将字符串中的指定子串替换为另一个子串。 public class SubstringReplacement {public static String replaceSubstring(String original, String oldSubstring, String newSubstring) {return original.replace(oldSubstring, newSubs…

Docker 容器双网卡访问物理雷达网络教程

作者: 陈梓洋 环境: ubuntu 22.04lts 时间: 2025年4月29日 Docker 容器双网卡访问物理雷达网络教程 这个教程适用于这样的场景:容器保留原有 ROS 通信网络(如 bridge 网络),同时需要访问一个物…

AWS创建多块盘并创建RAID0以及后增加空间

创建硬盘并挂载到EC2上,后查询如下 [rootip-127-0-0-1 ~]# lsblk NAME MAJ:MIN RM SIZE RO TYPE MOUNTPOINTS nvme0n1 259:0 0 40G 0 disk ├─nvme0n1p1 259:1 0 40G 0 part / ├─nvme0n1p127 259:2 0 1M 0 part └─nvme0n1p128 259:3 …

数据结构---单链表的增删查改

前言: 经过了几个月的漫长岁月,回头时年迈的小编发现,数据结构的内容还没有写博客,于是小编赶紧停下手头的活动,补上博客以洗清身上的罪孽 目录 前言 概念: 单链表的结构 我们设定一个哨兵位头节点给链…

XSS靶场实战(工作wuwuwu)

knoxss knoxss Single Reflection Using QUERY of URL ——01 测试标签 <script>alert(666666)</script>——02: " <h1>test</h1>没有反应&#xff0c;查看源码 现在需要闭合双引号&#xff0c;我计划还是先搞标签 "><h1>tes…

基于 BERT 微调一个意图识别(Intent Classification)模型

基于 BERT 微调一个意图识别&#xff08;Intent Classification&#xff09;模型&#xff0c;你的意图类别包括&#xff1a; 查询天气获取新闻咨询想听音乐想添加备忘查询备忘获取家政服务结束对话增加音量减小音量其他 具体实现步骤&#xff08;详细版&#xff09; 1. 准备你…

SSM书籍管理(环境搭建)

整合SSM&#xff1a;SpringSpringMVCMybatis 环境要求&#xff1a;IDEA、MySQL5、Tomcat9、Maven3 数据库搭建 数据库准备以下数据用于后续实验&#xff1a;创建一个ssmbuild数据库&#xff0c;表books&#xff0c;该表有4个字段&#xff0c;并且插入3条数据用于后续。 CRE…

API文档生成与测试工具推荐

在API开发过程中&#xff0c;文档的编写和维护是一项重要但繁琐的工作。为了提高效率&#xff0c;许多开发者会选择使用API文档自动生成工具或具备API文档生成功能的API门户产品。选择能导入API文档的工具生成测试脚本, 本文将全面梳理市面上符合OpenAPI 3.0规范的文档生成工具…

linux修改环境变量

添加环境变量注意事项。 vim ~/.bashrc 添加环境变量时&#xff0c;需要source ~/.bashrc后才能有效。同时只对当前shell窗口有效&#xff0c;当打开另外的shell窗口时&#xff0c;需要重新source才能起效。 1.修改bashrc文件后 2.source后打开另一个shell窗口则无效&#xff…

springboot项目中,MySQL数据库转达梦数据库

前言 前段时间&#xff0c;公司要求要把某几个项目的数据库换成达梦数据库&#xff0c;说是为了国产化。我就挺无语的&#xff0c;三四年的项目了&#xff0c;现在说要换数据库。我一开始以为这个达梦数据库应该是和TIDB差不多的。 我之前做的好几个项目部署到测试服、正式服…

【Quest开发】透视环境下抠出身体并能遮挡身体上的服装

软件&#xff1a;Unity 2022.3.51f1c1、vscode、Meta XR All in One SDK V72 硬件&#xff1a;Meta Quest3 仅针对urp管线 博主搞这个主要是想做现实里的人的变身功能&#xff0c;最后效果如下 可以看到虽然身体是半透明的&#xff0c;但是裙子依旧被完全遮挡了 原理是参考…

前端安全中的XSS(跨站脚本攻击)

XSS 类型 存储型 XSS 特征&#xff1a;恶意脚本存储在服务器&#xff08;如数据库&#xff09;&#xff0c;用户访问受感染页面时触发。场景&#xff1a;用户评论、论坛帖子等持久化内容。影响范围&#xff1a;所有访问该页面的用户。 反射型 XSS 特征&#xff1a;恶意脚本通过…

(第三篇)Springcloud之Ribbon负载均衡

一、简介 1、介绍 Spring Cloud Ribbon是Netflix发布的开源项目&#xff0c;是基于Netflix Ribbon实现的一套客户端负载均衡的工具。主要功能是提供客户端的软件负载均衡算法&#xff0c;将Netflix的中间层服务连接在一起。Ribbon客户端组件提供一系列完善的配置项如连接超时&…

大模型——使用coze搭建基于DeepSeek大模型的智能体实现智能客服问答

大模型——使用coze搭建基于DeepSeek大模型的智能体实现智能客服问答 本章实验完全依托于coze在线平台,不需要本地部署任何应用。 实验介绍 1.coze介绍 扣子(coze)是新一代 AI 应用开发平台。无论你是否有编程基础,都可以在扣子上快速搭建基于大模型的各类 AI 应用,并…

【计算机视觉】目标检测:深度解析YOLOv9:下一代实时目标检测架构的创新与实战

深度解析YOLOv9&#xff1a;下一代实时目标检测架构的创新与实战 架构演进与技术创新YOLOv9的设计哲学核心创新解析1. 可编程梯度信息&#xff08;PGI&#xff09;2. 广义高效层聚合网络&#xff08;GELAN&#xff09;3. 轻量级设计 环境配置与快速开始硬件需求建议详细安装步骤…

【SpringBoot】基于MybatisPlus的博客管理系统(1)

1.准备工作 1.1数据库 -- 建表SQL create database if not exists java_blog_spring charset utf8mb4;use java_blog_spring; -- 用户表 DROP TABLE IF EXISTS java_blog_spring.user_info; CREATE TABLE java_blog_spring.user_info(id INT NOT NULL AUTO_INCREMENT,user_na…

贵族运动项目有哪些·棒球1号位

10个具有代表性的贵族运动&#xff1a; 高尔夫 马术 网球 帆船 击剑 斯诺克 冰球 私人飞机驾驶 深海潜水 马球 贵族运动通常指具有较高参与成本、历史底蕴或社交属性的运动&#xff0c;而棒球作为一项大众化团队运动&#xff0c;与典型贵族运动的结合较为罕见。从以下几个角度探…

【Tauri2】035——sql和sqlx

前言 这篇就来看看插件sql SQL | Taurihttps://tauri.app/plugin/sql/ 正文 准备 添加依赖 tauri-plugin-sql {version "2.2.0",features ["sqlite"]} features可以是mysql、sqlite、postsql 进去features看看 sqlite ["sqlx/sqlite&quo…

全链路自动化AIGC内容工厂:构建企业级智能内容生产系统

一、工业化AIGC系统架构 1.1 生产流程设计 [需求输入] → [创意生成] → [多模态生产] → [质量审核] → [多平台分发] ↑ ↓ ↑ [用户反馈] ← [效果分析] ← [数据埋点] ← [内容投放] 1.2 技术指标要求 指标 标准值 实现方案 单日产能 1,000,000 分布式推理集群 内容合规率…

是否想要一个桌面哆啦A梦的宠物

是否想拥有一个在指定时间喊你的桌面宠物呢&#xff08;手动狗头&#xff09; 如果你有更好的想法&#xff0c;欢迎提出你的想法。 是否考虑过跟开发者一对一&#xff0c;提出你的建议&#xff08;狗头&#xff09;。 https://wwxc.lanzouo.com/idKnJ2uvq11c 密码:bbkm