import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import argparse
import cv2
import time
import keras
import sys

from keras import layers, models
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer

def parse_argument():
    """
    Construct the argument parser and parse the arguments.
    """
    ap = argparse.ArgumentParser()
    ap.add_argument("-d", "--dataset", required=True,
                    help="path to English Handwritten Characters dataset")
    ap.add_argument("-m", "--model", type=str, default="model.keras",
                    help="path to output trained handwriting recognition model")
    ap.add_argument("-p", "--plot", type=str, default="plot.png",
                    help="path to output training history file")
    ap.add_argument('-s', '--show', action='store_true',
                    help='show all details')
    ap.add_argument('-t', '--timer', action='store_true',
                    help='set timer for training data')
    args = vars(ap.parse_args())

    return args


def load_eng_dataset(datasetPath):
    """
    Helper function for train model OCR. Function will load English Handwritten
    Characters dataset that should be in given path.
    """
    # initialize the list of data and labels
    data = []
    labels = []

    # loop over the rows of the A-Z handwritten digit dataset
    for row in open(datasetPath):
        # Skip the first row
        if row == "image,label\n":
            continue

        # parse the label and image from the row
        row = row.split(",")
        imagePath = "dataset/" + row[0] # hardcode the path
        try:
            image = cv2.imread(imagePath)
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        except cv2.error as e:
            print("[ERROR] loading image ", row[0], " fail")
            continue
        
        label = row[1][:-1] if len(row[1]) > 1 else row[1] # remove '\n' at end

        # update the list of data and labels
        data.append(image)
        labels.append(label)

    # convert the data and labels to NumPy arrays
    data = np.array(data)
    labels = np.array(labels, dtype="U1")
	# return a 2-tuple of the English Handwritten Characters data and labels
    return (data, labels)


def load_MNIST_dataset():
    """
    Helper function for train model OCR. Function will load MNIST handwritten
    dataset that should be in given path.
    """
    # initialize the list of data and labels
    data = []
    labels = []

    # loop over the rows of the MNIST handwritten dataset
    for row in open("dataset/dataset_mnist.csv"):

        # parse the label and image from the row
        row = row.split(",")
        label = row[-1]
        label = label[0] if len(label) > 1 else label # remove '\n' at end
        image = np.array([int(x) for x in row[0: -1]], dtype="uint8")
        
        # the image's size is 28*28
        image = image.reshape((28,28))

        # update the list of data and labels
        data.append(image)
        labels.append(label)

    # loop over the rows of the EMNIST handwritten dataset
    for row in open("dataset/dataset_emnist.csv"):

        # parse the label and image from the row
        row = row.split(",")
        label = row[-1]
        label = label[0] if len(label) > 1 else label # remove '\n' at end
        image = np.array([int(x) for x in row[0: -1]], dtype="uint8")
        
        # the image's size is 28*28
        image = image.reshape((28,28))
        image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
        image = cv2.flip(image, 1)

        # update the list of data and labels
        data.append(image)
        labels.append(label)

    # convert the data and labels to NumPy arrays
    data = np.array(data)
    labels = np.array(labels, dtype="U1")
	# return a 2-tuple of the MNIST handwritten data and labels
    return (data, labels)


def load_dataset(datasetPath):
    """
    Helper function for train model OCR. Function will load MNIST handwritten
    dataset that should be in given path.
    """
    # initialize the list of data and labels
    data = []
    labels = []

    # loop over the rows of the MNIST handwritten dataset
    for row in open(datasetPath):

        # parse the label and image from the row
        row = row.split(",")
        label = row[-1]
        label = label[0] if len(label) > 1 else label # remove '\n' at end
        image = np.array([int(x) for x in row[0: -1]], dtype="uint8")
        
        # the image's size is 28*28
        image = image.reshape((28,28))

        # update the list of data and labels
        data.append(image)
        labels.append(label)

    # convert the data and labels to NumPy arrays
    data = np.array(data)
    labels = np.array(labels, dtype="U1")
	# return a 2-tuple of the MNIST handwritten data and labels
    return (data, labels)


def process_dataset(data, labels, img_height, img_weight):
    """
    Help function to pre-process the dataset for ready to train model.
    """
    # resize the image, default is 32x32
    data = [cv2.resize(image, (img_height, img_weight)) for image in data]
    data = np.array(data, dtype="float32")

    # scale the pixel intensities of the images from [0, 255] down to [0, 1]
    data /= 255.0

    # convert the labels from integers to vectors
    le = LabelBinarizer()
    labels = le.fit_transform(labels)

    # account for skew in the labeled data
    classTotals = labels.sum(axis=0)
    classWeight = {}
    # loop over all classes and calculate the class weight
    for i in range(0, len(classTotals)):
        classWeight[i] = classTotals.max() / classTotals[i]

    return data, labels, classWeight


def show_train_data(train_images, train_labels):
    """
    To verify that the dataset looks correct, let's plot the first 25 images from
    the training set and display the class name below each image
    """
    class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 
                   'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'G', 'K', 'L', 'M', 
                   'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
                   'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'g', 'k', 'l', 'm', 
                   'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

    plt.figure(figsize=(10,10))
    for i in range(25):
        plt.subplot(5,5,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(train_images[i])
        # The CIFAR labels happen to be arrays, 
        # which is why you need the extra index
        index = np.where(train_labels[i] == 1)[0][0]
        plt.xlabel(class_names[index])
    plt.show()


def show_train_result(history, plot_path):
    """
    Help function to show training result and save result plot.
    """
    plt.plot(history.history['accuracy'], label='accuracy')
    plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.ylim([0, 1])
    plt.legend(loc='lower right')
    plt.savefig(plot_path)
    plt.show()


def save_model(model, model_name):
    """
    Help function to save the model with proper postfix.
    """
    if model_name.endswith(".h5") or model_name.endswith(".keras"):
        model.save(model_name)
    else:
        model.save(model_name + ".keras")


if __name__ == "__main__":
    # load arguments
    args = parse_argument()

    # initialize the number of epochs, batch size, test size, classes size
    # and image size
    epochs = 40
    batch_size = 50
    test_size = 0.1
    num_classes = 62
    img_height, img_weight = 32, 32

    # load the English Handwritten Characters datasets
    print("[INFO] loading datasets...")
    data, labels = [], []
    if "english.csv" in args["dataset"]:
        (data, labels) = load_eng_dataset(args["dataset"])
    elif "dataset_mnist.csv" in args["dataset"] or "dataset_emnist.csv" in args["dataset"]:
        epochs = 20
        batch_size = 500
        test_size = 0.2
        num_classes = 36
        img_height, img_weight = 28, 28
        (data, labels) = load_MNIST_dataset()
    else:
        print("[ERROR] given dataset is not implemented, system quit.")
        sys.exit()

    # pre-process the data and labels for training
    print("[INFO] pre-processing datasets...")
    data, labels, classWeight = process_dataset(data, labels, img_height, img_weight)

    # partition the data into training and testing splits using 90% of
    # the data for training and the remaining 10% for testing
    (train_images, test_images, train_labels, test_labels) = train_test_split(data,
        labels, test_size=test_size, stratify=labels, random_state=42)
    
    # show train data in plot
    if args["show"]:
        show_train_data(train_images, train_labels)
    
    # initialize and compile our deep neural network
    print("[INFO] compiling model...")
    model = models.Sequential([keras.Input(shape=(img_height, img_weight, 1))])
    model.add(layers.Conv2D(32, (3, 3), activation='relu', padding = 'same'))
    model.add(layers.BatchNormalization())
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu', padding = 'same'))
    model.add(layers.BatchNormalization())
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(128, (3, 3), activation='relu', padding = 'same'))
    model.add(layers.BatchNormalization())
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(128, (3, 3), activation='relu', padding = 'same'))
    model.add(layers.BatchNormalization())
    model.add(layers.MaxPooling2D((2, 2)))

    model.add(layers.Dropout(0.2))
    model.add(layers.Flatten())
    model.add(layers.Dense(256, activation='relu'))
    model.add(layers.Dense(num_classes, activation='softmax'))

    # Use categorical_crossentropy for one-hot coding labels
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                  loss=keras.losses.categorical_crossentropy,
                  metrics=['accuracy'])

    if args["show"]:
        print(model.summary())

    # Keras needs a channel dimension for the model
    # Since the images are greyscale, the channel can be 1
    train_images = train_images.reshape(-1, img_height, img_weight, 1)
    test_images = test_images.reshape(-1, img_height, img_weight, 1)

    # train the network
    print("[INFO] training model...")
    start_time = time.time()
    history = model.fit(x=train_images, 
                        y=train_labels, 
                        validation_data=(test_images, test_labels), 
                        batch_size=batch_size,
                        epochs=epochs,
                        class_weight=classWeight)
    
    end_time = time.time()
    spend = end_time - start_time
    if args["timer"]:
        print("training time: %.3f seconds" % spend)

    # evaluate the network
    print("[INFO] evaluating network...")
    show_train_result(history, args["plot"])
    model.evaluate(test_images, test_labels, verbose=2)

    # save the model to disk
    print("[INFO] saving trained model...")
    save_model(model, args["model"])