import os
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
解决OMP冲突
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
device = torch.device('cpu')
配置参数
CHARS = ['一', '二', '三', '十', '人', '口', '手', '日', '月', '水']
TRAIN_NUM = 200
TEST_NUM = 50
IMG_SIZE = 64
DATA_SAVE_DIR = 'hanzi_data'
BATCH_SIZE = 32
EPOCHS = 30
LEARNING_RATE = 0.005
class HanziDatasetGenerator:
def init(self):
# 不依赖系统字体,使用PIL的默认字体+手动调整位置确保汉字显示
self.font = ImageFont.load_default()
print("提示:使用默认字体生成汉字(可能显示较简单,但能保证运行)")
# 【修复1】补全self参数
def _generate_single_img(self, char):"""生成简单但可区分的汉字图像"""img = Image.new('L', (IMG_SIZE, IMG_SIZE), color=255) # 白底draw = ImageDraw.Draw(img)# 针对默认字体调整位置(确保汉字完整显示)char_offsets = {'一': (5, 25), '二': (5, 15), '三': (5, 10),'十': (20, 15), '人': (10, 20), '口': (15, 15),'手': (5, 10), '日': (15, 15), '月': (10, 15), '水': (5, 10)}x, y = char_offsets[char]# 固定较大字体尺寸,确保笔画清晰font_size = 40try:# 再次尝试系统字体,失败则用默认font = ImageFont.truetype('simsun.ttc', size=font_size) # 尝试宋体draw.text((x, y), char, font=font, fill=0, stroke_width=2)except:# 用默认字体,手动加粗笔画确保可区分draw.text((x, y), char, font=self.font, fill=0, stroke_width=3)# 二次绘制增强笔画(避免默认字体太细)draw.text((x + 1, y), char, font=self.font, fill=0, stroke_width=2)# 轻微旋转增加差异rotation = random.randint(-10, 10)img = img.rotate(rotation, expand=False, fillcolor=255)return img# 【修复2】补全self参数
def generate_dataset(self):"""生成数据集目录和图片"""if os.path.exists(DATA_SAVE_DIR):for root, dirs, files in os.walk(DATA_SAVE_DIR, topdown=False):for f in files:os.remove(os.path.join(root, f))for d in dirs:os.rmdir(os.path.join(root, d))os.rmdir(DATA_SAVE_DIR)# 创建目录for split in ['train', 'test']:for char in CHARS:os.makedirs(os.path.join(DATA_SAVE_DIR, split, char), exist_ok=True)# 生成样本print("生成数据集...")for char in CHARS:for i in range(TRAIN_NUM):img = self._generate_single_img(char)img.save(os.path.join(DATA_SAVE_DIR, 'train', char, f'{i}.png'))for i in range(TEST_NUM):img = self._generate_single_img(char)img.save(os.path.join(DATA_SAVE_DIR, 'test', char, f'{i}.png'))print(f"数据集生成完成:{os.path.abspath(DATA_SAVE_DIR)}")
class HanziDataset(Dataset):
# 【修复3】补全__init__双下划线和self参数
def init(self, split='train'):
self.split = split
self.data_dir = os.path.join(DATA_SAVE_DIR, split)
self.char_list = CHARS
self.char2idx = {c: i for i, c in enumerate(self.char_list)}
self.images, self.labels = self._load_data()
self.transform = transforms.ToTensor()
# 【修复4】补全self参数
def _load_data(self):images = []labels = []for char in self.char_list:char_dir = os.path.join(self.data_dir, char)for img_name in os.listdir(char_dir):images.append(os.path.join(char_dir, img_name))labels.append(self.char2idx[char])return images, labelsdef __len__(self):return len(self.images)def __getitem__(self, idx):img = Image.open(self.images[idx]).convert('L')return self.transform(img), self.labels[idx]
class FeatureCNN(nn.Module):
# 【修复5】补全__init__双下划线和self参数;修正Sequential缩进
def init(self, num_classes=10):
super(FeatureCNN, self).init() # 【修复6】补全super()调用格式
self.features = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # 64→32
nn.Conv2d(8, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2) # 32→16
)
self.classifier = nn.Linear(16 * 16 * 16, num_classes) # 【修复7】修正缩进
# 【修复8】补全self参数
def forward(self, x):x = self.features(x)x = x.view(-1, 16 * 16 * 16)x = self.classifier(x)return x
def main():
# 生成数据集(关键:即使没有中文字体也能生成可区分的图像)
generator = HanziDatasetGenerator()
generator.generate_dataset()
# 加载数据(【修复9】修正缩进)
train_dataset = HanziDataset('train')
test_dataset = HanziDataset('test')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)# 模型与优化器
model = FeatureCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)# 训练
print("\n开始训练...")
best_acc = 0.0
for epoch in range(EPOCHS):model.train()total_loss = 0.0for imgs, labels in train_loader:imgs, labels = imgs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(imgs)loss = criterion(outputs, labels)loss.backward()optimizer.step()total_loss += loss.item() * imgs.size(0)avg_loss = total_loss / len(train_dataset)# 测试model.eval()correct = 0total = 0with torch.no_grad():for imgs, labels in test_loader:imgs, labels = imgs.to(device), labels.to(device)outputs = model(imgs)_, preds = torch.max(outputs, 1)total += labels.size(0)correct += (preds == labels).sum().item()acc = 100 * correct / totalprint(f"轮次{epoch + 1:2d} | 损失:{avg_loss:.4f} | 准确率:{acc:.2f}%")if acc > best_acc:best_acc = acctorch.save(model.state_dict(), 'best_model.pth')if acc >= 85:print(f"达标!准确率:{acc:.2f}%")break# 识别
model.load_state_dict(torch.load('best_model.pth'))
print(f"\n最优准确率:{best_acc:.2f}%")while True:path = input("\n输入图片路径(q退出):")if path.lower() == 'q':breakif not os.path.exists(path):print("路径错误")continuetry:img = Image.open(path).convert('L').resize((64, 64))img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)with torch.no_grad():output = model(img_tensor)pred_char = CHARS[torch.argmax(output).item()]confidence = torch.softmax(output, dim=1).max().item() * 100print(f"识别结果:{pred_char} | 可信度:{confidence:.2f}%")except Exception as e:print(f"错误:{e}")
if name == "main":
main()