Skip to content
Snippets Groups Projects
train-ocr-model.py 8.86 KiB
Newer Older
import tensorflow as tf
Jiale Song's avatar
Jiale Song committed
import numpy as np
import matplotlib.pyplot as plt
import argparse
import cv2
Jiale Song's avatar
Jiale Song committed
import time
import keras
Jiale Song's avatar
Jiale Song committed
import sys
Jiale Song's avatar
Jiale Song committed

Jiale Song's avatar
Jiale Song committed
from keras import layers, models
Jiale Song's avatar
Jiale Song committed
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer

def parse_argument():
    """
    Construct the argument parser and parse the arguments.
    """
    ap = argparse.ArgumentParser()
    ap.add_argument("-d", "--dataset", required=True,
Jiale Song's avatar
Jiale Song committed
                    help="path to English Handwritten Characters dataset")
Jiale Song's avatar
Jiale Song committed
    ap.add_argument("-m", "--model", type=str, default="model.keras",
Jiale Song's avatar
Jiale Song committed
                    help="path to output trained handwriting recognition model")
Jiale Song's avatar
Jiale Song committed
    ap.add_argument("-p", "--plot", type=str, default="plot.png",
Jiale Song's avatar
Jiale Song committed
                    help="path to output training history file")
Jiale Song's avatar
Jiale Song committed
    ap.add_argument('-s', '--show', action='store_true',
                    help='show all details')
    ap.add_argument('-t', '--timer', action='store_true',
                    help='set timer for training data')
Jiale Song's avatar
Jiale Song committed
    args = vars(ap.parse_args())

    return args


def load_eng_dataset(datasetPath):
    """
    Helper function for train model OCR. Function will load English Handwritten
    Characters dataset that should be in given path.
    """
    # initialize the list of data and labels
    data = []
    labels = []

    # loop over the rows of the A-Z handwritten digit dataset
    for row in open(datasetPath):
        # Skip the first row
        if row == "image,label\n":
            continue

        # parse the label and image from the row
        row = row.split(",")
        imagePath = "eng_dataset/" + row[0] # hardcode the path
        try:
            image = cv2.imread(imagePath)
Jiale Song's avatar
Jiale Song committed
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
Jiale Song's avatar
Jiale Song committed
        except cv2.error as e:
            print("[ERROR] loading image ", row[0], " fail")
            continue
        
        label = row[1][:-1] if len(row[1]) > 1 else row[1] # remove '\n' at end

        # update the list of data and labels
        data.append(image)
        labels.append(label)

    # convert the data and labels to NumPy arrays
    data = np.array(data)
    labels = np.array(labels, dtype="U1")
	# return a 2-tuple of the English Handwritten Characters data and labels
    return (data, labels)


Jiale Song's avatar
Jiale Song committed
def load_MNIST_dataset(datasetPath):
    """
    Helper function for train model OCR. Function will load MNIST handwritten
    dataset that should be in given path.
    """
    # initialize the list of data and labels
    data = []
    labels = []

    # loop over the rows of the MNIST handwritten dataset
    for row in open(datasetPath):

        # parse the label and image from the row
        row = row.split(",")
        label = row[-1]
        label = label[0] if len(label) > 1 else label # remove '\n' at end
        image = np.array([int(x) for x in row[0: -1]], dtype="uint8")
        
        # the image's size is 28*28
        image = image.reshape((28,28))

        # update the list of data and labels
        data.append(image)
        labels.append(label)

    # convert the data and labels to NumPy arrays
    data = np.array(data)
    labels = np.array(labels, dtype="U1")
	# return a 2-tuple of the MNIST handwritten data and labels
    return (data, labels)

Jiale Song's avatar
Jiale Song committed
def process_dataset(data, labels):
    """
    Help function to pre-process the dataset for ready to train model.
    """
    # the architecture we're using is designed for 32x32 images,
    # so we need to resize them to 32x32
    data = [cv2.resize(image, (32, 32)) for image in data]
Jiale Song's avatar
Jiale Song committed
    data = np.array(data, dtype="float32")
Jiale Song's avatar
Jiale Song committed

Jiale Song's avatar
Jiale Song committed
    # add a channel dimension to every image in the dataset and 
Jiale Song's avatar
Jiale Song committed
    # data = np.expand_dims(data, axis=-1)
Jiale Song's avatar
Jiale Song committed

    # scale the pixel intensities of the images from [0, 255] down to [0, 1]
    data /= 255.0
Jiale Song's avatar
Jiale Song committed

    # convert the labels from integers to vectors
    le = LabelBinarizer()
    labels = le.fit_transform(labels)

    # account for skew in the labeled data
    classTotals = labels.sum(axis=0)
    classWeight = {}
    # loop over all classes and calculate the class weight
    for i in range(0, len(classTotals)):
        classWeight[i] = classTotals.max() / classTotals[i]

    return data, labels, classWeight


def show_train_data(train_images, train_labels):
    """
    To verify that the dataset looks correct, let's plot the first 25 images from
    the training set and display the class name below each image
    """
    class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 
                   'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'G', 'K', 'L', 'M', 
                   'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
                   'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'g', 'k', 'l', 'm', 
                   'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

    plt.figure(figsize=(10,10))
    for i in range(25):
        plt.subplot(5,5,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(train_images[i])
        # The CIFAR labels happen to be arrays, 
        # which is why you need the extra index
        index = np.where(train_labels[i] == 1)[0][0]
        plt.xlabel(class_names[index])
    plt.show()


Jiale Song's avatar
Jiale Song committed
def show_train_result(history, plot_path):
    """
    Help function to show training result and save result plot.
    """
    plt.plot(history.history['accuracy'], label='accuracy')
    plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.ylim([0, 1])
    plt.legend(loc='lower right')
    plt.savefig(plot_path)
    plt.show()


def save_model(model, model_name):
    """
    Help function to save the model with proper postfix.
    """
    if model_name.endswith(".h5") or model_name.endswith(".keras"):
        model.save(model_name)
    else:
        model.save(model_name + ".keras")


Jiale Song's avatar
Jiale Song committed
if __name__ == "__main__":
    # load arguments
    args = parse_argument()

Jiale Song's avatar
Jiale Song committed
    # initialize the number of epochs to train for, initial learning rate,
    # and batch size
    epochs = 40
    batch_size = 50
    test_size = 0.1
Jiale Song's avatar
Jiale Song committed

Jiale Song's avatar
Jiale Song committed
    # load the English Handwritten Characters datasets
    print("[INFO] loading datasets...")
Jiale Song's avatar
Jiale Song committed
    data, labels = [], []
    if "english.csv" in args["dataset"]:
        (data, labels) = load_eng_dataset(args["dataset"])
    elif "dataset_final.csv" in args["dataset"]:
        batch_size = 500
        test_size = 0.2
Jiale Song's avatar
Jiale Song committed
        (data, labels) = load_MNIST_dataset(args["dataset"])
    else:
        print("[ERROR] given dataset is not implemented, system quit.")
        sys.exit()
Jiale Song's avatar
Jiale Song committed

    # pre-process the data and labels for training
    print("[INFO] pre-processing datasets...")
    data, labels, classWeight = process_dataset(data, labels)

Jiale Song's avatar
Jiale Song committed
    # partition the data into training and testing splits using 90% of
    # the data for training and the remaining 10% for testing
Jiale Song's avatar
Jiale Song committed
    (train_images, test_images, train_labels, test_labels) = train_test_split(data,
        labels, test_size=test_size, stratify=labels, random_state=42)
Jiale Song's avatar
Jiale Song committed
    
    # show train data in plot
Jiale Song's avatar
Jiale Song committed
    if args["show"]:
        show_train_data(train_images, train_labels)
Jiale Song's avatar
Jiale Song committed
    
    # initialize and compile our deep neural network
Jiale Song's avatar
Jiale Song committed
    print("[INFO] compiling model...")
Jiale Song's avatar
Jiale Song committed
    model = models.Sequential([keras.Input(shape=(32, 32, 1))])
    model.add(layers.Conv2D(32, (3, 3), activation='relu'))
Jiale Song's avatar
Jiale Song committed
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))

    model.add(layers.Flatten())
Jiale Song's avatar
Jiale Song committed
    model.add(layers.Dense(128, activation='relu'))
Jiale Song's avatar
Jiale Song committed
    model.add(layers.Dense(62, activation='softmax'))

    # Use categorical_crossentropy for one-hot coding labels
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
Jiale Song's avatar
Jiale Song committed
                  loss=keras.losses.categorical_crossentropy,
                  metrics=['accuracy'])
Jiale Song's avatar
Jiale Song committed

    if args["show"]:
        print(model.summary())

    # Keras needs a channel dimension for the model
    # Since the images are greyscale, the channel can be 1
    train_images = train_images.reshape(-1, 32, 32, 1)
    test_images = test_images.reshape(-1, 32, 32, 1)

Jiale Song's avatar
Jiale Song committed
    # train the network
    print("[INFO] training model...")
Jiale Song's avatar
Jiale Song committed
    start_time = time.time()
Jiale Song's avatar
Jiale Song committed
    history = model.fit(x=train_images, 
                        y=train_labels, 
                        validation_data=(test_images, test_labels), 
Jiale Song's avatar
Jiale Song committed
                        batch_size=batch_size,
                        epochs=epochs,
Jiale Song's avatar
Jiale Song committed
                        class_weight=classWeight)
    
Jiale Song's avatar
Jiale Song committed
    end_time = time.time()
    spend = end_time - start_time
    if args["timer"]:
        print("training time: %.3f seconds" % spend)

Jiale Song's avatar
Jiale Song committed
    # evaluate the network
    print("[INFO] evaluating network...")
Jiale Song's avatar
Jiale Song committed
    show_train_result(history, args["plot"])
    model.evaluate(test_images, test_labels, verbose=2)
Jiale Song's avatar
Jiale Song committed

    # save the model to disk
    print("[INFO] saving trained model...")
Jiale Song's avatar
Jiale Song committed
    save_model(model, args["model"])