Classify Handwritten Digits Using Keras

We can build a very small but very accurate deep learning model for classifying handwritten MNIST digits in Python with Keras, easily trained on a laptop CPU. 

This assumes Keras and numpy are installed; see this post for installing on macOS for example.

The labeled data for training is available inside the Keras library itself in keras.datasets

The load_data function returns a pair of tuples, each element of which is a numpy array with the images (stored as tensors) and labels respectively.  First,  in the training script, we only use the training data.

The model is a network of two densely-connected layers. 

We use categorical_crossentropy for the loss function since there are multiple categories to classify (one for each digit). 

The training and testing images are re-shaped to match the input shape of the Dense layer which is the first layer in the model. They are also converted to float32 type. 

The fit function executes the training process, after which the trained model is saved to a file. The classification script will read the model from the file and use it. 

The following is the training script.

train.py

from keras.datasets import mnist
from keras import models
from keras import layers
from keras.utils import to_categorical
 
(trainImages, trainLabels) = mnist.load_data()[0]
network = models.Sequential() network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,))) network.add(layers.Dense(10, activation='softmax'))
network.compile( optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'] )
network.summary()
trainImages = trainImages.reshape((60000, 28 * 28)) trainImages = trainImages.astype('float32') / 255
trainLabels = to_categorical(trainLabels)
network.fit(trainImages, trainLabels, epochs=5, batch_size=128)
network.save('digit_model.keras')

Train the network with: 

$ python train.py

 

To classify a specific image, we select one example from the test data (available with Keras) and open the saved model from the file to use for classification. First, only the test images array is extracted; a specific image to test is chosen by user input. 

The trained model is loaded with load_model

An index is read from the user to choose one of the testing data images (stored as a tensor). The image is reshaped into the shape required as input into the model. 

We use pyplot to show the actual image before running the classification. 

Running the classification is done by calling predict; note that the input being classified as one of the 10 digits must be in the correct shape as a numpy array.

The function argmax is then used to select the highest value out of the list of predictions, which are probabilities indicating how likely it is that the input is in the class at the given position. For example, for an image most likely classified as 2, the third position value will have the highest probability (the list starts from zero). The output is printed to show the actual numerical representation.

The following is the classification script.

classify.py

from keras import models
from keras.datasets import mnist
from matplotlib import pyplot
from numpy import array, argmax

mnistDataTuples = mnist.load_data()
testData = mnistDataTuples[1]
testImages = testData[0]

network = models.load_model('digit_model.keras')

testImages = testImages.reshape((10000, 28 * 28))
testImages = testImages.astype('float32') / 255

testImageIndex = int(input('Select test image index (0-9999): '))
inputImage = testImages[testImageIndex]
inputImageScaled = inputImage.reshape(28, 28)

pyplot.imshow(inputImageScaled, cmap='gray')
pyplot.show() # Opens blocking window; close it to continue.

# Predict class using the correct shape of the test image.
resultPredictions = network.predict( array( [inputImage,] ) )
print(resultPredictions)

resultClass = argmax(resultPredictions)

print('The digit is: ' + str(resultClass))

To run the classification, execute the following and select an index: 

$ python classify.py
Select test image index (0-9999):

The original image will be displayed in a window; after it is closed the result of the classification will be printed in the terminal. 

This can classify the MNIST digits with about 98% accuracy.