Skip to content
Snippets Groups Projects
ocr_project.ipynb 126 KiB
Newer Older

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mini-Project OCR Code\n",
    "\n",
    "This is a walkthrough of the process we went through to develop an OCR model that recognizes handwriting.\n",
    "\n",
    "Our libraries used are listed in the README, we will utilize requirements.txt to load them all at once.\n",
    "\n",
    "**IMPORTANT** Make sure to use `python3.6` (the version we use for class projects) to run our program with little issues. Information can be found [here](https://courses.cs.vt.edu/cs4804/Spring24/projects/project0.html#python-installation) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide_output"
    ]
   },
   "outputs": [],
   "source": [
    "# Install libraries\n",
    "%pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load libraries\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2\n",
    "\n",
    "from keras import layers, models\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "from sklearn.metrics import classification_report"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Helper functions for loading dataset\n",
    "\n",
    "We need to load our dataset. We will create a helper function for the model OCR. This function will load the English Handwritten Characters dataset that should be in given path."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_eng_dataset(datasetPath):\n",
    "\t # initialize the list of data and labels\n",
    "    data = []\n",
    "    labels = []\n",
    "\n",
    "    # loop over the rows of the A-Z handwritten digit dataset\n",
    "    for row in open(datasetPath):\n",
    "        # Skip the first row\n",
    "        if row == \"image,label\\n\":\n",
    "            continue\n",
    "\n",
    "        # parse the label and image from the row\n",
    "        row = row.split(\",\")\n",
    "        imagePath = \"eng_dataset/\" + row[0] # hardcode the path\n",
    "        try:\n",
    "            image = cv2.imread(imagePath)\n",
    "            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)\n",
    "        except cv2.error as e:\n",
    "            print(\"[ERROR] loading image \", row[0], \" fail\")\n",
    "            continue\n",
    "        \n",
    "        label = row[1][:-1] if len(row[1]) > 1 else row[1] # remove '\\n' at end\n",
    "\n",
    "        # update the list of data and labels\n",
    "        data.append(image)\n",
    "        labels.append(label)\n",
    "\n",
    "    # convert the data and labels to NumPy arrays\n",
    "    data = np.array(data)\n",
    "    labels = np.array(labels, dtype=\"U1\")\n",
    "\t# return a 2-tuple of the English Handwritten Characters data and labels\n",
    "    return (data, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset Pre-Processing\n",
    "\n",
    "Next we will pre-process the dataset in order to train the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_dataset(data, labels):\n",
    "    \"\"\"\n",
    "    Help function to pre-process the dataset for ready to train model.\n",
    "    \"\"\"\n",
    "    # the architecture we're using is designed for 32x32 images,\n",
    "    # so we need to resize them to 32x32\n",
    "    data = [cv2.resize(image, (32, 32)) for image in data]\n",
    "    data = np.array(data, dtype=\"float32\")\n",
    "\n",
    "    # add a channel dimension to every image in the dataset and \n",
    "    # data = np.expand_dims(data, axis=-1)\n",
    "\n",
    "    # scale the pixel intensities of the images from [0, 255] down to [0, 1]\n",
    "    data /= 255.0\n",
    "\n",
    "    # convert the labels from integers to vectors\n",
    "    le = LabelBinarizer()\n",
    "    labels = le.fit_transform(labels)\n",
    "\n",
    "    # account for skew in the labeled data\n",
    "    classTotals = labels.sum(axis=0)\n",
    "    classWeight = {}\n",
    "    # loop over all classes and calculate the class weight\n",
    "    for i in range(0, len(classTotals)):\n",
    "        classWeight[i] = classTotals.max() / classTotals[i]\n",
    "\n",
    "    return data, labels, classWeight"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Verification\n",
    "\n",
    "To verify that the dataset looks correct, let's plot the first 25 images from the training set and display the class name below each image.\n",
    "\n",
    "We define the helper function here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_train_data(train_images, train_labels):\n",
    "    class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', \n",
    "                   'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'G', 'K', 'L', 'M', \n",
    "                   'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',\n",
    "                   'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'g', 'k', 'l', 'm', \n",
    "                   'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']\n",
    "\n",
    "    plt.figure(figsize=(10,10))\n",
    "    for i in range(25):\n",
    "        plt.subplot(5,5,i+1)\n",
    "        plt.xticks([])\n",
    "        plt.yticks([])\n",
    "        plt.grid(False)\n",
    "        plt.imshow(train_images[i])\n",
    "        # The CIFAR labels happen to be arrays, \n",
    "        # which is why you need the extra index\n",
    "        index = np.where(train_labels[i] == 1)[0][0]\n",
    "        plt.xlabel(class_names[index])\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and Pre-Process\n",
    "\n",
    "First, we need to load and pre-process the data. We will use the functions we defined previously.\n",
    "\n",
    "Make sure the English Handwritten Characters dataset is loocated in `/eng_dataset` in the following format:\n",
    "\n",
    "```\n",
    ".\n",
    "├── ocr_project.ipynb\n",
    "└── eng_dataset\n",
    "    ├── english.csv\n",
    "    └── Img\n",
    "        ├── imgXXX-XXX.png\n",
    "        └── ...\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define directories here\n",
    "datasetPath = \"eng_dataset/english.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] loading datasets...\n",
      "[INFO] pre-processing datasets...\n"
     ]
    }
   ],
   "source": [
    "# load the English Handwritten Characters datasets\n",
    "print(\"[INFO] loading datasets...\")\n",
    "(data, labels) = load_eng_dataset(datasetPath)\n",
    "\n",
    "# pre-process the data and labels for training\n",
    "print(\"[INFO] pre-processing datasets...\")\n",
    "data, labels, classWeight = process_dataset(data, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training\n",
    "\n",
    "Time to begin the training, we need to split the data for training and testing first. The training data will be shown here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 720x720 with 25 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'Training image Shape: (2898, 32, 32)'"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'Testing image Shape: (512, 32, 32)'"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# partition the data into training and testing splits using 90% of\n",
    "# the data for training and the remaining 10% for testing\n",
    "(train_images, test_images, train_labels, test_labels) = train_test_split(data,\n",
    "        labels, test_size=0.15, stratify=labels, random_state=42)\n",
    "    \n",
    "# show train data in plot\n",
    "show_train_data(train_images, train_labels)\n",
    "\n",
    "# Show shapes\n",
    "display(f\"Training image Shape: {train_images.shape}\")\n",
    "display(f\"Testing image Shape: {test_images.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Looks great!\n",
    "\n",
    "### Modeling\n",
    "\n",
    "We need to define some parameters for our model. Here is what they mean:\n",
    "\n",
    "- `EPOCHS` - amount of iterations to fit the model\n",
    "- `BATCH_SIZE` - size of slices of the dataset\n",
    "\n",
    "Info about model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS = 80\n",
    "BATCH_SIZE = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] compiling model...\n",
      "[INFO] building model...\n"
     ]
    }
   ],
   "source": [
    "# initialize and compile our deep neural network\n",
    "print(\"[INFO] compiling model...\")\n",
    "model = models.Sequential(name = \"OCR_CNN\")\n",
    "model.add(layers.Conv2D(32, (3, 3), activation='relu', use_bias=False))\n",
    "model.add(layers.MaxPooling2D((2, 2)))\n",
    "model.add(layers.Conv2D(64, (3, 3), activation='relu', use_bias=False))\n",
    "model.add(layers.MaxPooling2D((2, 2)))\n",
    "#model.add(layers.Conv2D(64, (3, 3), activation='relu', use_bias=False))\n",
    "model.add(layers.Flatten())\n",
    "model.add(layers.Dense(128, activation='relu'))\n",
    "model.add(layers.Dropout(0.5))\n",
    "model.add(layers.Dense(62, activation='softmax'))\n",
    "# build model\n",
    "print(\"[INFO] building model...\")\n",
    "model.build(input_shape=(BATCH_SIZE, 32, 32, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"OCR_CNN\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d_51 (Conv2D)           (50, 30, 30, 32)          288       \n",
      "_________________________________________________________________\n",
      "max_pooling2d_34 (MaxPooling (50, 15, 15, 32)          0         \n",
      "_________________________________________________________________\n",
      "conv2d_52 (Conv2D)           (50, 13, 13, 64)          18432     \n",
      "_________________________________________________________________\n",
      "max_pooling2d_35 (MaxPooling (50, 6, 6, 64)            0         \n",
      "_________________________________________________________________\n",
      "flatten_17 (Flatten)         (50, 2304)                0         \n",
      "_________________________________________________________________\n",
      "dense_32 (Dense)             (50, 128)                 295040    \n",
      "_________________________________________________________________\n",
      "dropout_8 (Dropout)          (50, 128)                 0         \n",
      "_________________________________________________________________\n",
      "dense_33 (Dense)             (50, 62)                  7998      \n",
      "=================================================================\n",
      "Total params: 321,758\n",
      "Trainable params: 321,758\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "None"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Use categorical_crossentropy for one-hot coding labels\n",
    "model.compile(optimizer='adam',\n",
    "    loss='categorical_crossentropy',\n",
    "    metrics=['accuracy'])\n",
    "\n",
    "display(model.summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we are using a CNN, we need to add the channel parameter to our training and test data shapes. This is so our data is represented as `(batch_size, new_height, new_width, filters)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(512, 32, 32, 1)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Keras needs a channel dimension for the model\n",
    "# Since the images are greyscale, the channel can be 1\n",
    "train_images = train_images.reshape(-1, 32, 32, 1)\n",
    "test_images = test_images.reshape(-1, 32, 32, 1)\n",
    "display(test_images.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] training model...\n",
      "Epoch 1/80\n",
      "57/57 [==============================] - 3s 40ms/step - loss: 4.1263 - accuracy: 0.0200 - val_loss: 4.0838 - val_accuracy: 0.0762\n",
      "Epoch 2/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 3.9482 - accuracy: 0.0579 - val_loss: 3.6387 - val_accuracy: 0.1348\n",
      "Epoch 3/80\n",
      "57/57 [==============================] - 2s 40ms/step - loss: 3.4824 - accuracy: 0.1268 - val_loss: 3.0436 - val_accuracy: 0.2773\n",
      "Epoch 4/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 2.9919 - accuracy: 0.2152 - val_loss: 2.5431 - val_accuracy: 0.3848\n",
      "Epoch 5/80\n",
      "57/57 [==============================] - 2s 41ms/step - loss: 2.6030 - accuracy: 0.2886 - val_loss: 2.1482 - val_accuracy: 0.5020\n",
      "Epoch 6/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 2.2868 - accuracy: 0.3490 - val_loss: 1.8642 - val_accuracy: 0.5605\n",
      "Epoch 7/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 2.0885 - accuracy: 0.4115 - val_loss: 1.6513 - val_accuracy: 0.5957\n",
      "Epoch 8/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 1.8573 - accuracy: 0.4670 - val_loss: 1.5443 - val_accuracy: 0.6035\n",
      "Epoch 9/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 1.6904 - accuracy: 0.5032 - val_loss: 1.3878 - val_accuracy: 0.6465\n",
      "Epoch 10/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 1.5709 - accuracy: 0.5421 - val_loss: 1.2869 - val_accuracy: 0.6543\n",
      "Epoch 11/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 1.4515 - accuracy: 0.5702 - val_loss: 1.2251 - val_accuracy: 0.6777\n",
      "Epoch 12/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 1.3549 - accuracy: 0.5797 - val_loss: 1.1648 - val_accuracy: 0.6738\n",
      "Epoch 13/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 1.2908 - accuracy: 0.6015 - val_loss: 1.1176 - val_accuracy: 0.7070\n",
      "Epoch 14/80\n",
      "57/57 [==============================] - 2s 35ms/step - loss: 1.2471 - accuracy: 0.6173 - val_loss: 1.0772 - val_accuracy: 0.7188\n",
      "Epoch 15/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 1.1482 - accuracy: 0.6513 - val_loss: 1.0406 - val_accuracy: 0.7129\n",
      "Epoch 16/80\n",
      "57/57 [==============================] - 2s 30ms/step - loss: 1.0662 - accuracy: 0.6534 - val_loss: 1.0068 - val_accuracy: 0.7344\n",
      "Epoch 17/80\n",
      "57/57 [==============================] - 2s 31ms/step - loss: 1.0530 - accuracy: 0.6633 - val_loss: 0.9945 - val_accuracy: 0.7344\n",
      "Epoch 18/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 1.0029 - accuracy: 0.6675 - val_loss: 0.9942 - val_accuracy: 0.7246\n",
      "Epoch 19/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 0.9219 - accuracy: 0.7103 - val_loss: 0.9844 - val_accuracy: 0.7246\n",
      "Epoch 20/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 0.9073 - accuracy: 0.7100 - val_loss: 0.9860 - val_accuracy: 0.7266\n",
      "Epoch 21/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 0.8509 - accuracy: 0.7310 - val_loss: 0.9577 - val_accuracy: 0.7461\n",
      "Epoch 22/80\n",
      "57/57 [==============================] - 2s 38ms/step - loss: 0.8417 - accuracy: 0.7265 - val_loss: 0.9592 - val_accuracy: 0.7480\n",
      "Epoch 23/80\n",
      "57/57 [==============================] - 2s 38ms/step - loss: 0.8113 - accuracy: 0.7317 - val_loss: 0.9147 - val_accuracy: 0.7402\n",
      "Epoch 24/80\n",
      "57/57 [==============================] - 2s 40ms/step - loss: 0.7345 - accuracy: 0.7475 - val_loss: 0.9604 - val_accuracy: 0.7422\n",
      "Epoch 25/80\n",
      "57/57 [==============================] - 2s 40ms/step - loss: 0.7125 - accuracy: 0.7683 - val_loss: 0.9708 - val_accuracy: 0.7344\n",
      "Epoch 26/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 0.6743 - accuracy: 0.7658 - val_loss: 0.9696 - val_accuracy: 0.7344\n",
      "Epoch 27/80\n",
      "57/57 [==============================] - 2s 35ms/step - loss: 0.6806 - accuracy: 0.7707 - val_loss: 0.9611 - val_accuracy: 0.7520\n",
      "Epoch 28/80\n",
      "57/57 [==============================] - 2s 38ms/step - loss: 0.6684 - accuracy: 0.7704 - val_loss: 0.9740 - val_accuracy: 0.7480\n",
      "Epoch 29/80\n",
      "57/57 [==============================] - 2s 38ms/step - loss: 0.6236 - accuracy: 0.7893 - val_loss: 0.9719 - val_accuracy: 0.7441\n",
      "Epoch 30/80\n",
      "57/57 [==============================] - 2s 39ms/step - loss: 0.6148 - accuracy: 0.7925 - val_loss: 0.9768 - val_accuracy: 0.7441\n",
      "Epoch 31/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.5969 - accuracy: 0.7971 - val_loss: 0.9384 - val_accuracy: 0.7461\n",
      "Epoch 32/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.5668 - accuracy: 0.8104 - val_loss: 0.9279 - val_accuracy: 0.7637\n",
      "Epoch 33/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 0.5786 - accuracy: 0.7978 - val_loss: 0.9732 - val_accuracy: 0.7500\n",
      "Epoch 34/80\n",
      "57/57 [==============================] - 2s 41ms/step - loss: 0.5362 - accuracy: 0.8139 - val_loss: 0.9789 - val_accuracy: 0.7617\n",
      "Epoch 35/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 0.5333 - accuracy: 0.8188 - val_loss: 0.9604 - val_accuracy: 0.7578\n",
      "Epoch 36/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.5207 - accuracy: 0.8220 - val_loss: 0.9480 - val_accuracy: 0.7500\n",
      "Epoch 37/80\n",
      "57/57 [==============================] - 2s 40ms/step - loss: 0.4855 - accuracy: 0.8346 - val_loss: 1.0021 - val_accuracy: 0.7500\n",
      "Epoch 38/80\n",
      "57/57 [==============================] - 2s 39ms/step - loss: 0.4904 - accuracy: 0.8241 - val_loss: 0.9739 - val_accuracy: 0.7559\n",
      "Epoch 39/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.4762 - accuracy: 0.8329 - val_loss: 1.0013 - val_accuracy: 0.7422\n",
      "Epoch 40/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 0.4636 - accuracy: 0.8332 - val_loss: 1.0122 - val_accuracy: 0.7520\n",
      "Epoch 41/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 0.4644 - accuracy: 0.8413 - val_loss: 0.9958 - val_accuracy: 0.7676\n",
      "Epoch 42/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 0.4382 - accuracy: 0.8430 - val_loss: 1.0455 - val_accuracy: 0.7461\n",
      "Epoch 43/80\n",
      "57/57 [==============================] - 2s 35ms/step - loss: 0.4152 - accuracy: 0.8539 - val_loss: 1.0369 - val_accuracy: 0.7598\n",
      "Epoch 44/80\n",
      "57/57 [==============================] - 2s 35ms/step - loss: 0.4241 - accuracy: 0.8430 - val_loss: 1.1106 - val_accuracy: 0.7402\n",
      "Epoch 45/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 0.4116 - accuracy: 0.8508 - val_loss: 1.0125 - val_accuracy: 0.7461\n",
      "Epoch 46/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.3994 - accuracy: 0.8596 - val_loss: 1.0470 - val_accuracy: 0.7695\n",
      "Epoch 47/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 0.4325 - accuracy: 0.8452 - val_loss: 1.0454 - val_accuracy: 0.7461\n",
      "Epoch 48/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.3701 - accuracy: 0.8718 - val_loss: 1.0575 - val_accuracy: 0.7578\n",
      "Epoch 49/80\n",
      "57/57 [==============================] - 2s 35ms/step - loss: 0.3781 - accuracy: 0.8666 - val_loss: 1.0279 - val_accuracy: 0.7734\n",
      "Epoch 50/80\n",
      "57/57 [==============================] - 2s 34ms/step - loss: 0.3822 - accuracy: 0.8722 - val_loss: 1.0828 - val_accuracy: 0.7559\n",
      "Epoch 51/80\n",
      "57/57 [==============================] - 2s 35ms/step - loss: 0.3535 - accuracy: 0.8715 - val_loss: 1.0420 - val_accuracy: 0.7500\n",
      "Epoch 52/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.3446 - accuracy: 0.8785 - val_loss: 1.0947 - val_accuracy: 0.7520\n",
      "Epoch 53/80\n",
      "57/57 [==============================] - 2s 39ms/step - loss: 0.3421 - accuracy: 0.8824 - val_loss: 1.0699 - val_accuracy: 0.7461\n",
      "Epoch 54/80\n",
      "57/57 [==============================] - 2s 35ms/step - loss: 0.3586 - accuracy: 0.8694 - val_loss: 1.0417 - val_accuracy: 0.7656\n",
      "Epoch 55/80\n",
      "57/57 [==============================] - 2s 40ms/step - loss: 0.3285 - accuracy: 0.8768 - val_loss: 1.0780 - val_accuracy: 0.7617\n",
      "Epoch 56/80\n",
      "57/57 [==============================] - 2s 38ms/step - loss: 0.3199 - accuracy: 0.8841 - val_loss: 1.0870 - val_accuracy: 0.7656\n",
      "Epoch 57/80\n",
      "57/57 [==============================] - 2s 42ms/step - loss: 0.3201 - accuracy: 0.8894 - val_loss: 1.0548 - val_accuracy: 0.7676\n",
      "Epoch 58/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.2920 - accuracy: 0.8954 - val_loss: 1.0703 - val_accuracy: 0.7695\n",
      "Epoch 59/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.3039 - accuracy: 0.8916 - val_loss: 1.1101 - val_accuracy: 0.7617\n",
      "Epoch 60/80\n",
      "57/57 [==============================] - 2s 42ms/step - loss: 0.3221 - accuracy: 0.8897 - val_loss: 1.1130 - val_accuracy: 0.7520\n",
      "Epoch 61/80\n",
      "57/57 [==============================] - 2s 39ms/step - loss: 0.3108 - accuracy: 0.8919 - val_loss: 1.0888 - val_accuracy: 0.7520\n",
      "Epoch 62/80\n",
      "57/57 [==============================] - 2s 38ms/step - loss: 0.3157 - accuracy: 0.8915 - val_loss: 1.1258 - val_accuracy: 0.7578\n",
      "Epoch 63/80\n",
      "57/57 [==============================] - 2s 38ms/step - loss: 0.3056 - accuracy: 0.8933 - val_loss: 1.1095 - val_accuracy: 0.7676\n",
      "Epoch 64/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.2912 - accuracy: 0.8964 - val_loss: 1.1501 - val_accuracy: 0.7520\n",
      "Epoch 65/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.2861 - accuracy: 0.8982 - val_loss: 1.1846 - val_accuracy: 0.7500\n",
      "Epoch 66/80\n",
      "57/57 [==============================] - 2s 38ms/step - loss: 0.2736 - accuracy: 0.9020 - val_loss: 1.1524 - val_accuracy: 0.7441\n",
      "Epoch 67/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 0.2727 - accuracy: 0.8985 - val_loss: 1.1769 - val_accuracy: 0.7715\n",
      "Epoch 68/80\n",
      "57/57 [==============================] - 2s 35ms/step - loss: 0.2652 - accuracy: 0.9077 - val_loss: 1.2352 - val_accuracy: 0.7461\n",
      "Epoch 69/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 0.2589 - accuracy: 0.9062 - val_loss: 1.2367 - val_accuracy: 0.7422\n",
      "Epoch 70/80\n",
      "57/57 [==============================] - 2s 36ms/step - loss: 0.2690 - accuracy: 0.8968 - val_loss: 1.1756 - val_accuracy: 0.7559\n",
      "Epoch 71/80\n",
      "57/57 [==============================] - 2s 31ms/step - loss: 0.2689 - accuracy: 0.9041 - val_loss: 1.1749 - val_accuracy: 0.7441\n",
      "Epoch 72/80\n",
      "57/57 [==============================] - 2s 35ms/step - loss: 0.2475 - accuracy: 0.9168 - val_loss: 1.2030 - val_accuracy: 0.7402\n",
      "Epoch 73/80\n",
      "57/57 [==============================] - 2s 38ms/step - loss: 0.2404 - accuracy: 0.9105 - val_loss: 1.2158 - val_accuracy: 0.7480\n",
      "Epoch 74/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.2526 - accuracy: 0.9101 - val_loss: 1.1893 - val_accuracy: 0.7559\n",
      "Epoch 75/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.2342 - accuracy: 0.9210 - val_loss: 1.2666 - val_accuracy: 0.7266\n",
      "Epoch 76/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.2523 - accuracy: 0.9105 - val_loss: 1.2919 - val_accuracy: 0.7383\n",
      "Epoch 77/80\n",
      "57/57 [==============================] - 2s 38ms/step - loss: 0.2451 - accuracy: 0.9077 - val_loss: 1.2883 - val_accuracy: 0.7480\n",
      "Epoch 78/80\n",
      "57/57 [==============================] - 2s 35ms/step - loss: 0.2559 - accuracy: 0.9094 - val_loss: 1.2495 - val_accuracy: 0.7480\n",
      "Epoch 79/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.2562 - accuracy: 0.8992 - val_loss: 1.2393 - val_accuracy: 0.7598\n",
      "Epoch 80/80\n",
      "57/57 [==============================] - 2s 37ms/step - loss: 0.2341 - accuracy: 0.9189 - val_loss: 1.3221 - val_accuracy: 0.7500\n"
     ]
    }
   ],
   "source": [
    "# train the network\n",
    "print(\"[INFO] training model...\")\n",
    "history = model.fit(x=train_images, \n",
    "    y=train_labels, \n",
    "    validation_data=(test_images, test_labels), \n",
    "    batch_size=BATCH_SIZE,\n",
    "    epochs=EPOCHS, \n",
    "    steps_per_epoch=len(train_images)//BATCH_SIZE,\n",
    "    class_weight=classWeight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validation\n",
    "\n",
    "Now, we need to validate our data. Using the test data we split earlier, we can evaluate our  model. Let's plot our training accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Plot accuracies\n",
    "plt.plot(history.history['accuracy'], label='Training Accuracy')\n",
    "plt.plot(history.history['val_accuracy'], label = 'Validation Accuracy')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.ylim([0, 1])\n",
    "plt.legend(loc='lower right')\n",
    "plt.title(\"Training and Value Accuracy\")\n",
    "plt.savefig(\"accuracy_plot\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A big difference between training accuracy and validation accuracy means we could have potentially overfitted the model.\n",
    "\n",
    "### Evaluation\n",
    "\n",
    "We can evaluate the loss and accuracy with the `evaluate` function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] evaluating network...\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "('Unrecognized keyword arguments:', dict_keys(['batch_shape']))",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-100-d3feebf6d1e4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtensorflow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeras\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mload_model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"[INFO] evaluating network...\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"./model.h5\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_images\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_labels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mBATCH_SIZE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"test loss, test acc:\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/intro_to_ai/lib/python3.6/site-packages/keras/saving/save.py\u001b[0m in \u001b[0;36mload_model\u001b[0;34m(filepath, custom_objects, compile, options)\u001b[0m\n\u001b[1;32m    199\u001b[0m             (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):\n\u001b[1;32m    200\u001b[0m           return hdf5_format.load_model_from_hdf5(filepath, custom_objects,\n\u001b[0;32m--> 201\u001b[0;31m                                                   compile)\n\u001b[0m\u001b[1;32m    202\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    203\u001b[0m         \u001b[0mfilepath\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpath_to_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/intro_to_ai/lib/python3.6/site-packages/keras/saving/hdf5_format.py\u001b[0m in \u001b[0;36mload_model_from_hdf5\u001b[0;34m(filepath, custom_objects, compile)\u001b[0m\n\u001b[1;32m    179\u001b[0m     \u001b[0mmodel_config\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjson_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    180\u001b[0m     model = model_config_lib.model_from_config(model_config,\n\u001b[0;32m--> 181\u001b[0;31m                                                custom_objects=custom_objects)\n\u001b[0m\u001b[1;32m    182\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    183\u001b[0m     \u001b[0;31m# set weights\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/intro_to_ai/lib/python3.6/site-packages/keras/saving/model_config.py\u001b[0m in \u001b[0;36mmodel_from_config\u001b[0;34m(config, custom_objects)\u001b[0m\n\u001b[1;32m     50\u001b[0m                     '`Sequential.from_config(config)`?')\n\u001b[1;32m     51\u001b[0m   \u001b[0;32mfrom\u001b[0m \u001b[0mkeras\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdeserialize\u001b[0m  \u001b[0;31m# pylint: disable=g-import-not-at-top\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mdeserialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcustom_objects\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcustom_objects\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/intro_to_ai/lib/python3.6/site-packages/keras/layers/serialization.py\u001b[0m in \u001b[0;36mdeserialize\u001b[0;34m(config, custom_objects)\u001b[0m\n\u001b[1;32m    210\u001b[0m       \u001b[0mmodule_objects\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mLOCAL\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mALL_OBJECTS\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    211\u001b[0m       \u001b[0mcustom_objects\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcustom_objects\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 212\u001b[0;31m       printable_module_name='layer')\n\u001b[0m",
      "\u001b[0;32m~/miniconda3/envs/intro_to_ai/lib/python3.6/site-packages/keras/utils/generic_utils.py\u001b[0m in \u001b[0;36mdeserialize_keras_object\u001b[0;34m(identifier, module_objects, custom_objects, printable_module_name)\u001b[0m\n\u001b[1;32m    676\u001b[0m             custom_objects=dict(\n\u001b[1;32m    677\u001b[0m                 \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_GLOBAL_CUSTOM_OBJECTS\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 678\u001b[0;31m                 list(custom_objects.items())))\n\u001b[0m\u001b[1;32m    679\u001b[0m       \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    680\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mCustomObjectScope\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcustom_objects\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/intro_to_ai/lib/python3.6/site-packages/keras/engine/sequential.py\u001b[0m in \u001b[0;36mfrom_config\u001b[0;34m(cls, config, custom_objects)\u001b[0m\n\u001b[1;32m    431\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mlayer_config\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mlayer_configs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    432\u001b[0m       layer = layer_module.deserialize(layer_config,\n\u001b[0;32m--> 433\u001b[0;31m                                        custom_objects=custom_objects)\n\u001b[0m\u001b[1;32m    434\u001b[0m       \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlayer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    435\u001b[0m     if (not model.inputs and build_input_shape and\n",
      "\u001b[0;32m~/miniconda3/envs/intro_to_ai/lib/python3.6/site-packages/keras/layers/serialization.py\u001b[0m in \u001b[0;36mdeserialize\u001b[0;34m(config, custom_objects)\u001b[0m\n\u001b[1;32m    210\u001b[0m       \u001b[0mmodule_objects\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mLOCAL\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mALL_OBJECTS\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    211\u001b[0m       \u001b[0mcustom_objects\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcustom_objects\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 212\u001b[0;31m       printable_module_name='layer')\n\u001b[0m",
      "\u001b[0;32m~/miniconda3/envs/intro_to_ai/lib/python3.6/site-packages/keras/utils/generic_utils.py\u001b[0m in \u001b[0;36mdeserialize_keras_object\u001b[0;34m(identifier, module_objects, custom_objects, printable_module_name)\u001b[0m\n\u001b[1;32m    679\u001b[0m       \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    680\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mCustomObjectScope\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcustom_objects\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 681\u001b[0;31m           \u001b[0mdeserialized_obj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_config\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    682\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    683\u001b[0m       \u001b[0;31m# Then `cls` may be a function returning a class.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/intro_to_ai/lib/python3.6/site-packages/keras/engine/base_layer.py\u001b[0m in \u001b[0;36mfrom_config\u001b[0;34m(cls, config)\u001b[0m\n\u001b[1;32m    746\u001b[0m         \u001b[0mA\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0minstance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    747\u001b[0m     \"\"\"\n\u001b[0;32m--> 748\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    749\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    750\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0mcompute_output_shape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/intro_to_ai/lib/python3.6/site-packages/keras/engine/input_layer.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, input_shape, batch_size, dtype, input_tensor, sparse, name, ragged, type_spec, **kwargs)\u001b[0m\n\u001b[1;32m    132\u001b[0m         \u001b[0minput_shape\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_input_shape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    133\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 134\u001b[0;31m       \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Unrecognized keyword arguments:'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    135\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    136\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0msparse\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mragged\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: ('Unrecognized keyword arguments:', dict_keys(['batch_shape']))"
     ]
    }
   ],
   "source": [
    "# evaluate the network\n",
    "#from tensorflow.keras.models import load_model\n",
    "#print(\"[INFO] evaluating network...\")\n",
    "#model = load_model(\"./model.h5\")\n",
    "results = model.evaluate(test_images, test_labels, batch_size=BATCH_SIZE)\n",
    "display(f\"test loss, test acc:\", results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predictions\n",
    "\n",
    "Let's predict with our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] predicting test samples...\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'predictions shape:'"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "(512, 62)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.60      0.67      0.63         9\n",
      "           1       0.88      0.78      0.82         9\n",
      "           2       0.89      1.00      0.94         8\n",
      "           3       0.75      0.75      0.75         8\n",
      "           4       0.62      0.62      0.62         8\n",
      "           5       0.71      0.62      0.67         8\n",
      "           6       0.73      1.00      0.84         8\n",
      "           7       0.88      0.88      0.88         8\n",
      "           8       0.73      1.00      0.84         8\n",
      "           9       0.57      0.50      0.53         8\n",
      "           A       0.89      0.89      0.89         9\n",
      "           B       1.00      0.62      0.77         8\n",
      "           C       0.67      0.89      0.76         9\n",
      "           D       1.00      1.00      1.00         8\n",
      "           E       0.86      0.75      0.80         8\n",
      "           F       0.89      1.00      0.94         8\n",
      "           G       1.00      0.75      0.86         8\n",
      "           H       0.78      0.88      0.82         8\n",
      "           I       0.57      0.50      0.53         8\n",
      "           G       0.71      0.56      0.63         9\n",
      "           K       0.88      0.88      0.88         8\n",
      "           L       1.00      0.88      0.93         8\n",
      "           M       0.86      0.67      0.75         9\n",
      "           N       0.67      0.67      0.67         9\n",
      "           O       0.50      0.62      0.56         8\n",
      "           P       0.46      0.75      0.57         8\n",
      "           Q       1.00      0.78      0.88         9\n",
      "           R       1.00      0.75      0.86         8\n",
      "           S       0.80      1.00      0.89         8\n",
      "           T       0.83      0.62      0.71         8\n",
      "           U       0.70      0.88      0.78         8\n",
      "           V       0.88      0.88      0.88         8\n",
      "           W       0.71      0.62      0.67         8\n",
      "           X       0.73      0.89      0.80         9\n",
      "           Y       1.00      1.00      1.00         8\n",
      "           Z       1.00      0.62      0.77         8\n",
      "           a       0.67      0.50      0.57         8\n",
      "           b       1.00      0.62      0.77         8\n",
      "           c       0.43      0.33      0.38         9\n",
      "           d       0.89      0.89      0.89         9\n",
      "           e       0.46      0.75      0.57         8\n",
      "           f       1.00      0.56      0.71         9\n",
      "           g       0.86      0.75      0.80         8\n",
      "           h       0.88      0.78      0.82         9\n",
      "           i       0.78      0.88      0.82         8\n",
      "           g       0.64      0.88      0.74         8\n",
      "           k       0.86      0.75      0.80         8\n",
      "           l       0.54      0.78      0.64         9\n",
      "           m       0.89      1.00      0.94         8\n",
      "           n       0.75      0.75      0.75         8\n",
      "           o       0.86      0.75      0.80         8\n",
      "           p       1.00      0.67      0.80         9\n",
      "           q       0.67      0.50      0.57         8\n",
      "           r       0.50      0.62      0.56         8\n",
      "           s       0.50      0.56      0.53         9\n",
      "           t       1.00      0.75      0.86         8\n",
      "           u       0.83      0.62      0.71         8\n",
      "           v       0.56      0.62      0.59         8\n",
      "           w       0.67      0.75      0.71         8\n",
      "           x       0.67      0.75      0.71         8\n",
      "           y       0.86      0.75      0.80         8\n",
      "           z       0.58      0.88      0.70         8\n",
      "\n",
      "    accuracy                           0.75       512\n",
      "   macro avg       0.77      0.75      0.75       512\n",
      "weighted avg       0.77      0.75      0.75       512\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(\"[INFO] predicting test samples...\")\n",
    "predictions = model.predict(test_images, batch_size=BATCH_SIZE)\n",
    "display(\"predictions shape:\", predictions.shape)\n",
    "\n",
    "# labels for readability\n",
    "labelNames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', \n",
    "                   'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'G', 'K', 'L', 'M', \n",
    "                   'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',\n",
    "                   'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'g', 'k', 'l', 'm', \n",
    "                   'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']\n",
    "\n",
    "print(classification_report(test_labels.argmax(axis=1),\n",
    "    predictions.argmax(axis=1), target_names=labelNames))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving\n",
    "\n",
    "Lets save the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] saving trained model...\n"
     ]
    }
   ],
   "source": [
    " # save the model to disk\n",
    "print(\"[INFO] saving trained model...\")\n",
    "model.save(\"OCR_CNN.h5\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Intro_to_AI",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}