A Convolutional Neural Network for Face Keypoint Detection

Yesterday, I read this recent article on medium about facial keypoint detection. The article suggests that deep learning methods can easily be used to perform this task. It ends by suggesting that everyone should try it, since the data needed and the toolkits are all open source. This article is my attempt, since I've been interested in face detection for a long time and written about it before.

This is the outline of what we'll try:

  • loading the data
  • analyzing the data
  • building a Keras model
  • checking the results
  • applying the method to a fun problem

Loading the data

The data we will use comes from a Kaggle challenge called Facial Keypoints Detection. I've downloaded the .csv file and put it in a data/ directory. Let's use pandas to read it.

In [1]:
import pandas as pd
In [2]:
df = pd.read_csv('data/training.csv')
In [3]:
df.head()
Out[3]:
left_eye_center_x left_eye_center_y right_eye_center_x right_eye_center_y left_eye_inner_corner_x left_eye_inner_corner_y left_eye_outer_corner_x left_eye_outer_corner_y right_eye_inner_corner_x right_eye_inner_corner_y ... nose_tip_y mouth_left_corner_x mouth_left_corner_y mouth_right_corner_x mouth_right_corner_y mouth_center_top_lip_x mouth_center_top_lip_y mouth_center_bottom_lip_x mouth_center_bottom_lip_y Image
0 66.033564 39.002274 30.227008 36.421678 59.582075 39.647423 73.130346 39.969997 36.356571 37.389402 ... 57.066803 61.195308 79.970165 28.614496 77.388992 43.312602 72.935459 43.130707 84.485774 238 236 237 238 240 240 239 241 241 243 240 23...
1 64.332936 34.970077 29.949277 33.448715 58.856170 35.274349 70.722723 36.187166 36.034723 34.361532 ... 55.660936 56.421447 76.352000 35.122383 76.047660 46.684596 70.266553 45.467915 85.480170 219 215 204 196 204 211 212 200 180 168 178 19...
2 65.057053 34.909642 30.903789 34.909642 59.412000 36.320968 70.984421 36.320968 37.678105 36.320968 ... 53.538947 60.822947 73.014316 33.726316 72.732000 47.274947 70.191789 47.274947 78.659368 144 142 159 180 188 188 184 180 167 132 84 59 ...
3 65.225739 37.261774 32.023096 37.261774 60.003339 39.127179 72.314713 38.380967 37.618643 38.754115 ... 54.166539 65.598887 72.703722 37.245496 74.195478 50.303165 70.091687 51.561183 78.268383 193 192 193 194 194 194 193 192 168 111 50 12 ...
4 66.725301 39.621261 32.244810 38.042032 58.565890 39.621261 72.515926 39.884466 36.982380 39.094852 ... 64.889521 60.671411 77.523239 31.191755 76.997301 44.962748 73.707387 44.227141 86.871166 147 148 160 196 215 214 216 217 219 220 206 18...

5 rows × 31 columns

In [4]:
df.shape
Out[4]:
(7049, 31)

Analyzing the data

The Image column contains the face data for which the 30 first columns represent the keypoint data (15 x-coordinates and 15 y-coordinates). Let's try to get a feel for the data. First, let's display some faces.

In [5]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
In [6]:
def string2image(string):
    """Converts a string to a numpy array."""
    return np.array([int(item) for item in string.split()]).reshape((96, 96))

def plot_faces(nrows=5, ncols=5):
    """Randomly displays some faces from the training data."""
    selection = np.random.choice(df.index, size=(nrows*ncols), replace=False)
    image_strings = df.loc[selection]['Image']
    fig, axes = plt.subplots(figsize=(10, 10), nrows=nrows, ncols=ncols)
    for string, ax in zip(image_strings, axes.ravel()):
        ax.imshow(string2image(string), cmap='gray')
        ax.axis('off')
In [7]:
plot_faces()

Let's now add to that plot the facial keypoints that were tagged. First, let's do an example :

In [8]:
keypoint_cols = list(df.columns)[:-1]
In [9]:
xy = df.iloc[0][keypoint_cols].values.reshape((15, 2))
xy 
Out[9]:
array([[66.033563909799994, 39.002273684199999],
       [30.227007518800001, 36.4216781955],
       [59.582075188000005, 39.647422556399995],
       [73.130345864700004, 39.9699969925],
       [36.356571428599999, 37.389401503800002],
       [23.452872180500002, 37.389401503800002],
       [56.953263157899997, 29.033648120300001],
       [80.227127819499998, 32.2281383459],
       [40.227609022599999, 29.002321804499999],
       [16.3563789474, 29.647470676699999],
       [44.420571428599999, 57.066803007499999],
       [61.195308270699996, 79.970165413499998],
       [28.614496240600001, 77.388992481199992],
       [43.312601503800003, 72.935458646599997],
       [43.130706766899998, 84.485774436100002]], dtype=object)
In [10]:
plt.plot(xy[:, 0], xy[:, 1], 'ro')
plt.imshow(string2image(df.iloc[0]['Image']), cmap='gray')
Out[10]:

Now, let's add this to the function we wrote before.

In [11]:
def plot_faces_with_keypoints(nrows=5, ncols=5):
    """Randomly displays some faces from the training data with their keypoints."""
    selection = np.random.choice(df.index, size=(nrows*ncols), replace=False)
    image_strings = df.loc[selection]['Image']
    keypoint_cols = list(df.columns)[:-1]
    keypoints = df.loc[selection][keypoint_cols]
    fig, axes = plt.subplots(figsize=(10, 10), nrows=nrows, ncols=ncols)
    for string, (iloc, keypoint), ax in zip(image_strings, keypoints.iterrows(), axes.ravel()):
        xy = keypoint.values.reshape((15, 2))
        ax.imshow(string2image(string), cmap='gray')
        ax.plot(xy[:, 0], xy[:, 1], 'ro')
        ax.axis('off')
In [12]:
plot_faces_with_keypoints()

We can make several observations from this image:

  • some images are high resolution, some are low
  • some images have all 15 keypoints, while some have only a few

Let's do some statistics about the keypoints to investigate that last observation :

In [13]:
df.describe().loc['count'].plot.bar()
Out[13]:

What this plot tells us is that in this dataset, only 2000 images are "high quality" with all keypoints, while 5000 other images are "low quality" with only 4 keypoints labelled.

Let's start training the data with the high quality images and see how far we get.

In [14]:
fully_annotated = df.dropna()
In [15]:
fully_annotated.shape
Out[15]:
(2140, 31)

Building a Keras model

Now on to the machine learning part. Let's build a Keras model with our data. Actually, before we do that, let's do some preprocessing first, using the scikit-learn pipelines (inspired by this great post on scalable Machine Learning by Tom Augspurger).

The idea behind pipelining is that it allows you to easily keep track of the data transformations applied to our data. We need two scalings: one for the input and one for the output. Since I couldn't get the scaling to work for 3d image data, we will only use a pipeline for our outputs.

In [16]:
X = np.stack([string2image(string) for string in fully_annotated['Image']]).astype(np.float)[:, :, :, np.newaxis]
In [17]:
y = np.vstack(fully_annotated[fully_annotated.columns[:-1]].values)
In [30]:
X.shape, X.dtype
Out[30]:
((2140, 96, 96, 1), dtype('float64'))
In [31]:
y.shape, y.dtype
Out[31]:
((2140, 30), dtype('float64'))
In [32]:
X_train = X / 255.
In [33]:
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import MinMaxScaler

output_pipe = make_pipeline(
    MinMaxScaler(feature_range=(-1, 1))
)

y_train = output_pipe.fit_transform(y)

In this case, the pipelining process is, how to say this, not very spectacular. Let's move on and train a Keras model! We will start with a simple model, as found in this blog post with a fully connected layer and 100 hidden units.

In [34]:
from keras.models import Sequential
from keras.layers import BatchNormalization, Conv2D, Activation, MaxPooling2D, Dense, GlobalAveragePooling2D
In [44]:
model = Sequential()
model.add(Dense(100, activation="relu", input_shape=(96*96,)))
model.add(Activation('relu'))
model.add(Dense(30))

Now let's compile the model and run the training.

In [47]:
from keras import optimizers

sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(optimizer=sgd, loss='mse', metrics=['accuracy'])
epochs = 200
history = model.fit(X_train.reshape(y_train.shape[0], -1), y_train, 
                 validation_split=0.2, shuffle=True, 
                 epochs=epochs, batch_size=20)
Train on 1712 samples, validate on 428 samples
Epoch 1/200
1712/1712 [==============================] - 1s - loss: 0.0173 - acc: 0.4836 - val_loss: 0.0542 - val_acc: 0.0794
Epoch 2/200
1712/1712 [==============================] - 1s - loss: 0.0170 - acc: 0.4842 - val_loss: 0.0554 - val_acc: 0.0911
Epoch 3/200
1712/1712 [==============================] - 1s - loss: 0.0167 - acc: 0.4842 - val_loss: 0.0546 - val_acc: 0.0888
Epoch 4/200
1712/1712 [==============================] - ETA: 0s - loss: 0.0165 - acc: 0.478 - 1s - loss: 0.0165 - acc: 0.4772 - val_loss: 0.0531 - val_acc: 0.1262
Epoch 5/200
1712/1712 [==============================] - 1s - loss: 0.0161 - acc: 0.4959 - val_loss: 0.0623 - val_acc: 0.0864
Epoch 6/200
1712/1712 [==============================] - 1s - loss: 0.0159 - acc: 0.4813 - val_loss: 0.0523 - val_acc: 0.1075
Epoch 7/200
1712/1712 [==============================] - 1s - loss: 0.0157 - acc: 0.4942 - val_loss: 0.0526 - val_acc: 0.1145
Epoch 8/200
1712/1712 [==============================] - 1s - loss: 0.0155 - acc: 0.4953 - val_loss: 0.0537 - val_acc: 0.1051
Epoch 9/200
1712/1712 [==============================] - 1s - loss: 0.0152 - acc: 0.5006 - val_loss: 0.0503 - val_acc: 0.1472
Epoch 10/200
1712/1712 [==============================] - 1s - loss: 0.0151 - acc: 0.5006 - val_loss: 0.0531 - val_acc: 0.1168
Epoch 11/200
1712/1712 [==============================] - 1s - loss: 0.0148 - acc: 0.5041 - val_loss: 0.0530 - val_acc: 0.11450.496 - ETA: 0s - loss: 0.0151 - a
Epoch 12/200
1712/1712 [==============================] - 1s - loss: 0.0148 - acc: 0.4994 - val_loss: 0.0521 - val_acc: 0.1098
Epoch 13/200
1712/1712 [==============================] - 1s - loss: 0.0144 - acc: 0.5076 - val_loss: 0.0505 - val_acc: 0.1238
Epoch 14/200
1712/1712 [==============================] - 1s - loss: 0.0145 - acc: 0.5152 - val_loss: 0.0505 - val_acc: 0.1215
Epoch 15/200
1712/1712 [==============================] - 1s - loss: 0.0144 - acc: 0.5088 - val_loss: 0.0501 - val_acc: 0.1145
Epoch 16/200
1712/1712 [==============================] - 2s - loss: 0.0138 - acc: 0.5140 - val_loss: 0.0491 - val_acc: 0.1355
Epoch 17/200
1712/1712 [==============================] - 1s - loss: 0.0140 - acc: 0.5239 - val_loss: 0.0562 - val_acc: 0.1098
Epoch 18/200
1712/1712 [==============================] - 1s - loss: 0.0138 - acc: 0.5129 - val_loss: 0.0532 - val_acc: 0.1262
Epoch 19/200
1712/1712 [==============================] - 1s - loss: 0.0138 - acc: 0.5175 - val_loss: 0.0486 - val_acc: 0.1355
Epoch 20/200
1712/1712 [==============================] - 1s - loss: 0.0137 - acc: 0.5041 - val_loss: 0.0517 - val_acc: 0.1425
Epoch 21/200
1712/1712 [==============================] - 1s - loss: 0.0133 - acc: 0.5199 - val_loss: 0.0498 - val_acc: 0.1075
Epoch 22/200
1712/1712 [==============================] - 1s - loss: 0.0132 - acc: 0.5175 - val_loss: 0.0490 - val_acc: 0.1192
Epoch 23/200
1712/1712 [==============================] - 1s - loss: 0.0133 - acc: 0.5175 - val_loss: 0.0492 - val_acc: 0.0981
Epoch 24/200
1712/1712 [==============================] - 1s - loss: 0.0131 - acc: 0.5257 - val_loss: 0.0491 - val_acc: 0.1355
Epoch 25/200
1712/1712 [==============================] - 1s - loss: 0.0129 - acc: 0.5321 - val_loss: 0.0489 - val_acc: 0.1332
Epoch 26/200
1712/1712 [==============================] - 1s - loss: 0.0127 - acc: 0.5275 - val_loss: 0.0504 - val_acc: 0.1355
Epoch 27/200
1712/1712 [==============================] - 1s - loss: 0.0126 - acc: 0.5280 - val_loss: 0.0500 - val_acc: 0.1332
Epoch 28/200
1712/1712 [==============================] - 1s - loss: 0.0124 - acc: 0.5269 - val_loss: 0.0497 - val_acc: 0.1425
Epoch 29/200
1712/1712 [==============================] - 1s - loss: 0.0127 - acc: 0.5199 - val_loss: 0.0490 - val_acc: 0.1332
Epoch 30/200
1712/1712 [==============================] - 1s - loss: 0.0124 - acc: 0.5298 - val_loss: 0.0507 - val_acc: 0.1495
Epoch 31/200
1712/1712 [==============================] - 1s - loss: 0.0124 - acc: 0.5263 - val_loss: 0.0486 - val_acc: 0.1589
Epoch 32/200
1712/1712 [==============================] - 1s - loss: 0.0122 - acc: 0.5362 - val_loss: 0.0482 - val_acc: 0.1355
Epoch 33/200
1712/1712 [==============================] - 1s - loss: 0.0119 - acc: 0.5421 - val_loss: 0.0485 - val_acc: 0.1449
Epoch 34/200
1712/1712 [==============================] - 1s - loss: 0.0118 - acc: 0.5315 - val_loss: 0.0478 - val_acc: 0.1519
Epoch 35/200
1712/1712 [==============================] - 1s - loss: 0.0119 - acc: 0.5333 - val_loss: 0.0494 - val_acc: 0.1495
Epoch 36/200
1712/1712 [==============================] - 1s - loss: 0.0119 - acc: 0.5339 - val_loss: 0.0489 - val_acc: 0.1425
Epoch 37/200
1712/1712 [==============================] - 1s - loss: 0.0117 - acc: 0.5386 - val_loss: 0.0486 - val_acc: 0.1262
Epoch 38/200
1712/1712 [==============================] - 1s - loss: 0.0114 - acc: 0.5403 - val_loss: 0.0484 - val_acc: 0.1449
Epoch 39/200
1712/1712 [==============================] - 1s - loss: 0.0116 - acc: 0.5164 - val_loss: 0.0498 - val_acc: 0.1495
Epoch 40/200
1712/1712 [==============================] - 1s - loss: 0.0113 - acc: 0.5310 - val_loss: 0.0475 - val_acc: 0.1706
Epoch 41/200
1712/1712 [==============================] - 1s - loss: 0.0114 - acc: 0.5450 - val_loss: 0.0500 - val_acc: 0.1449
Epoch 42/200
1712/1712 [==============================] - 1s - loss: 0.0113 - acc: 0.5327 - val_loss: 0.0478 - val_acc: 0.1729
Epoch 43/200
1712/1712 [==============================] - 1s - loss: 0.0111 - acc: 0.5315 - val_loss: 0.0480 - val_acc: 0.1332
Epoch 44/200
1712/1712 [==============================] - 1s - loss: 0.0111 - acc: 0.5286 - val_loss: 0.0518 - val_acc: 0.1192
Epoch 45/200
1712/1712 [==============================] - 1s - loss: 0.0110 - acc: 0.5204 - val_loss: 0.0476 - val_acc: 0.1425
Epoch 46/200
1712/1712 [==============================] - 1s - loss: 0.0110 - acc: 0.5345 - val_loss: 0.0504 - val_acc: 0.1238
Epoch 47/200
1712/1712 [==============================] - 1s - loss: 0.0109 - acc: 0.5292 - val_loss: 0.0480 - val_acc: 0.1449
Epoch 48/200
1712/1712 [==============================] - 1s - loss: 0.0108 - acc: 0.5356 - val_loss: 0.0473 - val_acc: 0.1355
Epoch 49/200
1712/1712 [==============================] - 1s - loss: 0.0108 - acc: 0.5269 - val_loss: 0.0505 - val_acc: 0.1121
Epoch 50/200
1712/1712 [==============================] - 1s - loss: 0.0107 - acc: 0.5386 - val_loss: 0.0480 - val_acc: 0.1729
Epoch 51/200
1712/1712 [==============================] - 1s - loss: 0.0107 - acc: 0.5199 - val_loss: 0.0507 - val_acc: 0.1238
Epoch 52/200
1712/1712 [==============================] - 1s - loss: 0.0105 - acc: 0.5315 - val_loss: 0.0484 - val_acc: 0.1472
Epoch 53/200
1712/1712 [==============================] - 1s - loss: 0.0107 - acc: 0.5333 - val_loss: 0.0481 - val_acc: 0.1986
Epoch 54/200
1712/1712 [==============================] - 1s - loss: 0.0105 - acc: 0.5339 - val_loss: 0.0475 - val_acc: 0.1425
Epoch 55/200
1712/1712 [==============================] - 1s - loss: 0.0106 - acc: 0.5409 - val_loss: 0.0493 - val_acc: 0.1121
Epoch 56/200
1712/1712 [==============================] - 1s - loss: 0.0105 - acc: 0.5310 - val_loss: 0.0494 - val_acc: 0.1355
Epoch 57/200
1712/1712 [==============================] - 1s - loss: 0.0104 - acc: 0.5350 - val_loss: 0.0481 - val_acc: 0.1729
Epoch 58/200
1712/1712 [==============================] - 1s - loss: 0.0102 - acc: 0.5415 - val_loss: 0.0501 - val_acc: 0.1752
Epoch 59/200
1712/1712 [==============================] - 1s - loss: 0.0102 - acc: 0.5333 - val_loss: 0.0479 - val_acc: 0.1379
Epoch 60/200
1712/1712 [==============================] - 1s - loss: 0.0099 - acc: 0.5415 - val_loss: 0.0477 - val_acc: 0.1332
Epoch 61/200
1712/1712 [==============================] - 1s - loss: 0.0099 - acc: 0.5327 - val_loss: 0.0501 - val_acc: 0.1332
Epoch 62/200
1712/1712 [==============================] - 1s - loss: 0.0100 - acc: 0.5461 - val_loss: 0.0510 - val_acc: 0.1449
Epoch 63/200
1712/1712 [==============================] - 1s - loss: 0.0100 - acc: 0.5368 - val_loss: 0.0497 - val_acc: 0.1355
Epoch 64/200
1712/1712 [==============================] - 1s - loss: 0.0101 - acc: 0.5362 - val_loss: 0.0504 - val_acc: 0.1495
Epoch 65/200
1712/1712 [==============================] - 1s - loss: 0.0100 - acc: 0.5333 - val_loss: 0.0491 - val_acc: 0.1589
Epoch 66/200
1712/1712 [==============================] - 1s - loss: 0.0098 - acc: 0.5310 - val_loss: 0.0482 - val_acc: 0.1519
Epoch 67/200
1712/1712 [==============================] - 1s - loss: 0.0098 - acc: 0.5292 - val_loss: 0.0484 - val_acc: 0.1472
Epoch 68/200
1712/1712 [==============================] - 1s - loss: 0.0096 - acc: 0.5321 - val_loss: 0.0486 - val_acc: 0.1846
Epoch 69/200
1712/1712 [==============================] - 1s - loss: 0.0096 - acc: 0.5356 - val_loss: 0.0475 - val_acc: 0.1752
Epoch 70/200
1712/1712 [==============================] - 1s - loss: 0.0097 - acc: 0.5275 - val_loss: 0.0490 - val_acc: 0.1308
Epoch 71/200
1712/1712 [==============================] - 1s - loss: 0.0098 - acc: 0.5275 - val_loss: 0.0475 - val_acc: 0.1776
Epoch 72/200
1712/1712 [==============================] - 1s - loss: 0.0096 - acc: 0.5345 - val_loss: 0.0484 - val_acc: 0.1379
Epoch 73/200
1712/1712 [==============================] - 1s - loss: 0.0096 - acc: 0.5239 - val_loss: 0.0475 - val_acc: 0.1682
Epoch 74/200
1712/1712 [==============================] - 1s - loss: 0.0096 - acc: 0.5386 - val_loss: 0.0484 - val_acc: 0.1472
Epoch 75/200
1712/1712 [==============================] - 1s - loss: 0.0096 - acc: 0.5333 - val_loss: 0.0540 - val_acc: 0.1542
Epoch 76/200
1712/1712 [==============================] - 1s - loss: 0.0100 - acc: 0.5193 - val_loss: 0.0479 - val_acc: 0.1612
Epoch 77/200
1712/1712 [==============================] - 1s - loss: 0.0094 - acc: 0.5368 - val_loss: 0.0475 - val_acc: 0.1565
Epoch 78/200
1712/1712 [==============================] - 1s - loss: 0.0096 - acc: 0.5339 - val_loss: 0.0479 - val_acc: 0.1799
Epoch 79/200
1712/1712 [==============================] - 1s - loss: 0.0094 - acc: 0.5403 - val_loss: 0.0482 - val_acc: 0.1636
Epoch 80/200
1712/1712 [==============================] - 1s - loss: 0.0093 - acc: 0.5350 - val_loss: 0.0471 - val_acc: 0.1846
Epoch 81/200
1712/1712 [==============================] - 1s - loss: 0.0094 - acc: 0.5391 - val_loss: 0.0494 - val_acc: 0.1752
Epoch 82/200
1712/1712 [==============================] - 1s - loss: 0.0093 - acc: 0.5239 - val_loss: 0.0490 - val_acc: 0.1472
Epoch 83/200
1712/1712 [==============================] - 1s - loss: 0.0094 - acc: 0.5350 - val_loss: 0.0491 - val_acc: 0.1379
Epoch 84/200
1712/1712 [==============================] - 1s - loss: 0.0091 - acc: 0.5397 - val_loss: 0.0485 - val_acc: 0.1542
Epoch 85/200
1712/1712 [==============================] - 1s - loss: 0.0093 - acc: 0.5339 - val_loss: 0.0490 - val_acc: 0.1589
Epoch 86/200
1712/1712 [==============================] - 1s - loss: 0.0090 - acc: 0.5280 - val_loss: 0.0487 - val_acc: 0.1636
Epoch 87/200
1712/1712 [==============================] - 1s - loss: 0.0090 - acc: 0.5345 - val_loss: 0.0476 - val_acc: 0.1355
Epoch 88/200
1712/1712 [==============================] - 1s - loss: 0.0092 - acc: 0.5350 - val_loss: 0.0475 - val_acc: 0.1565
Epoch 89/200
1712/1712 [==============================] - 1s - loss: 0.0091 - acc: 0.5368 - val_loss: 0.0495 - val_acc: 0.1729
Epoch 90/200
1712/1712 [==============================] - 1s - loss: 0.0092 - acc: 0.5450 - val_loss: 0.0485 - val_acc: 0.1542
Epoch 91/200
1712/1712 [==============================] - 2s - loss: 0.0090 - acc: 0.5298 - val_loss: 0.0479 - val_acc: 0.1589
Epoch 92/200
1712/1712 [==============================] - 1s - loss: 0.0090 - acc: 0.5426 - val_loss: 0.0489 - val_acc: 0.1893
Epoch 93/200
1712/1712 [==============================] - 1s - loss: 0.0089 - acc: 0.5321 - val_loss: 0.0470 - val_acc: 0.1519
Epoch 94/200
1712/1712 [==============================] - 1s - loss: 0.0091 - acc: 0.5345 - val_loss: 0.0482 - val_acc: 0.1706
Epoch 95/200
1712/1712 [==============================] - 1s - loss: 0.0089 - acc: 0.5537 - val_loss: 0.0483 - val_acc: 0.1636
Epoch 96/200
1712/1712 [==============================] - 1s - loss: 0.0089 - acc: 0.5345 - val_loss: 0.0496 - val_acc: 0.1542
Epoch 97/200
1712/1712 [==============================] - 1s - loss: 0.0088 - acc: 0.5386 - val_loss: 0.0489 - val_acc: 0.1542
Epoch 98/200
1712/1712 [==============================] - 1s - loss: 0.0088 - acc: 0.5444 - val_loss: 0.0522 - val_acc: 0.1285
Epoch 99/200
1712/1712 [==============================] - 1s - loss: 0.0089 - acc: 0.5386 - val_loss: 0.0485 - val_acc: 0.1542
Epoch 100/200
1712/1712 [==============================] - 1s - loss: 0.0089 - acc: 0.5426 - val_loss: 0.0480 - val_acc: 0.1822
Epoch 101/200
1712/1712 [==============================] - 1s - loss: 0.0087 - acc: 0.5374 - val_loss: 0.0503 - val_acc: 0.2103
Epoch 102/200
1712/1712 [==============================] - 1s - loss: 0.0087 - acc: 0.5356 - val_loss: 0.0489 - val_acc: 0.1776
Epoch 103/200
1712/1712 [==============================] - 1s - loss: 0.0087 - acc: 0.5461 - val_loss: 0.0484 - val_acc: 0.1706
Epoch 104/200
1712/1712 [==============================] - 1s - loss: 0.0086 - acc: 0.5409 - val_loss: 0.0478 - val_acc: 0.1449
Epoch 105/200
1712/1712 [==============================] - 1s - loss: 0.0088 - acc: 0.5368 - val_loss: 0.0479 - val_acc: 0.1612
Epoch 106/200
1712/1712 [==============================] - 1s - loss: 0.0086 - acc: 0.5362 - val_loss: 0.0483 - val_acc: 0.1636
Epoch 107/200
1712/1712 [==============================] - 1s - loss: 0.0085 - acc: 0.5356 - val_loss: 0.0483 - val_acc: 0.2056
Epoch 108/200
1712/1712 [==============================] - 1s - loss: 0.0087 - acc: 0.5421 - val_loss: 0.0489 - val_acc: 0.1659
Epoch 109/200
1712/1712 [==============================] - 1s - loss: 0.0084 - acc: 0.5421 - val_loss: 0.0478 - val_acc: 0.1682
Epoch 110/200
1712/1712 [==============================] - 1s - loss: 0.0087 - acc: 0.5415 - val_loss: 0.0488 - val_acc: 0.1799
Epoch 111/200
1712/1712 [==============================] - 1s - loss: 0.0086 - acc: 0.5421 - val_loss: 0.0480 - val_acc: 0.1565
Epoch 112/200
1712/1712 [==============================] - 1s - loss: 0.0086 - acc: 0.5426 - val_loss: 0.0502 - val_acc: 0.1495
Epoch 113/200
1712/1712 [==============================] - 1s - loss: 0.0087 - acc: 0.5350 - val_loss: 0.0481 - val_acc: 0.1846
Epoch 114/200
1712/1712 [==============================] - 1s - loss: 0.0085 - acc: 0.5426 - val_loss: 0.0489 - val_acc: 0.1565
Epoch 115/200
1712/1712 [==============================] - 1s - loss: 0.0085 - acc: 0.5374 - val_loss: 0.0488 - val_acc: 0.1776
Epoch 116/200
1712/1712 [==============================] - 1s - loss: 0.0085 - acc: 0.5456 - val_loss: 0.0481 - val_acc: 0.1869
Epoch 117/200
1712/1712 [==============================] - 1s - loss: 0.0084 - acc: 0.5339 - val_loss: 0.0534 - val_acc: 0.1449
Epoch 118/200
1712/1712 [==============================] - 1s - loss: 0.0085 - acc: 0.5432 - val_loss: 0.0485 - val_acc: 0.1659
Epoch 119/200
1712/1712 [==============================] - 1s - loss: 0.0085 - acc: 0.5432 - val_loss: 0.0499 - val_acc: 0.1308
Epoch 120/200
1712/1712 [==============================] - 1s - loss: 0.0083 - acc: 0.5374 - val_loss: 0.0490 - val_acc: 0.1612
Epoch 121/200
1712/1712 [==============================] - 1s - loss: 0.0085 - acc: 0.5345 - val_loss: 0.0497 - val_acc: 0.1355
Epoch 122/200
1712/1712 [==============================] - 1s - loss: 0.0085 - acc: 0.5397 - val_loss: 0.0497 - val_acc: 0.1542
Epoch 123/200
1712/1712 [==============================] - 1s - loss: 0.0085 - acc: 0.5374 - val_loss: 0.0482 - val_acc: 0.1495
Epoch 124/200
1712/1712 [==============================] - 1s - loss: 0.0083 - acc: 0.5386 - val_loss: 0.0487 - val_acc: 0.1425
Epoch 125/200
1712/1712 [==============================] - 1s - loss: 0.0083 - acc: 0.5397 - val_loss: 0.0480 - val_acc: 0.1542
Epoch 126/200
1712/1712 [==============================] - 1s - loss: 0.0084 - acc: 0.5386 - val_loss: 0.0496 - val_acc: 0.1752
Epoch 127/200
1712/1712 [==============================] - 1s - loss: 0.0082 - acc: 0.5350 - val_loss: 0.0502 - val_acc: 0.1355
Epoch 128/200
1712/1712 [==============================] - 1s - loss: 0.0083 - acc: 0.5356 - val_loss: 0.0484 - val_acc: 0.1939
Epoch 129/200
1712/1712 [==============================] - 1s - loss: 0.0081 - acc: 0.5380 - val_loss: 0.0526 - val_acc: 0.1659
Epoch 130/200
1712/1712 [==============================] - 1s - loss: 0.0083 - acc: 0.5409 - val_loss: 0.0484 - val_acc: 0.1589
Epoch 131/200
1712/1712 [==============================] - 1s - loss: 0.0081 - acc: 0.5491 - val_loss: 0.0494 - val_acc: 0.1425
Epoch 132/200
1712/1712 [==============================] - 1s - loss: 0.0081 - acc: 0.5473 - val_loss: 0.0487 - val_acc: 0.1706
Epoch 133/200
1712/1712 [==============================] - 1s - loss: 0.0082 - acc: 0.5508 - val_loss: 0.0508 - val_acc: 0.1636
Epoch 134/200
1712/1712 [==============================] - 1s - loss: 0.0081 - acc: 0.5643 - val_loss: 0.0525 - val_acc: 0.1542
Epoch 135/200
1712/1712 [==============================] - 1s - loss: 0.0082 - acc: 0.5403 - val_loss: 0.0495 - val_acc: 0.1822
Epoch 136/200
1712/1712 [==============================] - 1s - loss: 0.0082 - acc: 0.5456 - val_loss: 0.0479 - val_acc: 0.1893
Epoch 137/200
1712/1712 [==============================] - 1s - loss: 0.0082 - acc: 0.5432 - val_loss: 0.0492 - val_acc: 0.1729
Epoch 138/200
1712/1712 [==============================] - 1s - loss: 0.0082 - acc: 0.5461 - val_loss: 0.0496 - val_acc: 0.1636
Epoch 139/200
1712/1712 [==============================] - 1s - loss: 0.0081 - acc: 0.5380 - val_loss: 0.0481 - val_acc: 0.1495
Epoch 140/200
1712/1712 [==============================] - 1s - loss: 0.0080 - acc: 0.5461 - val_loss: 0.0492 - val_acc: 0.1565
Epoch 141/200
1712/1712 [==============================] - 1s - loss: 0.0080 - acc: 0.5315 - val_loss: 0.0511 - val_acc: 0.1612
Epoch 142/200
1712/1712 [==============================] - 1s - loss: 0.0079 - acc: 0.5444 - val_loss: 0.0492 - val_acc: 0.1495
Epoch 143/200
1712/1712 [==============================] - 1s - loss: 0.0080 - acc: 0.5426 - val_loss: 0.0484 - val_acc: 0.1565
Epoch 144/200
1712/1712 [==============================] - 1s - loss: 0.0079 - acc: 0.5496 - val_loss: 0.0487 - val_acc: 0.1565
Epoch 145/200
1712/1712 [==============================] - 1s - loss: 0.0080 - acc: 0.5356 - val_loss: 0.0495 - val_acc: 0.1799
Epoch 146/200
1712/1712 [==============================] - 1s - loss: 0.0079 - acc: 0.5532 - val_loss: 0.0526 - val_acc: 0.1519
Epoch 147/200
1712/1712 [==============================] - 1s - loss: 0.0082 - acc: 0.5356 - val_loss: 0.0499 - val_acc: 0.1799
Epoch 148/200
1712/1712 [==============================] - 1s - loss: 0.0080 - acc: 0.5479 - val_loss: 0.0490 - val_acc: 0.1893
Epoch 149/200
1712/1712 [==============================] - 1s - loss: 0.0082 - acc: 0.5491 - val_loss: 0.0508 - val_acc: 0.1379
Epoch 150/200
1712/1712 [==============================] - 1s - loss: 0.0083 - acc: 0.5374 - val_loss: 0.0505 - val_acc: 0.1565
Epoch 151/200
1712/1712 [==============================] - 1s - loss: 0.0080 - acc: 0.5537 - val_loss: 0.0487 - val_acc: 0.1846
Epoch 152/200
1712/1712 [==============================] - 1s - loss: 0.0078 - acc: 0.5374 - val_loss: 0.0494 - val_acc: 0.1659
Epoch 153/200
1712/1712 [==============================] - 1s - loss: 0.0081 - acc: 0.5397 - val_loss: 0.0482 - val_acc: 0.1636
Epoch 154/200
1712/1712 [==============================] - 1s - loss: 0.0079 - acc: 0.5450 - val_loss: 0.0480 - val_acc: 0.1776
Epoch 155/200
1712/1712 [==============================] - 1s - loss: 0.0081 - acc: 0.5520 - val_loss: 0.0518 - val_acc: 0.1402
Epoch 156/200
1712/1712 [==============================] - 1s - loss: 0.0079 - acc: 0.5485 - val_loss: 0.0508 - val_acc: 0.1565
Epoch 157/200
1712/1712 [==============================] - 1s - loss: 0.0078 - acc: 0.5491 - val_loss: 0.0494 - val_acc: 0.1519
Epoch 158/200
1712/1712 [==============================] - 2s - loss: 0.0079 - acc: 0.5496 - val_loss: 0.0493 - val_acc: 0.1752
Epoch 159/200
1712/1712 [==============================] - 1s - loss: 0.0078 - acc: 0.5520 - val_loss: 0.0481 - val_acc: 0.1776
Epoch 160/200
1712/1712 [==============================] - ETA: 0s - loss: 0.0080 - acc: 0.550 - 1s - loss: 0.0080 - acc: 0.5485 - val_loss: 0.0496 - val_acc: 0.1706
Epoch 161/200
1712/1712 [==============================] - 1s - loss: 0.0078 - acc: 0.5380 - val_loss: 0.0498 - val_acc: 0.1799
Epoch 162/200
1712/1712 [==============================] - 1s - loss: 0.0079 - acc: 0.5310 - val_loss: 0.0504 - val_acc: 0.2220
Epoch 163/200
1712/1712 [==============================] - 1s - loss: 0.0079 - acc: 0.5421 - val_loss: 0.0495 - val_acc: 0.1542
Epoch 164/200
1712/1712 [==============================] - 1s - loss: 0.0078 - acc: 0.5555 - val_loss: 0.0502 - val_acc: 0.1636
Epoch 165/200
1712/1712 [==============================] - 1s - loss: 0.0077 - acc: 0.5473 - val_loss: 0.0503 - val_acc: 0.1869
Epoch 166/200
1712/1712 [==============================] - 1s - loss: 0.0076 - acc: 0.5467 - val_loss: 0.0490 - val_acc: 0.1472
Epoch 167/200
1712/1712 [==============================] - 1s - loss: 0.0078 - acc: 0.5397 - val_loss: 0.0496 - val_acc: 0.1659
Epoch 168/200
1712/1712 [==============================] - 1s - loss: 0.0076 - acc: 0.5485 - val_loss: 0.0492 - val_acc: 0.1682
Epoch 169/200
1712/1712 [==============================] - 1s - loss: 0.0078 - acc: 0.5514 - val_loss: 0.0493 - val_acc: 0.1776
Epoch 170/200
1712/1712 [==============================] - 1s - loss: 0.0076 - acc: 0.5514 - val_loss: 0.0487 - val_acc: 0.1822
Epoch 171/200
1712/1712 [==============================] - 1s - loss: 0.0077 - acc: 0.5461 - val_loss: 0.0495 - val_acc: 0.1612
Epoch 172/200
1712/1712 [==============================] - 1s - loss: 0.0076 - acc: 0.5590 - val_loss: 0.0501 - val_acc: 0.1612
Epoch 173/200
1712/1712 [==============================] - 1s - loss: 0.0078 - acc: 0.5380 - val_loss: 0.0508 - val_acc: 0.1565
Epoch 174/200
1712/1712 [==============================] - 1s - loss: 0.0076 - acc: 0.5421 - val_loss: 0.0505 - val_acc: 0.1636
Epoch 175/200
1712/1712 [==============================] - 1s - loss: 0.0076 - acc: 0.5479 - val_loss: 0.0487 - val_acc: 0.2033
Epoch 176/200
1712/1712 [==============================] - 1s - loss: 0.0077 - acc: 0.5467 - val_loss: 0.0503 - val_acc: 0.1589
Epoch 177/200
1712/1712 [==============================] - 1s - loss: 0.0080 - acc: 0.5508 - val_loss: 0.0487 - val_acc: 0.1589
Epoch 178/200
1712/1712 [==============================] - 1s - loss: 0.0075 - acc: 0.5555 - val_loss: 0.0516 - val_acc: 0.1612
Epoch 179/200
1712/1712 [==============================] - 2s - loss: 0.0076 - acc: 0.5543 - val_loss: 0.0499 - val_acc: 0.2220
Epoch 180/200
1712/1712 [==============================] - 1s - loss: 0.0077 - acc: 0.5543 - val_loss: 0.0500 - val_acc: 0.1449
Epoch 181/200
1712/1712 [==============================] - 1s - loss: 0.0075 - acc: 0.5532 - val_loss: 0.0487 - val_acc: 0.1636
Epoch 182/200
1712/1712 [==============================] - 1s - loss: 0.0078 - acc: 0.5456 - val_loss: 0.0490 - val_acc: 0.2009
Epoch 183/200
1712/1712 [==============================] - 1s - loss: 0.0078 - acc: 0.5467 - val_loss: 0.0510 - val_acc: 0.1612
Epoch 184/200
1712/1712 [==============================] - 1s - loss: 0.0077 - acc: 0.5456 - val_loss: 0.0536 - val_acc: 0.1963
Epoch 185/200
1712/1712 [==============================] - 1s - loss: 0.0075 - acc: 0.5502 - val_loss: 0.0493 - val_acc: 0.1565
Epoch 186/200
1712/1712 [==============================] - 1s - loss: 0.0074 - acc: 0.5473 - val_loss: 0.0496 - val_acc: 0.1729
Epoch 187/200
1712/1712 [==============================] - 1s - loss: 0.0075 - acc: 0.5479 - val_loss: 0.0496 - val_acc: 0.1636
Epoch 188/200
1712/1712 [==============================] - 1s - loss: 0.0076 - acc: 0.5461 - val_loss: 0.0501 - val_acc: 0.1706
Epoch 189/200
1712/1712 [==============================] - 1s - loss: 0.0078 - acc: 0.5397 - val_loss: 0.0504 - val_acc: 0.1542
Epoch 190/200
1712/1712 [==============================] - 1s - loss: 0.0074 - acc: 0.5537 - val_loss: 0.0502 - val_acc: 0.1589
Epoch 191/200
1712/1712 [==============================] - 1s - loss: 0.0074 - acc: 0.5421 - val_loss: 0.0490 - val_acc: 0.1659
Epoch 192/200
1712/1712 [==============================] - 1s - loss: 0.0076 - acc: 0.5461 - val_loss: 0.0484 - val_acc: 0.1519
Epoch 193/200
1712/1712 [==============================] - 1s - loss: 0.0074 - acc: 0.5532 - val_loss: 0.0506 - val_acc: 0.1706
Epoch 194/200
1712/1712 [==============================] - 1s - loss: 0.0074 - acc: 0.5502 - val_loss: 0.0493 - val_acc: 0.1939
Epoch 195/200
1712/1712 [==============================] - 1s - loss: 0.0074 - acc: 0.5578 - val_loss: 0.0495 - val_acc: 0.1659
Epoch 196/200
1712/1712 [==============================] - 1s - loss: 0.0077 - acc: 0.5514 - val_loss: 0.0504 - val_acc: 0.1589
Epoch 197/200
1712/1712 [==============================] - 1s - loss: 0.0073 - acc: 0.5567 - val_loss: 0.0510 - val_acc: 0.1565
Epoch 198/200
1712/1712 [==============================] - 1s - loss: 0.0075 - acc: 0.5491 - val_loss: 0.0498 - val_acc: 0.1776
Epoch 199/200
1712/1712 [==============================] - 1s - loss: 0.0073 - acc: 0.5567 - val_loss: 0.0517 - val_acc: 0.1449
Epoch 200/200
1712/1712 [==============================] - 1s - loss: 0.0074 - acc: 0.5456 - val_loss: 0.0493 - val_acc: 0.1612

Let's plot our training curves with this model.

In [48]:
# summarize history for accuracy
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

What we see here is that with this model, the learning quickly gets on a plateau. How can we improve this? There are a lot of options:

  • adjust the optimizer settings
    • learning rate
    • batch size
    • momentum
  • change the model

However, one things that is pretty clear from the above plot is that our model overfits: the train and test losses are not comparable (the test loss is 3 times higher). Let's see what the results of the net are on some samples from our data.

In [84]:
img = X_train[0, :, :, :].reshape(1, -1)
predictions = model.predict(img)
In [85]:
img
Out[85]:
array([[ 0.93333333,  0.9254902 ,  0.92941176, ...,  0.2745098 ,
         0.29411765,  0.35294118]])
In [87]:
xy_predictions = output_pipe.inverse_transform(predictions).reshape(15, 2)
In [88]:
plt.imshow(X_train[0, :, :, 0], cmap='gray')
plt.plot(xy_predictions[:, 0], xy_predictions[:, 1], 'b*')
Out[88]:
[]
In [111]:
def plot_faces_with_keypoints_and_predictions(model, nrows=5, ncols=5, model_input='flat'):
    """Plots sampled faces with their truth and predictions."""
    selection = np.random.choice(np.arange(X.shape[0]), size=(nrows*ncols), replace=False)
    fig, axes = plt.subplots(figsize=(10, 10), nrows=nrows, ncols=ncols)
    for ind, ax in zip(selection, axes.ravel()):
        img = X_train[ind, :, :, 0]
        if model_input == 'flat':
            predictions = model.predict(img.reshape(1, -1))
        else:
            predictions = model.predict(img[np.newaxis, :, :, np.newaxis])
        xy_predictions = output_pipe.inverse_transform(predictions).reshape(15, 2)
        ax.imshow(img, cmap='gray')
        ax.plot(xy_predictions[:, 0], xy_predictions[:, 1], 'bo')
        ax.axis('off')
In [96]:
plot_faces_with_keypoints_and_predictions(model)