Skip to content
Snippets Groups Projects
train-ocr-model.py 6.57 KiB
Newer Older
Jiale Song's avatar
Jiale Song committed
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import argparse
import cv2

Jiale Song's avatar
Jiale Song committed
from keras import layers, models
Jiale Song's avatar
Jiale Song committed
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer

# initialize the number of epochs to train for, initial learning rate,
# and batch size
EPOCHS = 50
Jiale Song's avatar
Jiale Song committed
BS = 64
Jiale Song's avatar
Jiale Song committed

def parse_argument():
    """
    Construct the argument parser and parse the arguments.
    """
    ap = argparse.ArgumentParser()
    ap.add_argument("-d", "--dataset", required=True,
Jiale Song's avatar
Jiale Song committed
                    help="path to English Handwritten Characters dataset")
    ap.add_argument("-m", "--model", type=str, default="model",
                    help="path to output trained handwriting recognition model")
Jiale Song's avatar
Jiale Song committed
    ap.add_argument('-s', '--show', action='store_true',
Jiale Song's avatar
Jiale Song committed
                    help='show all details')
Jiale Song's avatar
Jiale Song committed
    ap.add_argument("-p", "--plot", type=str, default="plot.png",
Jiale Song's avatar
Jiale Song committed
                    help="path to output training history file")
Jiale Song's avatar
Jiale Song committed
    args = vars(ap.parse_args())

    return args


def load_eng_dataset(datasetPath):
    """
    Helper function for train model OCR. Function will load English Handwritten
    Characters dataset that should be in given path.
    """
    # initialize the list of data and labels
    data = []
    labels = []

    # loop over the rows of the A-Z handwritten digit dataset
    for row in open(datasetPath):
        # Skip the first row
        if row == "image,label\n":
            continue

        # parse the label and image from the row
        row = row.split(",")
        imagePath = "eng_dataset/" + row[0] # hardcode the path
        try:
            image = cv2.imread(imagePath)
        except cv2.error as e:
            print("[ERROR] loading image ", row[0], " fail")
            continue
        
        label = row[1][:-1] if len(row[1]) > 1 else row[1] # remove '\n' at end

        # update the list of data and labels
        data.append(image)
        labels.append(label)

    # convert the data and labels to NumPy arrays
    data = np.array(data)
    labels = np.array(labels, dtype="U1")
	# return a 2-tuple of the English Handwritten Characters data and labels
    return (data, labels)


def process_dataset(data, labels):
    """
    Help function to pre-process the dataset for ready to train model.
    """
    # the architecture we're using is designed for 32x32 images,
    # so we need to resize them to 32x32
    data = [cv2.resize(image, (32, 32)) for image in data]
Jiale Song's avatar
Jiale Song committed
    data = np.array(data, dtype="float32")
Jiale Song's avatar
Jiale Song committed

Jiale Song's avatar
Jiale Song committed
    # add a channel dimension to every image in the dataset and 
Jiale Song's avatar
Jiale Song committed
    # data = np.expand_dims(data, axis=-1)
Jiale Song's avatar
Jiale Song committed

    # scale the pixel intensities of the images from [0, 255] down to [0, 1]
    data /= 255.0
Jiale Song's avatar
Jiale Song committed

    # convert the labels from integers to vectors
    le = LabelBinarizer()
    labels = le.fit_transform(labels)

    # account for skew in the labeled data
    classTotals = labels.sum(axis=0)
    classWeight = {}
    # loop over all classes and calculate the class weight
    for i in range(0, len(classTotals)):
        classWeight[i] = classTotals.max() / classTotals[i]

    return data, labels, classWeight


def show_train_data(train_images, train_labels):
    """
    To verify that the dataset looks correct, let's plot the first 25 images from
    the training set and display the class name below each image
    """
    class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 
                   'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'G', 'K', 'L', 'M', 
                   'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
                   'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'g', 'k', 'l', 'm', 
                   'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

    plt.figure(figsize=(10,10))
    for i in range(25):
        plt.subplot(5,5,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(train_images[i])
        # The CIFAR labels happen to be arrays, 
        # which is why you need the extra index
        index = np.where(train_labels[i] == 1)[0][0]
        plt.xlabel(class_names[index])
    plt.show()


if __name__ == "__main__":
    # load arguments
    args = parse_argument()

    # load the English Handwritten Characters datasets
    print("[INFO] loading datasets...")
    (data, labels) = load_eng_dataset(args["dataset"])

    # pre-process the data and labels for training
    print("[INFO] pre-processing datasets...")
    data, labels, classWeight = process_dataset(data, labels)

    # partition the data into training and testing splits using 80% of
    # the data for training and the remaining 20% for testing
    (train_images, test_images, train_labels, test_labels) = train_test_split(data,
        labels, test_size=0.20, stratify=labels, random_state=42)
    
    # show train data in plot
Jiale Song's avatar
Jiale Song committed
    if args["show"]:
        show_train_data(train_images, train_labels)
Jiale Song's avatar
Jiale Song committed
    
    # initialize and compile our deep neural network
Jiale Song's avatar
Jiale Song committed
    print("[INFO] compiling model...")
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))

    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(62, activation='softmax'))

    # Use categorical_crossentropy for one-hot coding labels
    model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

    if args["show"]:
        print(model.summary())

    # train the network
    print("[INFO] training model...")
    history = model.fit(x=train_images, 
                        y=train_labels, 
                        validation_data=(test_images, test_labels), 
                        batch_size=BS,
                        epochs=EPOCHS, 
                        steps_per_epoch=len(train_images)//BS,
                        class_weight=classWeight)
    
    # evaluate the network
    print("[INFO] evaluating network...")
    plt.plot(history.history['accuracy'], label='accuracy')
    plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.ylim([0, 1])
    plt.legend(loc='lower right')
Jiale Song's avatar
Jiale Song committed
    plt.savefig(args["plot"])
Jiale Song's avatar
Jiale Song committed
    plt.show()

    test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)

    print("accuracy: ", test_acc)

    # save the model to disk
    print("[INFO] saving trained model...")
    model.save(args["model"] + ".h5")