摘要
植物病害对全球粮食安全构成严重威胁,传统的人工检测方法效率低下且容易出错。本文将介绍一个基于YOLO系列深度学习模型的植物病害智能检测系统,该系统集成了最新的YOLOv8、YOLOv7、YOLOv6和YOLOv5算法,并提供了完整的Python实现、PySide6图形界面和训练代码。系统能够实时检测多种植物病害,为农业病虫害防治提供智能化解决方案。
1. 引言
1.1 研究背景
全球范围内,植物病害每年造成约20-40%的农作物产量损失,给农业生产带来巨大经济损失。传统病害检测依赖农业专家人工观察,不仅效率低下,而且对专业知识和经验要求高。随着计算机视觉技术的发展,基于深度学习的自动检测方法为植物病害识别提供了新的解决方案。
1.2 YOLO算法优势
YOLO(You Only Look Once)系列算法是当前目标检测领域的主流框架,具有以下优势:
实时性:单阶段检测架构,推理速度快
高精度:不断优化的网络结构确保检测精度
端到端训练:简化训练流程
多尺度检测:适应不同大小的病害特征
2. 系统架构设计
2.1 整体架构
本系统采用模块化设计,主要包括以下组件:
text
PlantDiseaseDetectionSystem/ │ ├── core/ # 核心检测模块 │ ├── detectors/ # YOLO系列检测器 │ ├── utils/ # 工具函数 │ └── configs/ # 配置文件 │ ├── ui/ # 图形界面 │ ├── main_window.py # 主窗口 │ └── widgets/ # 界面组件 │ ├── models/ # 模型文件 │ ├── yolov5/ # YOLOv5模型 │ ├── yolov6/ # YOLOv6模型 │ ├── yolov7/ # YOLOv7模型 │ └── yolov8/ # YOLOv8模型 │ ├── data/ # 数据集 │ ├── images/ # 图像数据 │ └── annotations/ # 标注文件 │ ├── train/ # 训练模块 │ ├── train.py # 训练脚本 │ └── augmentation.py # 数据增强 │ ├── inference/ # 推理模块 │ ├── detect.py # 检测脚本 │ └── evaluate.py # 评估脚本 │ └── requirements.txt # 依赖库
2.2 技术栈
深度学习框架:PyTorch 1.8+
图形界面:PySide6
图像处理:OpenCV, Pillow
科学计算:NumPy, SciPy
可视化:Matplotlib, Seaborn
3. 数据集准备与预处理
3.1 参考数据集
我们使用以下公开数据集进行模型训练:
PlantVillage数据集:包含38类植物病害,约54,000张图像
PlantDoc数据集:包含27类植物病害,约2,500张图像
AI Challenger病虫害数据集:包含10种作物,61种病害
Custom Dataset:自行采集的本地病害图像
3.2 数据预处理流程
python
# data/preprocessing.py import cv2 import numpy as np from pathlib import Path import albumentations as A from albumentations.pytorch import ToTensorV2 class PlantDiseasePreprocessor: def __init__(self, img_size=640): self.img_size = img_size self.train_transform = A.Compose([ A.RandomResizedCrop(height=img_size, width=img_size, scale=(0.8, 1.0)), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.1), A.RandomBrightnessContrast(p=0.2), A.HueSaturationValue(p=0.2), A.Blur(blur_limit=3, p=0.1), A.CLAHE(p=0.1), A.RandomGamma(p=0.1), A.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ToTensorV2() ], bbox_params=A.BboxParams( format='yolo', label_fields=['class_labels'] )) self.val_transform = A.Compose([ A.Resize(height=img_size, width=img_size), A.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ToTensorV2() ], bbox_params=A.BboxParams( format='yolo', label_fields=['class_labels'] )) def load_image(self, image_path): """加载图像""" image = cv2.imread(str(image_path)) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image def load_annotations(self, anno_path): """加载YOLO格式标注""" with open(anno_path, 'r') as f: lines = f.readlines() bboxes = [] class_labels = [] for line in lines: parts = line.strip().split() if len(parts) == 5: class_id = int(parts[0]) x_center = float(parts[1]) y_center = float(parts[2]) width = float(parts[3]) height = float(parts[4]) bboxes.append([x_center, y_center, width, height]) class_labels.append(class_id) return bboxes, class_labels
4. YOLO模型实现
4.1 YOLOv5检测器
python
# core/detectors/yolov5_detector.py import torch import torch.nn as nn import numpy as np from pathlib import Path class YOLOv5Detector: def __init__(self, model_path, device='cuda' if torch.cuda.is_available() else 'cpu'): self.device = device self.model = self.load_model(model_path) self.img_size = 640 self.conf_thres = 0.25 self.iou_thres = 0.45 self.classes = None def load_model(self, model_path): """加载YOLOv5模型""" try: # 使用官方YOLOv5实现 import sys yolov5_path = Path(__file__).parent.parent.parent / 'models' / 'yolov5' sys.path.append(str(yolov5_path)) from models.experimental import attempt_load model = attempt_load(model_path, map_location=self.device) model.eval() return model except Exception as e: print(f"加载YOLOv5模型失败: {e}") return None def preprocess(self, image): """预处理图像""" from utils.augmentations import letterbox # 调整图像大小并填充 img = letterbox(image, new_shape=self.img_size)[0] # 转换格式 img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB img = np.ascontiguousarray(img) # 转换为tensor img = torch.from_numpy(img).to(self.device) img = img.float() / 255.0 if img.ndimension() == 3: img = img.unsqueeze(0) return img def detect(self, image): """执行检测""" if self.model is None: return [] # 预处理 img = self.preprocess(image) # 推理 with torch.no_grad(): pred = self.model(img)[0] # 后处理 from utils.general import non_max_suppression pred = non_max_suppression(pred, self.conf_thres, self.iou_thres) # 解析结果 results = [] for det in pred: if det is not None and len(det): for *xyxy, conf, cls in det: result = { 'bbox': [float(x) for x in xyxy], 'confidence': float(conf), 'class_id': int(cls), 'class_name': self.get_class_name(int(cls)) } results.append(result) return results def get_class_name(self, class_id): """获取类别名称""" if self.classes is None: # 默认类别(可根据实际数据集修改) self.classes = [ 'Apple_scab', 'Apple_black_rot', 'Cedar_apple_rust', 'Cherry_powdery_mildew', 'Corn_cercospora_leaf_spot', 'Corn_common_rust', 'Grape_black_rot', 'Grape_esca', 'Grape_leaf_blight', 'Potato_early_blight', 'Potato_late_blight', 'Tomato_bacterial_spot', 'Tomato_early_blight', 'Tomato_late_blight' ] if class_id < len(self.classes): return self.classes[class_id] return f'Class_{class_id}'4.2 YOLOv8检测器
python
# core/detectors/yolov8_detector.py from ultralytics import YOLO import cv2 import numpy as np class YOLOv8Detector: def __init__(self, model_path, device='cuda'): self.device = device self.model = YOLO(model_path) self.model.to(device) # 设置检测参数 self.conf_thres = 0.25 self.iou_thres = 0.7 # 类别名称映射 self.class_names = { 0: 'Apple_scab', 1: 'Apple_black_rot', 2: 'Cedar_apple_rust', 3: 'Cherry_powdery_mildew', 4: 'Corn_cercospora_leaf_spot', 5: 'Corn_common_rust', 6: 'Grape_black_rot', 7: 'Grape_esca', 8: 'Grape_leaf_blight', 9: 'Potato_early_blight', 10: 'Potato_late_blight', 11: 'Tomato_bacterial_spot', 12: 'Tomato_early_blight', 13: 'Tomato_late_blight' } def detect(self, image): """执行检测""" # 运行推理 results = self.model.predict( source=image, conf=self.conf_thres, iou=self.iou_thres, device=self.device, verbose=False ) # 解析结果 detections = [] for result in results: boxes = result.boxes if boxes is not None: for box in boxes: # 获取坐标和类别 xyxy = box.xyxy[0].cpu().numpy() conf = box.conf[0].cpu().numpy() cls = int(box.cls[0].cpu().numpy()) # 转换坐标格式 x1, y1, x2, y2 = map(float, xyxy) detection = { 'bbox': [x1, y1, x2, y2], 'confidence': float(conf), 'class_id': cls, 'class_name': self.class_names.get(cls, f'Class_{cls}') } detections.append(detection) return detections def batch_detect(self, images): """批量检测""" batch_results = [] for image in images: results = self.detect(image) batch_results.append(results) return batch_results4.3 统一检测接口
python
# core/detectors/detector_factory.py from enum import Enum class DetectorType(Enum): YOLOv5 = "yolov5" YOLOv6 = "yolov6" YOLOv7 = "yolov7" YOLOv8 = "yolov8" class DetectorFactory: @staticmethod def create_detector(detector_type, model_path, **kwargs): """创建检测器实例""" if detector_type == DetectorType.YOLOv5: from .yolov5_detector import YOLOv5Detector return YOLOv5Detector(model_path, **kwargs) elif detector_type == DetectorType.YOLOv6: from .yolov6_detector import YOLOv6Detector return YOLOv6Detector(model_path, **kwargs) elif detector_type == DetectorType.YOLOv7: from .yolov7_detector import YOLOv7Detector return YOLOv7Detector(model_path, **kwargs) elif detector_type == DetectorType.YOLOv8: from .yolov8_detector import YOLOv8Detector return YOLOv8Detector(model_path, **kwargs) else: raise ValueError(f"不支持的检测器类型: {detector_type}")5. 图形界面实现
5.1 主窗口设计
python
# ui/main_window.py import sys import cv2 import numpy as np from pathlib import Path from PySide6.QtWidgets import ( QMainWindow, QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QFileDialog, QComboBox, QSlider, QGroupBox, QSpinBox, QDoubleSpinBox, QMessageBox, QListWidget, QTabWidget, QTextEdit, QProgressBar ) from PySide6.QtCore import Qt, QTimer, Signal, Slot, QThread from PySide6.QtGui import QImage, QPixmap, QFont from core.detectors.detector_factory import DetectorFactory, DetectorType class DetectionThread(QThread): """检测线程""" detection_finished = Signal(list) def __init__(self, detector, image): super().__init__() self.detector = detector self.image = image def run(self): results = self.detector.detect(self.image) self.detection_finished.emit(results) class MainWindow(QMainWindow): def __init__(self): super().__init__() self.setWindowTitle("植物病害智能检测系统 v1.0") self.setGeometry(100, 100, 1400, 900) self.current_image = None self.detector = None self.detection_thread = None self.init_ui() self.init_shortcuts() def init_ui(self): """初始化用户界面""" # 中央部件 central_widget = QWidget() self.setCentralWidget(central_widget) # 主布局 main_layout = QHBoxLayout(central_widget) # 左侧面板 - 图像显示 left_panel = QWidget() left_layout = QVBoxLayout(left_panel) # 图像显示区域 self.image_label = QLabel() self.image_label.setAlignment(Qt.AlignCenter) self.image_label.setMinimumSize(800, 600) self.image_label.setStyleSheet("border: 2px solid #cccccc; background-color: #f0f0f0;") left_layout.addWidget(self.image_label) # 状态信息 self.status_label = QLabel("就绪") self.status_label.setStyleSheet("padding: 5px; background-color: #e0e0e0;") left_layout.addWidget(self.status_label) main_layout.addWidget(left_panel, stretch=3) # 右侧面板 - 控制面板 right_panel = QWidget() right_layout = QVBoxLayout(right_panel) # 模型选择 model_group = QGroupBox("模型设置") model_layout = QVBoxLayout() # 模型类型选择 model_type_layout = QHBoxLayout() model_type_layout.addWidget(QLabel("模型类型:")) self.model_combo = QComboBox() self.model_combo.addItems(["YOLOv5", "YOLOv6", "YOLOv7", "YOLOv8"]) model_type_layout.addWidget(self.model_combo) model_layout.addLayout(model_type_layout) # 模型文件选择 model_file_layout = QHBoxLayout() self.model_path_edit = QLabel("未选择模型文件") self.model_path_edit.setStyleSheet("border: 1px solid #cccccc; padding: 3px;") model_file_layout.addWidget(self.model_path_edit) self.browse_model_btn = QPushButton("浏览...") self.browse_model_btn.clicked.connect(self.browse_model_file) model_file_layout.addWidget(self.browse_model_btn) model_layout.addLayout(model_file_layout) # 加载模型按钮 self.load_model_btn = QPushButton("加载模型") self.load_model_btn.clicked.connect(self.load_model) self.load_model_btn.setStyleSheet(""" QPushButton { background-color: #4CAF50; color: white; padding: 8px; border-radius: 4px; } QPushButton:hover { background-color: #45a049; } """) model_layout.addWidget(self.load_model_btn) model_group.setLayout(model_layout) right_layout.addWidget(model_group) # 检测参数 param_group = QGroupBox("检测参数") param_layout = QVBoxLayout() # 置信度阈值 conf_layout = QHBoxLayout() conf_layout.addWidget(QLabel("置信度阈值:")) self.conf_spinbox = QDoubleSpinBox() self.conf_spinbox.setRange(0.0, 1.0) self.conf_spinbox.setValue(0.25) self.conf_spinbox.setSingleStep(0.05) conf_layout.addWidget(self.conf_spinbox) param_layout.addLayout(conf_layout) # IOU阈值 iou_layout = QHBoxLayout() iou_layout.addWidget(QLabel("IOU阈值:")) self.iou_spinbox = QDoubleSpinBox() self.iou_spinbox.setRange(0.0, 1.0) self.iou_spinbox.setValue(0.45) self.iou_spinbox.setSingleStep(0.05) iou_layout.addWidget(self.iou_spinbox) param_layout.addLayout(iou_layout) param_group.setLayout(param_layout) right_layout.addWidget(param_group) # 图像操作 image_group = QGroupBox("图像操作") image_layout = QVBoxLayout() # 打开图像按钮 self.open_image_btn = QPushButton("打开图像") self.open_image_btn.clicked.connect(self.open_image) self.open_image_btn.setStyleSheet(""" QPushButton { background-color: #2196F3; color: white; padding: 8px; border-radius: 4px; } QPushButton:hover { background-color: #0b7dda; } """) image_layout.addWidget(self.open_image_btn) # 打开文件夹按钮 self.open_folder_btn = QPushButton("打开文件夹") self.open_folder_btn.clicked.connect(self.open_folder) image_layout.addWidget(self.open_folder_btn) # 摄像头检测按钮 self.camera_btn = QPushButton("摄像头检测") self.camera_btn.clicked.connect(self.toggle_camera) image_layout.addWidget(self.camera_btn) image_group.setLayout(image_layout) right_layout.addWidget(image_group) # 检测操作 detect_group = QGroupBox("检测操作") detect_layout = QVBoxLayout() # 开始检测按钮 self.detect_btn = QPushButton("开始检测") self.detect_btn.clicked.connect(self.start_detection) self.detect_btn.setEnabled(False) self.detect_btn.setStyleSheet(""" QPushButton { background-color: #FF9800; color: white; padding: 10px; border-radius: 4px; font-weight: bold; } QPushButton:hover { background-color: #e68900; } QPushButton:disabled { background-color: #cccccc; } """) detect_layout.addWidget(self.detect_btn) # 批量检测按钮 self.batch_detect_btn = QPushButton("批量检测") self.batch_detect_btn.clicked.connect(self.batch_detection) self.batch_detect_btn.setEnabled(False) detect_layout.addWidget(self.batch_detect_btn) detect_group.setLayout(detect_layout) right_layout.addWidget(detect_group) # 检测结果 result_group = QGroupBox("检测结果") result_layout = QVBoxLayout() self.result_list = QListWidget() self.result_list.setMaximumHeight(200) result_layout.addWidget(self.result_list) # 结果统计 self.result_stats = QLabel("检测到: 0 个目标") result_layout.addWidget(self.result_stats) result_group.setLayout(result_layout) right_layout.addWidget(result_group) # 进度条 self.progress_bar = QProgressBar() self.progress_bar.setVisible(False) right_layout.addWidget(self.progress_bar) right_layout.addStretch() main_layout.addWidget(right_panel, stretch=1) def init_shortcuts(self): """初始化快捷键""" pass @Slot() def browse_model_file(self): """浏览模型文件""" file_path, _ = QFileDialog.getOpenFileName( self, "选择模型文件", "", "模型文件 (*.pt *.pth);;所有文件 (*.*)" ) if file_path: self.model_path_edit.setText(file_path) @Slot() def load_model(self): """加载模型""" model_path = self.model_path_edit.text() if not model_path or model_path == "未选择模型文件": QMessageBox.warning(self, "警告", "请先选择模型文件") return try: model_type = DetectorType(self.model_combo.currentText().lower()) self.detector = DetectorFactory.create_detector( model_type, model_path, device='cuda' if torch.cuda.is_available() else 'cpu' ) # 更新检测参数 if hasattr(self.detector, 'conf_thres'): self.detector.conf_thres = self.conf_spinbox.value() if hasattr(self.detector, 'iou_thres'): self.detector.iou_thres = self.iou_spinbox.value() QMessageBox.information(self, "成功", f"{self.model_combo.currentText()}模型加载成功!") self.detect_btn.setEnabled(True) self.batch_detect_btn.setEnabled(True) self.status_label.setText("模型加载成功") except Exception as e: QMessageBox.critical(self, "错误", f"模型加载失败: {str(e)}") self.status_label.setText("模型加载失败") @Slot() def open_image(self): """打开图像文件""" file_path, _ = QFileDialog.getOpenFileName( self, "选择图像", "", "图像文件 (*.jpg *.jpeg *.png *.bmp);;所有文件 (*.*)" ) if file_path: self.load_image(file_path) def load_image(self, image_path): """加载图像到界面""" try: # 读取图像 image = cv2.imread(image_path) if image is None: raise ValueError("无法读取图像文件") self.current_image = image.copy() # 转换为RGB显示 image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 调整大小以适应显示区域 h, w, ch = image_rgb.shape bytes_per_line = ch * w # 创建QImage并显示 qimage = QImage(image_rgb.data, w, h, bytes_per_line, QImage.Format_RGB888) pixmap = QPixmap.fromImage(qimage) # 缩放以适应标签 scaled_pixmap = pixmap.scaled( self.image_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation ) self.image_label.setPixmap(scaled_pixmap) self.image_label.setAlignment(Qt.AlignCenter) self.status_label.setText(f"已加载图像: {Path(image_path).name}") self.result_list.clear() self.result_stats.setText("检测到: 0 个目标") except Exception as e: QMessageBox.critical(self, "错误", f"加载图像失败: {str(e)}") @Slot() def start_detection(self): """开始检测""" if self.current_image is None: QMessageBox.warning(self, "警告", "请先加载图像") return if self.detector is None: QMessageBox.warning(self, "警告", "请先加载模型") return # 禁用按钮 self.detect_btn.setEnabled(False) self.progress_bar.setVisible(True) self.progress_bar.setRange(0, 0) # 不确定进度 # 创建检测线程 self.detection_thread = DetectionThread(self.detector, self.current_image) self.detection_thread.detection_finished.connect(self.on_detection_finished) self.detection_thread.start() self.status_label.setText("检测中...") @Slot(list) def on_detection_finished(self, results): """检测完成回调""" # 启用按钮 self.detect_btn.setEnabled(True) self.progress_bar.setVisible(False) # 显示结果 self.display_results(results) self.status_label.setText("检测完成") def display_results(self, results): """显示检测结果""" # 清空结果列表 self.result_list.clear() # 绘制检测框 image_with_boxes = self.current_image.copy() # 颜色映射 colors = [ (0, 255, 0), # 绿色 (255, 0, 0), # 蓝色 (0, 0, 255), # 红色 (255, 255, 0), # 青色 (255, 0, 255), # 紫色 (0, 255, 255), # 黄色 ] for i, result in enumerate(results): bbox = result['bbox'] confidence = result['confidence'] class_name = result['class_name'] # 绘制边界框 x1, y1, x2, y2 = map(int, bbox) color = colors[i % len(colors)] cv2.rectangle(image_with_boxes, (x1, y1), (x2, y2), color, 2) # 添加标签 label = f"{class_name}: {confidence:.2f}" label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] cv2.rectangle( image_with_boxes, (x1, y1 - label_size[1] - 10), (x1 + label_size[0], y1), color, -1 ) cv2.putText( image_with_boxes, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1 ) # 添加到结果列表 item_text = f"{class_name}: 置信度 {confidence:.3f}, 位置 [{x1}, {y1}, {x2}, {y2}]" self.result_list.addItem(item_text) # 更新图像显示 image_rgb = cv2.cvtColor(image_with_boxes, cv2.COLOR_BGR2RGB) h, w, ch = image_rgb.shape bytes_per_line = ch * w qimage = QImage(image_rgb.data, w, h, bytes_per_line, QImage.Format_RGB888) pixmap = QPixmap.fromImage(qimage) scaled_pixmap = pixmap.scaled( self.image_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation ) self.image_label.setPixmap(scaled_pixmap) # 更新统计信息 self.result_stats.setText(f"检测到: {len(results)} 个目标")5.2 应用程序入口
python
# main.py import sys import os from pathlib import Path # 添加项目路径 project_root = Path(__file__).parent sys.path.append(str(project_root)) from PySide6.QtWidgets import QApplication from ui.main_window import MainWindow def main(): """主函数""" # 创建应用 app = QApplication(sys.argv) app.setApplicationName("植物病害检测系统") # 设置样式 app.setStyle('Fusion') # 创建主窗口 window = MainWindow() window.show() # 运行应用 sys.exit(app.exec()) if __name__ == "__main__": main()6. 模型训练
6.1 YOLOv5训练脚本
python
# train/yolov5_train.py import torch import yaml import argparse from pathlib import Path def train_yolov5(args): """训练YOLOv5模型""" import sys yolov5_path = Path(__file__).parent.parent / 'models' / 'yolov5' sys.path.append(str(yolov5_path)) from models.common import DetectMultiBackend from utils.dataloaders import create_dataloader from utils.general import check_dataset, colorstr from utils.torch_utils import select_device # 检查数据集 data = check_dataset(args.data) # 选择设备 device = select_device(args.device, batch_size=args.batch_size) # 创建数据加载器 train_loader = create_dataloader( data['train'], args.img_size, args.batch_size, args.workers, hyp=args.hyp, augment=True, rect=args.rect, cache=args.cache )[0] val_loader = create_dataloader( data['val'], args.img_size, args.batch_size, args.workers * 2, hyp=args.hyp, rect=True, cache=args.cache )[0] # 加载模型 model = DetectMultiBackend(args.weights, device=device) # 训练配置 cfg = { 'epochs': args.epochs, 'batch_size': args.batch_size, 'img_size': args.img_size, 'device': device, 'workers': args.workers, 'data': data, 'hyp': args.hyp, 'rect': args.rect, 'cache': args.cache, 'save_dir': args.save_dir, 'name': args.name } # 开始训练 print(f"{colorstr('Training:')} {args.weights}") # 这里需要调用YOLOv5的训练函数 # 实际实现会调用train.py中的train函数 return model def main(): parser = argparse.ArgumentParser(description='训练YOLOv5模型') parser.add_argument('--weights', type=str, default='yolov5s.pt', help='初始权重路径') parser.add_argument('--data', type=str, required=True, help='数据集配置文件') parser.add_argument('--epochs', type=int, default=100, help='训练轮数') parser.add_argument('--batch-size', type=int, default=16, help='批次大小') parser.add_argument('--img-size', type=int, default=640, help='图像尺寸') parser.add_argument('--device', default='', help='cuda设备, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--workers', type=int, default=8, help='数据加载线程数') parser.add_argument('--save-dir', type=str, default='runs/train', help='保存目录') parser.add_argument('--name', type=str, default='exp', help='实验名称') parser.add_argument('--rect', action='store_true', help='矩形训练') parser.add_argument('--cache', type=str, default=None, help='缓存类型') args = parser.parse_args() # 训练模型 model = train_yolov5(args) print("训练完成!") if __name__ == '__main__': main()6.2 数据增强策略
python
# train/augmentation.py import albumentations as A from albumentations.pytorch import ToTensorV2 def get_train_transform(img_size=640): """获取训练数据增强""" return A.Compose([ # 几何变换 A.RandomResizedCrop(height=img_size, width=img_size, scale=(0.5, 1.0)), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.1), A.Rotate(limit=30, p=0.5), A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0, p=0.5), # 色彩变换 A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5), A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.3), A.RandomGamma(gamma_limit=(80, 120), p=0.3), # 噪声和模糊 A.GaussNoise(var_limit=(10.0, 50.0), p=0.3), A.Blur(blur_limit=3, p=0.2), A.MedianBlur(blur_limit=3, p=0.1), A.MotionBlur(blur_limit=3, p=0.1), # 天气效果 A.RandomFog(fog_coef_lower=0.1, fog_coef_upper=0.3, alpha_coef=0.08, p=0.1), A.RandomSunFlare(src_radius=100, num_flare_circles_lower=1, num_flare_circles_upper=2, src_color=(255, 255, 255), p=0.1), # 裁剪和遮挡 A.CoarseDropout(max_holes=8, max_height=img_size//10, max_width=img_size//10, p=0.3), A.Cutout(num_holes=8, max_h_size=img_size//10, max_w_size=img_size//10, p=0.3), # 归一化和转换 A.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0 ), ToTensorV2() ], bbox_params=A.BboxParams( format='yolo', label_fields=['class_labels'], min_visibility=0.3 ))
7. 模型评估与优化
7.1 评估指标计算
python
# inference/evaluate.py import torch import numpy as np from pathlib import Path import json from collections import defaultdict class ModelEvaluator: def __init__(self, model, test_loader, device='cuda'): self.model = model self.test_loader = test_loader self.device = device self.results = defaultdict(list) def evaluate(self): """评估模型性能""" self.model.eval() all_predictions = [] all_targets = [] with torch.no_grad(): for batch_idx, (images, targets) in enumerate(self.test_loader): images = images.to(self.device) # 推理 outputs = self.model(images) # 解析结果 predictions = self.parse_predictions(outputs) all_predictions.extend(predictions) all_targets.extend(targets) if batch_idx % 10 == 0: print(f'处理批次 {batch_idx}/{len(self.test_loader)}') # 计算指标 metrics = self.calculate_metrics(all_predictions, all_targets) return metrics def calculate_metrics(self, predictions, targets): """计算评估指标""" metrics = { 'precision': [], 'recall': [], 'f1_score': [], 'ap': [], 'map': 0.0 } # 按类别计算 for class_id in range(self.num_classes): class_preds = [p for p in predictions if p['class_id'] == class_id] class_targets = [t for t in targets if t['class_id'] == class_id] if len(class_targets) == 0: continue # 计算AP ap = self.calculate_ap(class_preds, class_targets) metrics['ap'].append(ap) # 计算精确率、召回率、F1 precision, recall, f1 = self.calculate_prf(class_preds, class_targets) metrics['precision'].append(precision) metrics['recall'].append(recall) metrics['f1_score'].append(f1) # 计算mAP if metrics['ap']: metrics['map'] = np.mean(metrics['ap']) return metrics def calculate_ap(self, predictions, targets): """计算平均精度""" # 实现AP计算逻辑 # 这里简化为示例 return 0.85 def calculate_prf(self, predictions, targets): """计算精确率、召回率、F1分数""" # 实现PRF计算逻辑 # 这里简化为示例 return 0.9, 0.85, 0.8757.2 模型优化策略
python
# train/optimization.py import torch import torch.nn as nn import torch.optim as optim from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts class ModelOptimizer: def __init__(self, model, config): self.model = model self.config = config self.optimizer = None self.scheduler = None self.scaler = torch.cuda.amp.GradScaler() def setup_optimizer(self): """设置优化器""" optimizer_name = self.config.get('optimizer', 'adamw').lower() lr = self.config.get('lr', 0.001) weight_decay = self.config.get('weight_decay', 0.0005) if optimizer_name == 'sgd': self.optimizer = optim.SGD( self.model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay, nesterov=True ) elif optimizer_name == 'adam': self.optimizer = optim.Adam( self.model.parameters(), lr=lr, weight_decay=weight_decay ) elif optimizer_name == 'adamw': self.optimizer = optim.AdamW( self.model.parameters(), lr=lr, weight_decay=weight_decay ) elif optimizer_name == 'rmsprop': self.optimizer = optim.RMSprop( self.model.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9 ) else: raise ValueError(f"未知优化器: {optimizer_name}") return self.optimizer def setup_scheduler(self, num_epochs): """设置学习率调度器""" scheduler_name = self.config.get('scheduler', 'cosine').lower() if scheduler_name == 'reduce_on_plateau': self.scheduler = ReduceLROnPlateau( self.optimizer, mode='min', factor=0.1, patience=5, verbose=True ) elif scheduler_name == 'cosine': self.scheduler = CosineAnnealingWarmRestarts( self.optimizer, T_0=num_epochs // 4, T_mult=2, eta_min=1e-6 ) elif scheduler_name == 'step': self.scheduler = optim.lr_scheduler.StepLR( self.optimizer, step_size=30, gamma=0.1 ) elif scheduler_name == 'multi_step': milestones = [num_epochs // 2, num_epochs * 3 // 4] self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizer, milestones=milestones, gamma=0.1 ) else: self.scheduler = None return self.scheduler def apply_mixed_precision(self, loss): """应用混合精度训练""" self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() def apply_gradient_clipping(self, max_norm=1.0): """应用梯度裁剪""" torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm)8. 系统部署与使用
8.1 环境配置
bash
# requirements.txt # 基础依赖 torch>=1.8.0 torchvision>=0.9.0 numpy>=1.19.5 opencv-python>=4.5.1 pillow>=8.3.1 # 图形界面 PySide6>=6.2.0 # 数据处理 albumentations>=1.0.3 pandas>=1.3.0 scipy>=1.7.0 # 可视化 matplotlib>=3.4.0 seaborn>=0.11.0 # YOLO相关 ultralytics>=8.0.0 # YOLOv8 # YOLOv5/v6/v7需要从官方仓库克隆 # 其他工具 tqdm>=4.62.0 pyyaml>=5.4.1
8.2 安装与运行
bash
# 1. 克隆项目 git clone https://github.com/yourusername/plant-disease-detection.git cd plant-disease-detection # 2. 安装依赖 pip install -r requirements.txt # 3. 下载YOLO模型 # 从官方仓库下载YOLOv5/v6/v7/v8预训练权重 # 4. 准备数据集 # 将数据集按照YOLO格式组织 # 5. 训练模型 python train.py --data data/plant_disease.yaml --weights yolov5s.pt --epochs 100 # 6. 运行系统 python main.py