基于pytorch卷积神经网络的汉字识别系统

news/2025/11/7 16:46:38/文章来源:https://www.cnblogs.com/hxz1/p/19200193

基于pytorch卷积神经网络的汉字识别系统

源代码如下(pycharm//附运行结果):

import os
import shutil
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.metrics import accuracy_score
import warnings
from tqdm import tqdm # 进度条显示

warnings.filterwarnings('ignore')


# ======================== 1. 配置参数 ========================
class Config:
# 数据路径配置
TXT_PATH = "C:/Users/33946/Downloads/hd_chinese/hd_chinese/train.txt"
RAW_PNG_DIR = "C:/Users/33946/Downloads/hd_chinese/hd_chinese/test_data"
OUTPUT_DATASET_ROOT = "C:/Users/33946/Downloads/hd_chinese/hd_chinese/dataset"

# 训练参数配置
IMAGE_SIZE = (64, 64)
BATCH_SIZE = 64 # GPU可用时用64,CPU用32
EPOCHS = 100
LR = 1e-4
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
SAVE_DIR = "saved_models"
SAVE_INTERVAL = 10
ROTATION_DEGREES = 5
TRANSLATE = (0.05, 0.05)


# 创建必要目录
os.makedirs(Config.SAVE_DIR, exist_ok=True)


# ======================== 2. 数据集处理 ========================
def process_train_txt_and_generate_dataset():
print("===== 开始处理数据集 =====")
for split in ['train', 'val', 'test']:
os.makedirs(os.path.join(Config.OUTPUT_DATASET_ROOT, split), exist_ok=True)

data = []
with open(Config.TXT_PATH, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
png_rel_path, text = line.split('\t', 1)
png_filename = os.path.basename(png_rel_path)
if text:
first_char = text[0]
data.append((png_filename, first_char))
else:
print(f"⚠️ 跳过空文本:{png_rel_path}")

char_groups = {}
for png_filename, first_char in data:
if first_char not in char_groups:
char_groups[first_char] = []
char_groups[first_char].append(png_filename)

total_images = 0
for char, png_list in char_groups.items():
random.shuffle(png_list)
total = len(png_list)
total_images += total
train_num = int(total * 0.7)
val_num = int(total * 0.2)

for i, png_filename in enumerate(png_list):
src_path = os.path.join(Config.RAW_PNG_DIR, png_filename)
if not os.path.exists(src_path):
print(f"⚠️ 跳过不存在的文件:{src_path}")
continue

if i < train_num:
split = 'train'
elif i < train_num + val_num:
split = 'val'
else:
split = 'test'

dst_dir = os.path.join(Config.OUTPUT_DATASET_ROOT, split, char)
os.makedirs(dst_dir, exist_ok=True)
shutil.copy(src_path, os.path.join(dst_dir, png_filename))

print(f"✅ 数据集处理完成!共处理 {total_images} 张图像,{len(char_groups)} 个汉字类别")
print(f" 数据集目录:{Config.OUTPUT_DATASET_ROOT}")
return char_groups


# 仅首次运行时处理数据集,后续可注释
char_groups = process_train_txt_and_generate_dataset()

# ======================== 3. 数据加载 ========================
CHINESE_CHARS = sorted(char_groups.keys())
CHAR_TO_IDX = {char: idx for idx, char in enumerate(CHINESE_CHARS)}
IDX_TO_CHAR = {idx: char for idx, char in enumerate(CHINESE_CHARS)}
NUM_CLASSES = len(CHINESE_CHARS)
print(f"\n===== 模型配置 =====")
print(f" 识别类别数:{NUM_CLASSES},示例汉字:{CHINESE_CHARS[:10]}...")


class ChineseCharDataset(Dataset):
def __init__(self, data_dir, char_to_idx, transform=None):
self.data_dir = data_dir
self.char_to_idx = char_to_idx
self.transform = transform
self.image_paths = []
self.labels = []

for char in os.listdir(data_dir):
char_dir = os.path.join(data_dir, char)
if not os.path.isdir(char_dir) or char not in char_to_idx:
continue
for img_name in os.listdir(char_dir):
if img_name.endswith(".png"):
self.image_paths.append(os.path.join(char_dir, img_name))
self.labels.append(char_to_idx[char])

def __len__(self):
return len(self.image_paths)

def __getitem__(self, idx):
img = Image.open(self.image_paths[idx]).convert("L")
label = self.labels[idx]
if self.transform:
img = self.transform(img)
return img, torch.tensor(label, dtype=torch.long)


def get_transforms():
train_transform = transforms.Compose([
transforms.Resize(Config.IMAGE_SIZE),
transforms.RandomRotation(Config.ROTATION_DEGREES),
transforms.RandomAffine(0, translate=Config.TRANSLATE),
transforms.RandomResizedCrop(Config.IMAGE_SIZE, scale=(0.9, 1.0)),
transforms.ToTensor(),
transforms.RandomErasing(p=0.1, scale=(0.02, 0.05)),
transforms.Normalize(mean=[0.5], std=[0.5])
])
val_test_transform = transforms.Compose([
transforms.Resize(Config.IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
return train_transform, val_test_transform


train_transform, val_test_transform = get_transforms()
train_dir = os.path.join(Config.OUTPUT_DATASET_ROOT, "train")
val_dir = os.path.join(Config.OUTPUT_DATASET_ROOT, "val")
test_dir = os.path.join(Config.OUTPUT_DATASET_ROOT, "test")

train_dataset = ChineseCharDataset(train_dir, CHAR_TO_IDX, train_transform)
val_dataset = ChineseCharDataset(val_dir, CHAR_TO_IDX, val_test_transform)
test_dataset = ChineseCharDataset(test_dir, CHAR_TO_IDX, val_test_transform)

# Windows系统禁用多进程(解决路径问题)
train_loader = DataLoader(
train_dataset,
batch_size=Config.BATCH_SIZE,
shuffle=True,
num_workers=0,
pin_memory=True if Config.DEVICE.type == 'cuda' else False
)
val_loader = DataLoader(
val_dataset,
batch_size=Config.BATCH_SIZE,
shuffle=False,
num_workers=0,
pin_memory=True if Config.DEVICE.type == 'cuda' else False
)
test_loader = DataLoader(
test_dataset,
batch_size=Config.BATCH_SIZE,
shuffle=False,
num_workers=0,
pin_memory=True if Config.DEVICE.type == 'cuda' else False
)

print(f"\n===== 数据集加载 =====")
print(f" 训练集:{len(train_dataset)} 张图像")
print(f" 验证集:{len(val_dataset)} 张图像")
print(f" 测试集:{len(test_dataset)} 张图像")


# ======================== 4. 模型定义 ========================
class ImprovedChineseCharCNN(nn.Module):
def __init__(self, num_classes):
super(ImprovedChineseCharCNN, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout(0.05),

nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout(0.05),

nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout(0.05)
)

dummy = torch.randn(1, 1, Config.IMAGE_SIZE[0], Config.IMAGE_SIZE[1])
self.fc_input_dim = self.conv_layers(dummy).view(1, -1).size(1)

self.fc_layers = nn.Sequential(
nn.Linear(self.fc_input_dim, 1024),
nn.ReLU(inplace=True),
nn.BatchNorm1d(1024),
nn.Dropout(0.2),

nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(0.2),

nn.Linear(512, num_classes)
)

def forward(self, x):
x = self.conv_layers(x)
x = x.view(x.size(0), -1)
x = self.fc_layers(x)
return x


model = ImprovedChineseCharCNN(NUM_CLASSES).to(Config.DEVICE)
print(f"\n===== 模型信息 =====")
print(f" 设备:{Config.DEVICE}")
print(f" 模型结构:{model}")


# ======================== 5. 训练与评估函数 ========================
def train_one_epoch(model, train_loader, criterion, optimizer, device):
model.train()
total_loss, all_preds, all_labels = 0.0, [], []
for images, labels in tqdm(train_loader, desc="训练中", leave=False):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)

optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item() * images.size(0)
all_preds.extend(torch.argmax(outputs, 1).cpu().numpy())
all_labels.extend(labels.cpu().numpy())

avg_loss = total_loss / len(train_loader.dataset)
acc = accuracy_score(all_labels, all_preds)
return avg_loss, acc


def evaluate(model, dataloader, criterion, device, split="验证"):
model.eval()
total_loss, all_preds, all_labels = 0.0, [], []
with torch.no_grad():
for images, labels in tqdm(dataloader, desc=f"{split}中", leave=False):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)

total_loss += loss.item() * images.size(0)
all_preds.extend(torch.argmax(outputs, 1).cpu().numpy())
all_labels.extend(labels.cpu().numpy())

avg_loss = total_loss / len(dataloader.dataset)
acc = accuracy_score(all_labels, all_preds)
return avg_loss, acc


# ======================== 新增:输出识别文字结果 ========================
def print_recognition_results(model, dataloader, device, idx_to_char, num_samples=5):
"""随机打印指定数量样本的识别结果(预测文字 vs 真实文字)"""
model.eval()
samples_shown = 0
# 随机打乱数据顺序,避免每次打印相同样本
random_indices = random.sample(range(len(dataloader.dataset)), min(num_samples, len(dataloader.dataset)))

with torch.no_grad():
for idx in random_indices:
# 获取单个样本
image, label = dataloader.dataset[idx]
image = image.unsqueeze(0).to(device) # 增加批次维度
output = model(image)
pred_idx = torch.argmax(output, 1).cpu().item() # 预测索引
true_idx = label.item() # 真实索引

# 转换为文字
pred_char = idx_to_char[pred_idx]
true_char = idx_to_char[true_idx]

# 打印结果
print(f"样本 {samples_shown + 1}:预测='{pred_char}',真实='{true_char}',"
f"{'✅' if pred_char == true_char else '❌'}")
samples_shown += 1
if samples_shown >= num_samples:
break


# ======================== 6. 主训练函数(支持断点续训) ========================
def main_train(load_from_checkpoint=True, checkpoint_path="saved_models/best_model.pth"):
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
model.parameters(),
lr=Config.LR,
weight_decay=1e-4
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='max',
patience=3,
factor=0.5
)

best_val_acc = 0.0
start_epoch = 1

if load_from_checkpoint and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model_state_dict"])
if "optimizer_state_dict" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if "val_acc" in checkpoint:
best_val_acc = checkpoint["val_acc"]
if "epoch" in checkpoint:
start_epoch = checkpoint["epoch"] + 1
print(f"📌 已加载历史模型,从第{start_epoch}轮继续训练(历史最佳准确率:{best_val_acc:.4f})")

print(f"\n===== 开始训练 =====")
for epoch in range(start_epoch, Config.EPOCHS + 1):
train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, Config.DEVICE)
val_loss, val_acc = evaluate(model, val_loader, criterion, Config.DEVICE, split="验证")

scheduler.step(val_acc)

print(f"Epoch [{epoch:3d}/{Config.EPOCHS}] | "
f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | "
f"LR: {optimizer.param_groups[0]['lr']:.6f}")

if epoch % Config.SAVE_INTERVAL == 0:
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"val_acc": val_acc
}, os.path.join(Config.SAVE_DIR, f"model_epoch_{epoch}.pth"))
print(f"💾 已保存第{epoch}轮模型")

if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"val_acc": best_val_acc,
"char_to_idx": CHAR_TO_IDX,
"idx_to_char": IDX_TO_CHAR
}, os.path.join(Config.SAVE_DIR, "best_model.pth"))
print(f"🌟 最佳模型更新(Val Acc: {best_val_acc:.4f})")

# 训练完成后测试并输出识别结果
best_model_path = os.path.join(Config.SAVE_DIR, "best_model.pth")
if os.path.exists(best_model_path):
best_model = torch.load(best_model_path)
model.load_state_dict(best_model["model_state_dict"])
test_loss, test_acc = evaluate(model, test_loader, criterion, Config.DEVICE, split="测试")
print(f"\n===== 训练完成 =====")
print(f" 测试集准确率:{test_acc:.4f}")

# 调用新增函数,输出5个样本的识别文字
print(f"\n===== 随机抽取5个测试样本的识别结果 =====")
print_recognition_results(model, test_loader, Config.DEVICE, IDX_TO_CHAR, num_samples=5)
else:
print("\n⚠️ 未找到最佳模型文件")


# ======================== 启动训练 == ======================
if __name__ == "__main__":
main_train(load_from_checkpoint=True)

////准确率达90%以上////

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/958996.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

制图-学习日志

lth 开始时间:2025-11-07 更新时间:2025-11-07 QGIS\Aerialod{{image.png(uploading...)}}

2025年热门成人自考机构推荐

摘要 2025年,成人自考行业持续蓬勃发展,随着职场竞争加剧和终身学习理念普及,越来越多成年人选择通过自考提升学历。本文基于行业数据和用户口碑,为您推荐2025年热门成人自考机构TOP5排行,并附上详细评测,帮助您…

实用指南:手写MyBatis第95弹:调试追踪MyBatis SQL执行流程的终极指南

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

SOCKS5代理:通用性与协议覆盖

核心优势: 协议通用性:标准化转发能力,支持 TCP 与 UDP,适配混合协议场景[1] 客户端兼容性:对浏览器、数据库客户端、消息队列、实时业务等多样化客户端友好 连接灵活性:长连接与会话保持策略灵活,适合持续链路…

口碑好的成人自考机构2025年推荐榜单

摘要 2025年,成人自考行业持续蓬勃发展,随着职场竞争加剧和终身学习理念普及,越来越多在职人士选择通过自考提升学历。行业数据显示,中国成人自考市场规模年增长率超15%,需求主要集中在灵活学习、可靠服务和高效拿…

2025年国内成人自考机构口碑推荐排行榜单:选择指南与深度解析

摘要 2025年成人自考行业持续增长,越来越多在职人士选择自考提升学历,以应对职场竞争。本文基于权威数据和用户口碑,为您推荐top5成人自考机构,重点介绍排名第一的机构优势,并提供表单参考,助您高效选择。行业发…

2025 年 11 月除锈剂厂家推荐排行榜,钢铁除锈剂,金属除锈剂,钢材除锈剂,不锈钢除锈剂,螺丝除锈剂,弹簧除锈剂,铝型材除锈剂公司推荐

在金属加工制造领域,除锈剂作为表面处理的关键材料,其性能直接影响产品质量和生产效率。随着工业技术迭代升级,除锈剂产品已从基础防锈功能发展为具备多功能特性的专业化学品,针对不同金属材质和应用场景的需求差异…

CANopen转Profinet是一种构建于控制局域网设备之上的协议网关

CANopen转Profinet是一种构建于控制局域网设备之上的协议网关 CANopen作为构建于控制局域网(Controller Area Network, CAN)之上的高层通信协议,其体系架构包含通信子协议与设备子协议。此协议在嵌入式系统领域获得…

2025 年 11 月喷头漏墨维修厂家推荐排行榜,理光喷头漏墨,京瓷喷头漏墨,精工喷头漏墨,喷绘机喷头漏墨维修公司推荐

在工业喷墨打印领域,喷头漏墨是影响生产效率和打印质量的关键问题。随着喷墨技术在陶瓷装饰、广告喷绘、工业标识等行业的广泛应用,喷头漏墨故障已成为设备维护中的常见挑战。不同品牌的喷头,如理光、京瓷、精工等,…

Cohen‘s Kappa系数:衡量分类一致性的黄金标准及其在NLP中的应用 - 实践

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

2025年国内成人自考机构口碑推荐榜单:如何选择靠谱的学历提升平台

摘要 随着终身学习理念的深入,2025年成人自考市场呈现快速增长态势,学历提升需求持续旺盛。本文基于行业数据和用户口碑,为您精选国内优质的成人自考机构,并提供详细的对比分析。本文还包含机构推荐表单,供有需要…

Spring Cloud Alibaba + Sentinel

Sentinel 在微服务世界里,每个服务就像一个小摊位,生意火爆时,人流汹涌,如果没有保护措施,小摊很容易被“压垮”。这时候,你就需要 Sentinel——微服务界的“护身符”,帮你抵御流量暴击、保护系统稳定运行。 本…

2025年11月星光喷头厂家推荐排行榜:专业选购与维护指南

在工业喷墨打印领域,星光喷头作为核心部件,其性能稳定性与使用寿命直接影响生产效率和产品质量。随着陶瓷、纺织、包装等行业的快速发展,对星光喷头1024、1024MC、1024SC、1024LA、1024MA、SA、XSA、XSC、600DPI等型…

德鲁克管理哲学:管理是知行统一的实践创新 - 详解

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

2025 年 11 月食堂承包公司推荐排行榜,食堂承包商,食堂承包方案,大型食堂承包,专业餐饮服务与高效运营管理口碑之选

在当今快节奏的社会环境中,食堂承包服务已成为企业、学校、医院等机构后勤保障的重要组成部分。专业的食堂承包公司不仅能够提供多样化的餐饮方案,还能通过科学的管理体系确保食品安全与运营效率。随着行业标准的不断…

2025年双组份喷涂泵定做厂家权威推荐榜单:双组份喷漆机专用喷枪/无气喷涂机/高压无气喷涂泵专用喷枪源头厂家精选

在工业涂装领域,双组份喷涂泵作为精密涂覆的核心设备,其定制化能力与稳定性直接影响涂层质量与生产成本。行业数据显示,2025年全球双组份涂装设备市场规模增长率预计达12%,其中定制化泵组在汽车、航空航天等高端制…

智能充气泵方案:充气泵电机怎么选?怎么适配

这个问题切得很准,直接命中充气泵核心动力单元的选型关键!充气泵电机选型核心是“匹配充气需求+适配PCBA驱动”,需先按场景定电机类型,再通过参数匹配、驱动适配实现稳定运行。一、电机选型:先定类型,再挑参数1.…

智能家居产品品牌推荐排行2025:权威榜单揭晓

文章摘要 智能家居行业在2025年持续高速发展,全球市场规模预计突破1500亿美元,中国品牌凭借技术创新和成本优势占据重要地位。本文基于行业数据、用户口碑和技术评测,为您呈现2025年智能家居产品品牌推荐排行前十榜…

2025 年 11 月电弧故障保护器厂家推荐排行榜,断路器/检测断路器,并联/串联电弧故障保护器,防火限流式保护器,故障电弧探测器公司推荐

一、行业背景与发展现状随着现代电力系统复杂度的不断提升,电气火灾防护已成为工业安全领域的重要课题。电弧故障保护器作为电气防火系统的核心组件,通过检测线路中的异常电弧现象,及时切断故障电路,有效预防电气火…

2025 年 11 月食堂送菜平台推荐排行榜,送菜上门,食堂送菜公司,饭堂送菜平台,专业高效与新鲜直达服务口碑之选

随着现代餐饮服务行业的快速发展,食堂送菜平台作为连接农产品供应链与终端消费的重要环节,正发挥着日益关键的作用。近年来,随着食品安全意识的提升和配送效率要求的提高,专业送菜上门服务已成为企业食堂、学校机构…