Skip to content
Snippets Groups Projects
train-ocr-model.py 5.45 KiB
Newer Older
Jiale Song's avatar
Jiale Song committed
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import argparse
import cv2

from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import SGD
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer

# initialize the number of epochs to train for, initial learning rate,
# and batch size
EPOCHS = 50
INIT_LR = 1e-1
BS = 128

def parse_argument():
    """
    Construct the argument parser and parse the arguments.
    """
    ap = argparse.ArgumentParser()
    ap.add_argument("-d", "--dataset", required=True,
        help="path to English Handwritten Characters dataset")
    ap.add_argument("-m", "--model", type=str, required=True,
        help="path to output trained handwriting recognition model")
    ap.add_argument("-p", "--plot", type=str, default="plot.png",
        help="path to output training history file")
    args = vars(ap.parse_args())

    return args


def load_eng_dataset(datasetPath):
    """
    Helper function for train model OCR. Function will load English Handwritten
    Characters dataset that should be in given path.
    """
    # initialize the list of data and labels
    data = []
    labels = []

    # loop over the rows of the A-Z handwritten digit dataset
    for row in open(datasetPath):
        # Skip the first row
        if row == "image,label\n":
            continue

        # parse the label and image from the row
        row = row.split(",")
        imagePath = "eng_dataset/" + row[0] # hardcode the path
        try:
            image = cv2.imread(imagePath)
        except cv2.error as e:
            print("[ERROR] loading image ", row[0], " fail")
            continue
        
        label = row[1][:-1] if len(row[1]) > 1 else row[1] # remove '\n' at end

        # update the list of data and labels
        data.append(image)
        labels.append(label)

    # convert the data and labels to NumPy arrays
    data = np.array(data)
    labels = np.array(labels, dtype="U1")
	# return a 2-tuple of the English Handwritten Characters data and labels
    return (data, labels)


def process_dataset(data, labels):
    """
    Help function to pre-process the dataset for ready to train model.
    """
    # the architecture we're using is designed for 32x32 images,
    # so we need to resize them to 32x32
    data = [cv2.resize(image, (32, 32)) for image in data]
    data = np.array(data)

    # add a channel dimension to every image in the dataset and scale the
    # pixel intensities of the images from [0, 255] down to [0, 1]
    # data = np.expand_dims(data, axis=-1)
    # data /= 255.0

    # convert the labels from integers to vectors
    le = LabelBinarizer()
    labels = le.fit_transform(labels)

    # account for skew in the labeled data
    classTotals = labels.sum(axis=0)
    classWeight = {}
    # loop over all classes and calculate the class weight
    for i in range(0, len(classTotals)):
        classWeight[i] = classTotals.max() / classTotals[i]

    return data, labels, classWeight


def show_train_data(train_images, train_labels):
    """
    To verify that the dataset looks correct, let's plot the first 25 images from
    the training set and display the class name below each image
    """
    class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 
                   'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'G', 'K', 'L', 'M', 
                   'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
                   'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'g', 'k', 'l', 'm', 
                   'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

    plt.figure(figsize=(10,10))
    for i in range(25):
        plt.subplot(5,5,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(train_images[i])
        # The CIFAR labels happen to be arrays, 
        # which is why you need the extra index
        index = np.where(train_labels[i] == 1)[0][0]
        plt.xlabel(class_names[index])
    plt.show()


if __name__ == "__main__":
    # load arguments
    args = parse_argument()

    # load the English Handwritten Characters datasets
    print("[INFO] loading datasets...")
    (data, labels) = load_eng_dataset(args["dataset"])

    # pre-process the data and labels for training
    print("[INFO] pre-processing datasets...")
    data, labels, classWeight = process_dataset(data, labels)

    # partition the data into training and testing splits using 80% of
    # the data for training and the remaining 20% for testing
    (train_images, test_images, train_labels, test_labels) = train_test_split(data,
        labels, test_size=0.20, stratify=labels, random_state=42)
    
    # show train data in plot
    show_train_data(train_images, train_labels)
    
    # construct the image generator for data augmentation
    # aug = ImageDataGenerator(
    #     rotation_range=10,
    #     zoom_range=0.05,
    #     width_shift_range=0.1,
    #     height_shift_range=0.1,
    #     shear_range=0.15,
    #     horizontal_flip=False,
    #     fill_mode="nearest")
    
    # initialize and compile our deep neural network
    # print("[INFO] compiling model...")
    # opt = SGD(lr=INIT_LR, decay=INIT_LR / EPOCHS)
    # model = ResNet.build(32, 32, 1, len(le.classes_), (3, 3, 3),
    #     (64, 64, 128, 256), reg=0.0005)
    # model.compile(loss="categorical_crossentropy", optimizer=opt,
    #     metrics=["accuracy"])