CRNN(Convolutional Recurrent Neural Network)是一种用于图像序列识别的端到端可训练神经网络,特别擅长处理场景文本识别任务
。CRNN 的核心架构包括三个主要部分:卷积层(CNN)、循环层(RNN)和转录层(Transcription Layer),结合 CTC(Connectionist Temporal Classification)损失函数实现端到端训练
“端到端(End-to-End)” 是深度学习和人工智能领域中的一个重要概念,指的是从输入数据直接映射到输出结果的模型或系统,无需人工干预或手动设计特征提取等中间步骤。换句话说,端到端模型能够自动学习从原始输入到最终输出的映射关系
CRNN 的端到端序列识别实现
1. 卷积层(CNN)
-
功能:卷积层负责从输入图像中提取局部特征。通常使用经典的 CNN 结构(如 AlexNet、ResNet 等)来实现。
-
输出:卷积层将二维图像特征图转换为一维特征序列,为后续的 RNN 处理做准备。
2. 循环层(RNN)
-
功能:循环层在卷积特征的基础上继续提取序列特征,捕捉字符间的上下文关系。通常使用双向 LSTM(Long Short-Term Memory)或 GRU(Gated Recurrent Unit)等循环神经网络结构。
-
输出:循环层输出的是一系列标签的概率分布。
3. 转录层(Transcription Layer)
-
功能:转录层将循环层输出的标签分布通过去重整合等操作转换成最终的识别结果。使用 CTC 损失函数作为条件概率,允许模型在训练期间处理不定长的序列输出。
-
输出:最终的文本序列,可以直接用于识别任务。
4. CTC 损失函数
-
功能:CTC 损失函数允许模型在训练期间处理不定长的序列输出。它通过扩展标签集并引入空白符,将 RNN 输出的不定长序列映射到标签序列。
-
优势:CTC 损失函数使得整个网络可以端到端地训练,无需复杂的预处理或多步骤处理。
CRNN 的优势
-
端到端训练:CRNN 能够直接从输入图像到输出序列标签进行端到端的训练,无需复杂的预处理或多步骤处理。
-
任意长度序列处理:CRNN 可以处理任意长度的序列,不依赖于字符分割或水平尺度归一化。
-
高效性:CRNN 实现了从图像到字符序列的端到端识别,无需进行复杂的字符分割和定位。
-
准确性:通过结合 CNN 和 RNN 的优势,CRNN 能够有效捕捉图像中的局部特征和字符间的上下文关系,从而提升了识别的准确性。
-
广泛应用:CRNN 不仅在场景文本识别上表现出色,还能够应用于其他图像序列识别任务,如音乐符号识别等。
实现示例(PyTorch)
以下是一个基于 PyTorch 的 CRNN 模型实现示例:
Python
import torch
import torch.nn as nnclass CRNN(nn.Module):def __init__(self, imgH, nc, nclass, nh):super(CRNN, self).__init__()assert imgH % 16 == 0, 'imgH must be a multiple of 16'# CNN部分 (简化版)self.cnn = nn.Sequential(nn.Conv2d(nc, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2, 2),nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2, 2),nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(),nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(), nn.MaxPool2d((2, 2), (2, 1), (0, 1)),)# RNN部分self.rnn = nn.Sequential(BidirectionalLSTM(256, nh, nh),BidirectionalLSTM(nh, nh, nclass))def forward(self, input):# CNN前向传播conv = self.cnn(input)b, c, h, w = conv.size()assert h == 1, "the height of conv must be 1"conv = conv.squeeze(2) # [b, c, w]conv = conv.permute(2, 0, 1) # [w, b, c]# RNN前向传播output = self.rnn(conv)return outputclass BidirectionalLSTM(nn.Module):def __init__(self, nIn, nHidden, nOut):super(BidirectionalLSTM, self).__init__()self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)self.embedding = nn.Linear(nHidden *