基于CRNN的文本识别

  • Post author:
  • Post category:其他




0. 前言

至于CRNN网络的细节这里就不再多言了,网上有很多关于crnn的介绍,这里直接讲一下代码的实现流程



1. 数据集准备

CRNN是识别文本的网络,所以我们首先需要构建数据集,使用26个小写字母以及0到9十个数字,一共有36个字符,从这36个字符中随机选择4到9个字符(这里要说明一下,网上很多关于crnn的训练集中每张图片中的字符个数是一样的,这就具有很大的局限性。所以本文使用4到9随机选择字符个数构建图片。)

生成数据集代码如下:

import cv2
import numpy as np
import random
import imgaug.augmenters as iaa

def get_img():
    zfu=['a','b','c','d','e','f','g','h','i','j','k','l','m','n',
         'o','p','q','r','s','t','u','v','w','x','y','z',
         '0','1','2','3','4','5','6','7','8','9']
    # zfu=[str(i) for i in range(10)]


    # zfu=[str(i) for i in range(10)]
    k=random.randint(4,9)
    select=random.choices(zfu,k=k)
    lab=[zfu.index(i) for i in select]

    select="".join(select)
    font=cv2.FONT_HERSHEY_COMPLEX
    src=np.ones(shape=(50,250,3)).astype('uint8')*255
    src=cv2.putText(src,select,(20,27),font,1,(0,0,0),2)
    seq = iaa.Sequential([
        # iaa.Flipud(0.5),  # flip up and down (vertical)
        # iaa.Fliplr(0.5),  # flip left and right (horizontal)
        iaa.Multiply((0.5, 1.5)),  # change brightness, doesn't affect BBs(bounding boxes)
        iaa.GaussianBlur(sigma=(0, 1.0)),  # 标准差为0到3之间的值

        iaa.Crop(percent=(0, 0.06)),
        iaa.Grayscale(alpha=(0, 1)),
        iaa.Affine(
            #translate_px={"x": (0, 15), "y": (0, 15)},  # 平移
            scale=(0.95, 1.05),  # 尺度变换
            mode=iaa.ia.ALL,
            cval=(100, 255)
        ),
        iaa.Resize({"height": 32, "width": 200})
    ])
    # img是numpy格式,无归一化
    src=np.expand_dims(src,axis=0)
    src = seq(images=src)[0]
    # cv2.imshow('a21',src)
    # cv2.waitKey(0)
    return src,lab

f_train=open('train.txt','w')
f_val=open('val.txt','w')

for i in range(10000):
    img,lab=get_img()
    lab=[str(i) for i in lab]
    lab=" ".join(lab)
    path='train_data/'+str(i)+'.jpg'
    cv2.imwrite(path,img)
    f_train.write(path+' '+lab+'\n')
    print(i)
for i in range(1000):
    img,lab=get_img()
    lab=[str(i) for i in lab]
    lab=" ".join(lab)
    path='val_data/'+str(i)+'.jpg'
    cv2.imwrite(path,img)
    f_val.write(path+' '+lab+'\n')
    print(i)



运行上述代码之前首先需要手动新建两个空文件夹用于存放训练图像和验证图像,文件夹名字分别是:train_data和val_data。运行完上述代码以后会在train_data文件夹中保存10000张训练图像,在val_data文件夹中保存1000张验证图像。此外还会生成两个txt文件,分别为train.txt和val.txt。

txt文本中存放的是图片的路径及包含字符的类别,如下所示:

在这里插入图片描述

部分训练图像如下所示:

在这里插入图片描述



2.构建网络

构建crnn网络的代码如下所示:

# crnn.py
import argparse, os
import torch
import torch.nn as nn


class BidirectionalLSTM(nn.Module):

    def __init__(self, nInput_size, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()

        self.lstm = nn.LSTM(nInput_size, nHidden, bidirectional=True)
        self.linear = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
        recurrent, (hidden, cell) = self.lstm(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)

        output = self.linear(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)  # 输出变换为[seq,batch,类别总数]

        return output


class CNN(nn.Module):

    def __init__(self, imageHeight, nChannel):
        super(CNN, self).__init__()
        assert imageHeight % 32 == 0, 'image Height has to be a multiple of 32'

        self.depth_conv0 = nn.Conv2d(in_channels=nChannel, out_channels=nChannel, kernel_size=3, stride=1, padding=1,
                                     groups=nChannel)
        self.point_conv0 = nn.Conv2d(in_channels=nChannel, out_channels=64, kernel_size=1, stride=1, padding=0,
                                     groups=1)
        self.relu0 = nn.ReLU(inplace=True)
        self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.depth_conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=64)
        self.point_conv1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1)
        self.relu1 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.depth_conv2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=128)
        self.point_conv2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1)
        self.batchNorm2 = nn.BatchNorm2d(256)
        self.relu2 = nn.ReLU(inplace=True)

        self.depth_conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=256)
        self.point_conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1)
        self.relu3 = nn.ReLU(inplace=True)
        self.pool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))

        self.depth_conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=256)
        self.point_conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1)
        self.batchNorm4 = nn.BatchNorm2d(512)
        self.relu4 = nn.ReLU(inplace=True)

        self.depth_conv5 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, groups=512)
        self.point_conv5 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1)
        self.relu5 = nn.ReLU(inplace=True)
        self.pool5 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))

        # self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=2, stride=1, padding=0)
        self.depth_conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=2, stride=1, padding=0, groups=512)
        self.point_conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1)
        self.batchNorm6 = nn.BatchNorm2d(512)
        self.relu6 = nn.ReLU(inplace=True)

    def forward(self, input):
        depth0 = self.depth_conv0(input)
        point0 = self.point_conv0(depth0)
        relu0 = self.relu0(point0)
        pool0 = self.pool0(relu0)
        # print(pool0.size())

        depth1 = self.depth_conv1(pool0)
        point1 = self.point_conv1(depth1)
        relu1 = self.relu1(point1)
        pool1 = self.pool1(relu1)
        # print(pool1.size())

        depth2 = self.depth_conv2(pool1)
        point2 = self.point_conv2(depth2)
        batchNormal2 = self.batchNorm2(point2)
        relu2 = self.relu2(batchNormal2)
        # print(relu2.size())

        depth3 = self.depth_conv3(relu2)
        point3 = self.point_conv3(depth3)
        relu3 = self.relu3(point3)
        pool3 = self.pool3(relu3)
        # print(pool3.size())

        depth4 = self.depth_conv4(pool3)
        point4 = self.point_conv4(depth4)
        batchNormal4 = self.batchNorm4(point4)
        relu4 = self.relu4(batchNormal4)
        # print(relu4.size())

        depth5 = self.depth_conv5(relu4)
        point5 = self.point_conv5(depth5)
        relu5 = self.relu5(point5)
        pool5 = self.pool5(relu5)
        # print(pool5.size())

        depth6 = self.depth_conv6(pool5)
        point6 = self.point_conv6(depth6)
        batchNormal6 = self.batchNorm6(point6)
        relu6 = self.relu6(batchNormal6)
        # print(relu6.size())

        return relu6


class CRNN(nn.Module):
    def __init__(self, imgHeight, nChannel, nClass, nHidden):
        super(CRNN, self).__init__()

        self.cnn = nn.Sequential(CNN(imgHeight, nChannel))
        self.lstm = nn.Sequential(
            BidirectionalLSTM(512, nHidden, nHidden),
            BidirectionalLSTM(nHidden, nHidden, nClass),
        )

    def forward(self, input):
        conv = self.cnn(input)
        # pytorch框架输出结构为BCHW
        batch, channel, height, width = conv.size()
        assert height == 1, "the output height must be 1."
        # 将height==1的维度去掉-->BCW
        conv = conv.squeeze(dim=2)
        # 调整各个维度的位置(B,C,W)->(W,B,C),对应lstm的输入(seq,batch,input_size)
        conv = conv.permute(2, 0, 1)

        output = self.lstm(conv)

        return output


if __name__ == "__main__":
    x = torch.rand(1, 1, 32, 100)
    model = CRNN(imgHeight=32, nChannel=1, nClass=11, nHidden=256)

    y = model(x)

    print(y.shape)



3.数据读取

读取训练数据的代码如下所示:

import os
import torch
import cv2
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import imgaug.augmenters as iaa

class CRNNDataSet(Dataset):
    def __init__(self, lines,train=True,img_width=100):
        super(CRNNDataSet, self).__init__()
        self.lines=lines
        self.train=train
        self.img_width=img_width
        self.T=img_width//4+1


    def __getitem__(self, index):
        image_path = self.lines[index].strip().split()[0]
        label = self.lines[index].strip().split()[1:]

        image = cv2.imread(image_path,0)
        # 图像预处理
        if self.train:
            image=self.get_random_data(image)
        else:
            image = cv2.resize(image,(self.img_width,32))

        # cv2.imshow('a21',image)
        # cv2.waitKey(0)

        # 标签格式转换为IntTensor
        label_max=np.ones(shape=(self.T),dtype=np.int32)*-1
        label = np.array([int(i) for i in label])
        label_max[0:len(label)]=label


        #归一化
        image=(image/255.).astype('float32')
        image=np.expand_dims(image,axis=0)

        image=torch.from_numpy(image)
        label_max=torch.from_numpy(label_max)
        return image, label_max

    def __len__(self):
        return len(self.lines)
    def get_random_data(self,img):
        """随机增强图像"""
        seq = iaa.Sequential([
            iaa.Multiply((0.8, 1.3)),  # change brightness, doesn't affect BBs(bounding boxes)
            iaa.GaussianBlur(sigma=(0, 1.0)),  # 标准差为0到3之间的值
            iaa.Crop(percent=(0, 0.05)),
            iaa.Affine(
                scale=(0.95, 1.05),  # 尺度变换
                rotate=(-4, 4),
                cval=(100,250),
                mode=iaa.ia.ALL),
            iaa.Resize({"height": 32, "width": self.img_width})
        ])
        img=seq.augment(image=img)
        return img

if __name__ == '__main__':
    batch_size = 16
    lines=open('train.txt','r').readlines()
    trainData = CRNNDataSet(lines=lines)
    trainLoader=DataLoader(dataset=trainData,batch_size=batch_size)
    for data, label in trainLoader:
        print(data.shape,label)



4.训练模型

训练代码如下所示:

from model import CRNN
from mydataset import CRNNDataSet
from torch.utils.data import DataLoader
import torch
from torch import optim
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

def decode(preds):
    char_set = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
                'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
                '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] + [" "]
    preds=list(preds)


    pred_text = ''
    for i,j in enumerate(preds):
        if j==n_class-1:
            continue
        if i==0:
            pred_text+=char_set[j]
            continue
        if preds[i-1]!=j:
            pred_text += char_set[j]

    return pred_text
def getAcc(preds,labs):
    acc=0
    char_set = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
                'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
                '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] + [" "]

    labs=labs.cpu().detach().numpy()
    preds = preds.cpu().detach().numpy()
    preds=np.argmax(preds,axis=-1)
    preds=np.transpose(preds,(1,0))

    out=[]
    for pred in preds:
        out_txt=decode(pred)
        out.append(out_txt)

    ll=[]
    for lab in labs:
        a=lab[lab!=-1]
        b=[char_set[i] for i in a]
        b="".join(b)
        ll.append(b)
    for a1,a2 in zip(out,ll):
        if a1==a2:
            acc+=1
    return acc/batch_size

batch_size=128
n_class = 37

train_lines=open('train.txt','r').readlines()
val_lines=open('val.txt','r').readlines()
trainData = CRNNDataSet(lines=train_lines,train=True,img_width=200)
trainLoader = DataLoader(dataset=trainData, batch_size=batch_size, shuffle=True, num_workers=1)
valData = CRNNDataSet(lines=val_lines,train=False,img_width=200)
valLoader = DataLoader(dataset=valData, batch_size=batch_size, shuffle=False, num_workers=1)


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = CRNN(imgHeight=32, nChannel=1, nClass=n_class, nHidden=256)
net=net.to(device)


loss_func = torch.nn.CTCLoss(blank=n_class - 1)  # 注意,这里的CTCLoss中的 blank是指空白字符的位置,在这里是第65个,也即最后一个
optimizer = torch.optim.Adam(net.parameters(), lr=0.001, betas=(0.5, 0.999))
#学习率衰减
lr_scheduler  = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)


#画图列表
trainLoss=[]
valLoss=[]
trainAcc=[]
valAcc=[]
if __name__ == '__main__':

    #设置迭代次数200次
    Epoch=100

    epoch_step = len(train_lines) / batch_size
    for epoch in range(1, Epoch + 1):

        net.train()

        train_total_loss = 0
        val_total_loss=0
        train_total_acc = 0
        val_total_acc = 0

        with tqdm(total=epoch_step, desc=f'Epoch {epoch}/{Epoch}', postfix=dict, mininterval=0.3) as pbar:
            for step, (features, label) in enumerate(trainLoader, 1):
                labels = torch.IntTensor([])
                for j in range(label.size(0)):
                    labels = torch.cat((labels, label[j]), 0)

                labels=labels[labels!=-1]

                features = features.to(device)
                labels = labels.to(device)
                loss_func=loss_func.to(device)
                batch_size = features.size()[0]


                out = net(features)

                log_probs = out.log_softmax(2).requires_grad_()

                targets = labels
                input_lengths = torch.IntTensor([out.size(0)] * int(out.size(1)))
                target_lengths = torch.where(label!=-1,1,0).sum(dim=-1)
                loss = loss_func(log_probs, targets, input_lengths, target_lengths)
                acc=getAcc(out,label)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_total_loss += loss
                train_total_acc += acc
                pbar.set_postfix(**{'loss': train_total_loss.item() / (step),
                                    'acc': train_total_acc/ (step), })
                pbar.update(1)
        trainLoss.append(train_total_loss.item()/step)
        trainAcc.append(train_total_acc/step)

        #保存模型
        torch.save(net.state_dict(), 'model.pth')
        #验证
        net.eval()
        for step, (features, label) in enumerate(valLoader, 1):
            with torch.no_grad():
                labels = torch.IntTensor([])
                for j in range(label.size(0)):
                    labels = torch.cat((labels, label[j]), 0)

                labels = labels[labels != -1]

                features = features.to(device)
                labels = labels.to(device)
                loss_func = loss_func.to(device)
                batch_size = features.size()[0]

                out = net(features)

                log_probs = out.log_softmax(2).requires_grad_()

                targets = labels
                input_lengths = torch.IntTensor([out.size(0)] * int(out.size(1)))
                target_lengths = torch.where(label != -1, 1, 0).sum(dim=-1)
                loss = loss_func(log_probs, targets, input_lengths, target_lengths)
                acc = getAcc(out, label)
                val_total_loss+=loss
                val_total_acc+=acc

        valLoss.append(val_total_loss.item()/step)
        valAcc.append(val_total_acc/step)
        lr_scheduler.step()

        # print(trainLoss)
        # print(valLoss)
    """绘制loss acc曲线图"""
    plt.figure()
    plt.plot(trainLoss, 'r')
    plt.plot(valLoss, 'b')
    plt.title('Training and validation loss')
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend(["Loss", "Validation Loss"])
    plt.savefig('loss.png')

    plt.figure()
    plt.plot(trainAcc, 'r')
    plt.plot(valAcc, 'b')
    plt.title('Training and validation acc')
    plt.xlabel("Epochs")
    plt.ylabel("Acc")
    plt.legend(["Acc", "Validation Acc"])
    plt.savefig('acc.png')
    # plt.show()

acc和loss图如下所示:

在这里插入图片描述

在这里插入图片描述

经过验证发现准确率可达95%以上,效果不错。

整体项目下载地址:

项目下载



版权声明:本文为qq_45087786原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。