手写汉字识别

news/2025/10/30 21:28:15/文章来源:https://www.cnblogs.com/xiaoguo1111/p/19178168

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torch.optim import AdamW, SGD
import torchvision.models as models
from PIL import Image
import numpy as np
import warnings
warnings.filterwarnings("ignore", message=".weights_only.")

设置随机种子以保证可重复性

def set_seed(seed=42):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

set_seed(42)

定义生成图像集路径文档的函数

def classes_txt(root, out_path, num_class=None):
dirs = os.listdir(root)
if not num_class:
num_class = len(dirs)
if not os.path.exists(out_path):
with open(out_path, 'w') as f:
pass
with open(out_path, 'r+') as f:
try:
end = int(f.readlines()[-1].split(' ')[-1]) + 1
except:
end = 0
if end < num_class - 1:
dirs.sort()
dirs = dirs[end:num_class]
for dir in dirs:
files = os.listdir(os.path.join(root, dir))
for file in files:
f.write(os.path.join(root, dir, file) + '\n')

class MyDataset(Dataset):
def init(self, txt_path, num_class, transforms=None):
super().init()
images = []
labels = []
with open(txt_path, 'r') as f:
for line in f:
if int(line.split('\')[-2]) >= num_class:
break
line = line.strip('\n')
images.append(line)
labels.append(int(line.split('\')[-2]))
self.images = images
self.labels = labels
self.transforms = transforms

def __getitem__(self, index):image = Image.open(self.images[index]).convert('RGB')label = self.labels[index]if self.transforms is not None:image = self.transforms(image)return image, labeldef __len__(self):return len(self.labels)

========== 高级数据增强 ==========

train_transform = transforms.Compose([
transforms.Resize((128, 128)), # 更大的图像尺寸
transforms.RandomRotation(15), # 更大的旋转角度
transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), scale=(0.85, 1.15)), # 平移和缩放
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.05), # 垂直翻转(少量)
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.1),
transforms.RandomGrayscale(p=0.1), # 随机灰度化
transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.2), # 高斯模糊
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])

val_test_transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])

========== 高级模型架构 ==========

class AdvancedNet(nn.Module):
def init(self, num_classes=10, dropout_rate=0.5):
super(AdvancedNet, self).init()

    # 第一个卷积块self.conv1 = nn.Conv2d(1, 64, 3, padding=1)self.bn1 = nn.BatchNorm2d(64)self.conv2 = nn.Conv2d(64, 64, 3, padding=1)self.bn2 = nn.BatchNorm2d(64)self.pool1 = nn.MaxPool2d(2, 2)self.dropout1 = nn.Dropout2d(dropout_rate/2)# 第二个卷积块self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.bn3 = nn.BatchNorm2d(128)self.conv4 = nn.Conv2d(128, 128, 3, padding=1)self.bn4 = nn.BatchNorm2d(128)self.pool2 = nn.MaxPool2d(2, 2)self.dropout2 = nn.Dropout2d(dropout_rate/2)# 第三个卷积块self.conv5 = nn.Conv2d(128, 256, 3, padding=1)self.bn5 = nn.BatchNorm2d(256)self.conv6 = nn.Conv2d(256, 256, 3, padding=1)self.bn6 = nn.BatchNorm2d(256)self.pool3 = nn.MaxPool2d(2, 2)self.dropout3 = nn.Dropout2d(dropout_rate/2)# 第四个卷积块self.conv7 = nn.Conv2d(256, 512, 3, padding=1)self.bn7 = nn.BatchNorm2d(512)self.conv8 = nn.Conv2d(512, 512, 3, padding=1)self.bn8 = nn.BatchNorm2d(512)self.pool4 = nn.AdaptiveAvgPool2d((4, 4))  # 自适应池化self.dropout4 = nn.Dropout2d(dropout_rate/2)# 全连接层self.fc1 = nn.Linear(512 * 4 * 4, 1024)self.bn9 = nn.BatchNorm1d(1024)self.dropout5 = nn.Dropout(dropout_rate)self.fc2 = nn.Linear(1024, 512)self.bn10 = nn.BatchNorm1d(512)self.dropout6 = nn.Dropout(dropout_rate)self.fc3 = nn.Linear(512, num_classes)def forward(self, x):# 第一个卷积块x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.pool1(x)x = self.dropout1(x)# 第二个卷积块x = F.relu(self.bn3(self.conv3(x)))x = F.relu(self.bn4(self.conv4(x)))x = self.pool2(x)x = self.dropout2(x)# 第三个卷积块x = F.relu(self.bn5(self.conv5(x)))x = F.relu(self.bn6(self.conv6(x)))x = self.pool3(x)x = self.dropout3(x)# 第四个卷积块x = F.relu(self.bn7(self.conv7(x)))x = F.relu(self.bn8(self.conv8(x)))x = self.pool4(x)x = self.dropout4(x)# 全连接层x = x.view(-1, 512 * 4 * 4)x = F.relu(self.bn9(self.fc1(x)))x = self.dropout5(x)x = F.relu(self.bn10(self.fc2(x)))x = self.dropout6(x)x = self.fc3(x)return x

========== 高级训练策略 ==========

def train_advanced_model(model, train_loader, val_loader, num_epochs=50):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# 使用AdamW优化器
optimizer = AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)# 组合学习率调度器
scheduler1 = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
scheduler2 = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)# 标签平滑的损失函数
class LabelSmoothingCrossEntropy(nn.Module):def __init__(self, smoothing=0.1):super(LabelSmoothingCrossEntropy, self).__init__()self.smoothing = smoothingdef forward(self, x, target):confidence = 1. - self.smoothinglogprobs = F.log_softmax(x, dim=-1)nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))nll_loss = nll_loss.squeeze(1)smooth_loss = -logprobs.mean(dim=-1)loss = confidence * nll_loss + self.smoothing * smooth_lossreturn loss.mean()criterion = LabelSmoothingCrossEntropy(smoothing=0.1)best_acc = 0.0
patience_counter = 0
patience = 10  # 早停耐心值for epoch in range(num_epochs):# 训练阶段model.train()running_loss = 0.0correct_train = 0total_train = 0for i, (images, labels) in enumerate(train_loader):images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()running_loss += loss.item()# 计算训练准确率_, predicted = torch.max(outputs.data, 1)total_train += labels.size(0)correct_train += (predicted == labels).sum().item()if i % 50 == 49:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss/50:.4f}')running_loss = 0.0train_accuracy = 100 * correct_train / total_train# 验证阶段model.eval()correct_val = 0total_val = 0val_loss = 0.0with torch.no_grad():for images, labels in val_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)loss = criterion(outputs, labels)val_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).sum().item()val_accuracy = 100 * correct_val / total_valavg_val_loss = val_loss / len(val_loader)print(f'Epoch [{epoch+1}/{num_epochs}], Train Acc: {train_accuracy:.2f}%, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.2f}%')# 学习率调整scheduler1.step()scheduler2.step(avg_val_loss)# 保存最佳模型if val_accuracy > best_acc:best_acc = val_accuracytorch.save(model.state_dict(), 'best_model.pkl')print(f'New best model saved with validation accuracy: {best_acc:.2f}%')patience_counter = 0else:patience_counter += 1# 早停if patience_counter >= patience:print(f'Early stopping at epoch {epoch+1}')breakreturn best_acc

创建验证集

def create_datasets(train_txt_path, test_txt_path, num_classes=10):
# 加载完整训练集
full_train_set = MyDataset(train_txt_path, num_classes, transforms=train_transform)

# 划分训练集和验证集 (85% 训练, 15% 验证)
train_size = int(0.85 * len(full_train_set))
val_size = len(full_train_set) - train_size
train_set, val_set = random_split(full_train_set, [train_size, val_size])# 测试集
test_set = MyDataset(test_txt_path, num_classes, transforms=val_test_transform)return train_set, val_set, test_set

测试时间增强 (TTA)

def test_time_augmentation(model, dataloader, device, num_augmentations=5):
model.eval()
all_predictions = []
all_labels = []

with torch.no_grad():for images, labels in dataloader:images, labels = images.to(device), labels.to(device)batch_predictions = []# 原始图像预测outputs = model(images)batch_predictions.append(F.softmax(outputs, dim=1))# 数据增强预测augmentations = [transforms.Compose([transforms.RandomRotation(10),transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5])]),transforms.Compose([transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5])]),transforms.Compose([transforms.ColorJitter(brightness=0.2, contrast=0.2),transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5])])]for aug in augmentations[:num_augmentations-1]:augmented_images = torch.stack([aug(transforms.ToPILImage()(img)) for img in images.cpu()])augmented_images = augmented_images.to(device)outputs = model(augmented_images)batch_predictions.append(F.softmax(outputs, dim=1))# 平均所有预测avg_predictions = torch.stack(batch_predictions).mean(dim=0)_, predicted = torch.max(avg_predictions, 1)all_predictions.extend(predicted.cpu().numpy())all_labels.extend(labels.cpu().numpy())accuracy = 100 * np.sum(np.array(all_predictions) == np.array(all_labels)) / len(all_labels)
return accuracy

主函数

def main():
# 使用英文路径避免编码问题
root = 'D:/深度学习/手写汉字识别/data'
model_save_dir = 'D:/深度学习/手写汉字识别/tmp'

# 首先生成TXT文件
classes_txt(root + '/train', root + '/train.txt', num_class=10)
classes_txt(root + '/test', root + '/test.txt', num_class=10)# 创建数据集
train_set, val_set, test_set = create_datasets(root + '/train.txt', root + '/test.txt', num_classes=10
)# 创建数据加载器 - 使用较小的批量大小
train_loader = DataLoader(train_set, batch_size=16, shuffle=True, num_workers=0)
val_loader = DataLoader(val_set, batch_size=16, shuffle=False, num_workers=0)
test_loader = DataLoader(test_set, batch_size=16, shuffle=False, num_workers=0)print(f"训练集大小: {len(train_set)}")
print(f"验证集大小: {len(val_set)}")
print(f"测试集大小: {len(test_set)}")# 创建高级模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AdvancedNet(num_classes=10, dropout_rate=0.5).to(device)
print("高级模型已创建,开始训练...")# 训练模型
best_val_acc = train_advanced_model(model, train_loader, val_loader, num_epochs=100
)# 加载最佳模型
model.load_state_dict(torch.load('best_model.pkl', weights_only=True))# 在测试集上评估(不使用TTA)
model.eval()
correct = 0
total = 0
with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()test_accuracy = 100 * correct / total
print(f'标准测试集准确率: {test_accuracy:.2f}%')# 使用测试时间增强(TTA)进行评估
tta_accuracy = test_time_augmentation(model, test_loader, device)
print(f'TTA测试集准确率: {tta_accuracy:.2f}%')# 保存模型权重
os.makedirs(model_save_dir, exist_ok=True)
model_save_path = os.path.join(model_save_dir, 'advanced_model.pkl')
torch.save(model.state_dict(), model_save_path)
print(f'高级模型已保存到: {model_save_path}')# 单张图像预测示例
test_img_path = 'D:/深度学习/手写汉字识别/data/test/00008/12313.png'
try:img = Image.open(test_img_path).convert('RGB')img_transform = transforms.Compose([transforms.Resize((128, 128)),transforms.Grayscale(num_output_channels=1),transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5])])img_tensor = img_transform(img)img_tensor = img_tensor.unsqueeze(0).to(device)model.eval()with torch.no_grad():output = model(img_tensor)_, prediction = torch.max(output, 1)prediction = prediction.cpu().numpy()[0]print(f'单张图像预测结果: {prediction}')
except Exception as e:print(f'预测时出错: {e}')

if name == 'main':
main()

屏幕截图 2025-10-30 212018

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/951142.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Keil仿真条件断点10.30

Keil5软件仿真中可以使用“CTRL+B”打开断点设置界面。条件断点需要仿真器支持,已知ST-LINK可以,以下变量均为全局变量 条件断点1:设置执行多少次后停下。操作如下图所示定义好后如下图访问断点2:读/写变量时停下 …

CSP-S 复赛游记

\(\mathcal Day\ -1\) 2025/10/30 在机房打算写一个分块的模板代码笔记,但是写到一半发现自己爆炸了,好饿。 所以毅然决然地选择偷看HHY在干嘛,发现他在写复赛Day -2游记,所以我说我也要写,然后后面FBT就说今天其…

P6149 [USACO20FEB] Triangles S 总结

P6149 [USACO20FEB] Triangles S 总结P6149 [USACO20FEB] Triangles S 总结 思路历程 这一题还是相当有趣的,首先我们不难发现,题目要求的就是一个两个直角边平行于 \(x\) 和 \(y\) 的直角三角形。 此时我们想到,这…

10.30 程序员的修炼之道:从小工到专家第三章 基本工具 - GENGAR

第三章 “基本工具” 强调程序员需跳出单一 IDE,掌握多元基础工具。第 14 节指出纯文本由可打印字符构成,虽曾因算力存储受限不占优,但如今具备不过时(自描述性可明确信息含义,如标注 SSNO 的社会保障号)、有杠杆…

数据预处理

inputs.fillna(inputs.mean()) mean() 方法只能用于数值型数据,而如果你的 DataFrame 中包含字符串列,就会出现类型不兼容的错误。 解决方法是只对数值型列进行均值填充,忽略字符串列。可以这样修改: 这个错误的原…

学校机房电脑进阶操作

为了能做学校机房内获得更好的体验,本文章可以帮助你更好地改善你在机房使用的电脑。本文章适用于对电脑能熟练使用的人,对电脑基础操作不熟悉的不建议使用。 本内容将持续更新,若有建议,欢迎提出!(由于本文章刚…

AH2022 钥匙

钥匙 洛谷 P8339 钦定当有很多把钥匙能打开开宝箱时使用最后拿到的一把(应该要想想用第一把打开,实际不好做。) 每种颜色 \(col\) 的钥匙和宝箱是互相独立的,可以对每种颜色建出虚树。对于一把钥匙 \(u\),以它为根…

在国内体验 Claude Code 编程助手的可行方案 —— 我的 Evol AI 工作空间实践分享

前言 一直以来,我都在寻找一个能真正提升开发效率的 AI 编程助手。 我曾用过 GitHub Copilot,补全效果不错,但在处理复杂需求、跨文件逻辑时能力不足。后来了解到 Anthropic 的 Claude Code——支持超长上下文(200…

应用安全 --- vmp 之 代码虚拟化

应用安全 --- vmp 之 代码虚拟化所谓代码虚拟化就是用汇编指令模拟cpu的运行方式实现了一套软件虚拟机处理引擎和需要执行的虚拟化字节码 有点类似java和net的实现原理。 他们不同点就是代码虚拟化vmp是为了保护代码的…

Flask 入门:轻量级 Python Web 框架的快速上手 - 指南

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

Java第二天

十类型转换 自动类型转换:小范围可以自动转换为大类型表达式的类型转换:不同的类型相加结果要用最大的那个类型来接收。、强制类型转换:大范围不能直接赋值给小范围。但是通过强制转换是可以的,但是会失真,例如一…

八、认识for循环

1.for循环的概念 循环结构是指在程序中需要反复执行某个功能而设置的一种程序结构。如果条件满足,则重复执行相应的语句,当条件不满足,退出循环。 2.for循环的基本格式 for( 循环变量的初值; 循环条件; 循环变量的增…

CCUT应用OJ——小龙的字符串函数

题目简介题源:1073 - 小龙的字符串函数 | CCUT OJ 题意:给定 \(n\) 个等长字符串,定义函数 \(f(s_i,s_j)\) 表示字符串 \(s_i\) 与 \(s_j\) 中位置和字符相同的总数。输出 \(\sum f(s_i,s_j)\) ( 其中 \(i<j\) )…

OceanBase系列---【oceanbase的oracle模式新增分区表】

OceanBase系列---【oceanbase的oracle模式新增分区表】TIPS分区选择建议 按天分区: 适用于数据量极大(每天千万级以上)、需要频繁删除历史数据的场景 按月分区: 适用于数据量中等(每月百万到千万级)、最常用的分区方式…

cursor 数据路径 防止试用账号误删数据

C:\Users\xxx\.cursor C:\Users\xxx\AppData\Local\Programs\cursor C:\Users\xxx\AppData\Roaming\Cursor C:\Users\xxx\AppData\Roaming\Cursor\User 备份这个路径就行

Bettercap(中间人攻击神器)

Bettercap(中间人攻击神器)https://github.com/bettercap/bettercap/releaseshttps://github.com/bettercap/bettercap/releases/download/v2.41.4/bettercap_windows_amd64.zip 安装完运行会提示缺少 libusb.dll ht…

PHP代码加密方法

1. 新建一个 待加密的php文件:/routes_plain.php 注意不要带“<?php” var_dump(666);2.新建运行加密的文件: /jiami.php $plain = file_get_contents(__DIR__ . /routes_plain.php); // 压缩 + base64 $payload…

why is making friends, love bad

any relationship will let one be unreasonable.

DP题解

[P6772 [NOI2020] 美食家] (https://www.luogu.com.cn/problem/P6772) ZAK解题思路 蒟蒻语 wtcl, 只会最简单的题目 这道题目与 [P6569 NOI Online #3 提高组]魔法值(民间数据) 类似, 都是倍增优化矩阵乘法。 蒟蒻解…

逆序对略解

逆序对 定义 在一个数列中,如果前面的数字大于后面的数字,那么这两个数字就构成了一个逆序对 求逆序对 有3种方法:暴力,归并排序,线段树 1.暴力算法 枚举i和j(i<j),并判断是否满足a[j]<a[i] for(int i=1;i…