文字识别系统

news/2025/11/20 21:26:47/文章来源:https://www.cnblogs.com/xxxjxx/p/19249609

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import torchvision.transforms as transforms
import numpy as np

自定义数据集类

class CustomDataset(Dataset):
def init(self, data_directory, transform=None):
self.data_directory = data_directory
self.transform = transform
self.classes = os.listdir(data_directory)
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
self.samples = []
for cls in self.classes:
cls_dir = os.path.join(data_directory, cls)
for img_name in os.listdir(cls_dir):
img_path = os.path.join(cls_dir, img_name)
self.samples.append((img_path, self.class_to_idx[cls]))

def __len__(self):return len(self.samples)def __getitem__(self, idx):img_path, label = self.samples[idx]image = Image.open(img_path).convert('RGB')if self.transform:image = self.transform(image)return image, label

改进的CNN模型(增加层数、添加BatchNorm和Dropout)

class ImprovedCNN(nn.Module):
def init(self, num_classes):
super(ImprovedCNN, self).init()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

    self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.bn2 = nn.BatchNorm2d(64)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)self.bn3 = nn.BatchNorm2d(128)self.relu3 = nn.ReLU()self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(128 * 4 * 4, 512)  # 假设输入图像是32x32,可根据实际调整self.dropout1 = nn.Dropout(0.5)self.relu4 = nn.ReLU()self.fc2 = nn.Linear(512, num_classes)def forward(self, x):x = self.pool1(self.relu1(self.bn1(self.conv1(x))))x = self.pool2(self.relu2(self.bn2(self.conv2(x))))x = self.pool3(self.relu3(self.bn3(self.conv3(x))))x = x.view(x.size(0), -1)x = self.relu4(self.dropout1(self.fc1(x)))x = self.fc2(x)return x

早停类(防止过拟合)

class EarlyStopping:
def init(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
self.path = path

def __call__(self, val_loss, model):score = -val_lossif self.best_score is None:self.best_score = scoreself.save_checkpoint(val_loss, model)elif score < self.best_score + self.delta:self.counter += 1if self.verbose:print(f'EarlyStopping counter: {self.counter} out of {self.patience}')if self.counter >= self.patience:self.early_stop = Trueelse:self.best_score = scoreself.save_checkpoint(val_loss, model)self.counter = 0def save_checkpoint(self, val_loss, model):if self.verbose:print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')torch.save(model.state_dict(), self.path)self.val_loss_min = val_loss

训练、验证、测试函数

def train_validate_test(data_directory, epochs=100):
# 增强的数据预处理
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = CustomDataset(os.path.join(data_directory, 'train'), transform=train_transform)
test_dataset = CustomDataset(os.path.join(data_directory, 'test'), transform=test_transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)num_classes = len(train_dataset.classes)
model = ImprovedCNN(num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)  # 更换为Adam优化器,添加权重衰减
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)  # 余弦退火学习率调度
early_stopping = EarlyStopping(patience=10, verbose=True)  # 早停机制class_mapping = train_dataset.class_to_idx# 训练循环
for epoch in range(epochs):model.train()running_loss = 0.0for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()train_loss = running_loss / len(train_loader)# 验证model.eval()val_loss = 0.0with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)loss = criterion(outputs, labels)val_loss += loss.item()val_loss = val_loss / len(test_loader)scheduler.step()print(f'Epoch {epoch+1}, Train Loss: {train_loss:.3f}, Val Loss: {val_loss:.3f}, LR: {scheduler.get_last_lr()[0]:.6f}')# 早停判断early_stopping(val_loss, model)if early_stopping.early_stop:print("Early stopping")break# 加载最佳模型
model.load_state_dict(torch.load('checkpoint.pt'))# 测试
model.eval()
correct = 0
total = 0
with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()
test_accuracy = 100 * correct / total
print(f'Test Accuracy: {test_accuracy:.2f}%')return class_mapping, test_accuracy

调用函数

data_directory = 'D:/pytorch/shuzi' # 替换为你的数据目录
class_mapping, test_accuracy = train_validate_test(data_directory, epochs=100)
print("Class Mapping:", class_mapping)
print("Test Accuracy:", test_accuracy)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/971445.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025 门窗十大品牌精准选购指南:行业评估报告 + 白皮书护航,选窗不踩坑!

2025 年度门窗十大品牌的筛选工作,以中国建筑金属结构协会正式发布的《2025 年度建筑门窗行业发展评估报告》为核心根基,深度挖掘报告中关于行业技术迭代方向、品牌综合竞争力评级、产品核心性能基准参数等关键信息,…

写的都对_第二次软件工程作业

第二次软件工程作业 一、格式描述作业所属课程 软件工程 班级的链接 https://edu.cnblogs.com/campus/fzu/202501SoftwareEngineering作业要求 https://edu.cnblogs.com/campus/fzu/202501SoftwareEngineering/homewor…

深入解析:spark组件-spark core(批处理)-rdd血缘

深入解析:spark组件-spark core(批处理)-rdd血缘pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas"…

深入解析:开源 Linux 服务器与中间件(十二)FRP内网穿透应用

深入解析:开源 Linux 服务器与中间件(十二)FRP内网穿透应用2025-11-20 21:21 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !impor…

CF1542E1 Abnormal Permutation Pairs (easy version)

我们不妨想一个简单的问题,如何计算一个长度为 \(n\) 的排列且逆序对个数为 \(m\) 的方案数。 令 \(f_{i, j}\) 为长度为 \(i\) 的排列逆序对个数为 \(j\) 的方案数。 我们转移的时候,本质上可以任选最后一个数到底增…

网络流建模

网络流建模 最大流 多源多汇 如果一道题中有多个可行的源点 \(s_1,\ldots,s_a\) 和多个可行的汇点 \(t_1,\ldots,t_b\),那么可以建立超级源汇 \(S,T\),从 \(S\) 向 \(s_i\) 连容量无穷的边,\(t_i\) 向 \(T\) 连容量…

实用指南:GLM 智能助力・Trae 跨端个人任务清单

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

AT_agc050 总结

久违地发一次考试总结。因为这次写的比较详细,勉强能拿出来看看。 A 第一反应是线段树。(其实按位考虑说不定对于某些题也是一种突破口) 正解是连 \((2*p)-1\bmod n+1)\) 和 \((2*p+1)-1\bmod n+1\) 然后发现对于每…

补 二分法与图

题目:洛谷p1462 只要某个性质具有单调性,就必然可以二分。 以最短路为判断条件,二分费用,只允许使用费用小于等于目前费用的节点,求最短路,看是否可行,再根据可行性二分费用,最后求出费用的最小值 K 越大,可行…

SpringSecurity 集成 CAS Client 处理单点登录 - Higurashi

推荐阅读:CAS 单点登录详细流程背景 当前业务系统基于 Spring Security,现在需要集成 CAS,当用户访问业务系统时,如果用户没有登录,则跳转到 CAS Server 统一登录页面完成登录。 而当用户从 CAS Server 退出登录后…

NOIP2025模拟赛12(炼石计划NOIP模拟赛第 19 套题目)

赤了这口魔拟赛的石!写在前面: 我艹了何意味啊何意味T1放依托定理的板子题然后我还没听过这个定理(虽然据说是数论基础四大定理之一,但是好像学习数论基础的时候根本没看到过这个定理也没做过相关的题😡😡😡…

[nanoGPT] GPT模型架构 | `LayerNorm` | `CausalSelfAttention` |`MLP` | `Block` - 实践

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

duckdb索引介绍

duckdb支持的索引类型 duckdb支持两种索引:Min-Max Index (Zonemap)和Adaptive Radix Tree (ART)。前者所有通用数据类型(general-purpose data types,也就是常用的数据类型如varchar/integer/date等,非通用类型指的…

25.11.20 最长不升序列LNIS和最长升序列LIS

LNIS 1.处理一个数时: 如果这个数小于等于当前序列的最后一个数,则直接接在后面,ct++ 反之,从序列头开始寻找第一个比这个数小的数并且替代他,目的:使这个序列更容易接后面的数 2.代码模板 int LNIS(vector&…

2025.11.20 B 题解

感觉其实今天 \(B\) 是最有趣的,难度估在上位紫吧。一眼数学,两眼不是数学,三眼发现可以让 \(x\) 向 \((dx+t)\bmod n\ (t\in[L,R])\) 连边,然后从每个 \(x\) 找到到根最短路径。对于每个给出的 \(x\),它所覆盖的…

重组干扰素蛋白的结构特点与分子性质综述

一、干扰素的类别与基础结构特征 干扰素(interferon,IFN)是一类具有典型结构模式的小分子蛋白,在哺乳动物中广泛表达,其最显著的特征是以折叠紧凑的 α 螺旋结构或二聚体结构实现分子稳定性。按照分子结构、序列特…

2025 门窗十大品牌权威榜单:依托行业评估报告 + 选购白皮书,省心采购指南!

本次 2025 年门窗十大品牌筛选工作,以中国建筑金属结构协会重磅发布的《2025 年度建筑门窗行业发展评估报告》为核心数据支撑,深度拆解报告中关于行业技术趋势、品牌综合竞争力、产品性能核心指标等关键内容,同时整…

实用指南:OpenCV下载安装教程(非常详细)从零基础入门到精通,看完这一篇就够了(附安装包)

实用指南:OpenCV下载安装教程(非常详细)从零基础入门到精通,看完这一篇就够了(附安装包)2025-11-20 21:08 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word-wrap: normal !impo…

详解 DPO

DPO 隐式地优化了与现有 RLHF 算法(基于 KL 散度约束的奖励最大化)相同的目标函数。然而,与传统 RLHF 方法(需要首先训练一个独立的奖励模型,然后通过强化学习来优化策略)不同,DPO 推导并提出了一种直接利用人类…