Ships recognition in satellite images
Satellite imagery provides unique insights into various markets, including agriculture, defense and intelligence, energy, and finance. New commercial imagery providers, such as Planet, are using constellations of small satellites to capture images of the entire Earth every day.
This flood of new imagery is outgrowing the ability for organizations to manually look at each image that gets captured, and there is a need for machine learning and computer vision algorithms to help automate the analysis process.
This notebook describes using a CNN (convolutional neural network) to distinguish whether a presented satellite image has a ship on it. I use transfer learning on two publicly available pretrained CNNs (VGG16 and Inception), building classifiers on top of them, and compare those.
Many thanks to Gotam Dahiya for his excellent notebook Ship Detection using Faster R-CNN: Part 1 and to Tensorflow team for the Classification on imbalanced data tutorial.
This is an overview of the entire project structure, or pipeline:
%config IPCompleter.greedy=True
%config Completer.use_jedi = False
import os
import tempfile
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import imgaug.augmenters as iaa # data augmentation
from tqdm import tqdm # progress bar for loops
import cv2 # OpenCV for computer vision
import seaborn as sns # heatmap plotting
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications import vgg16
from tensorflow.keras.applications import inception_v3
from tensorflow.keras.layers import Dense, Dropout # layers to build my classifiers
from tensorflow.keras import Sequential, optimizers # layer stacking and optimizers
mpl.rcParams['figure.figsize'] = (5, 5)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
The data was downloaded manually and separated into ship and no-ship folders
data_location = 'data/shipsnet'
class_names = ["no-ship","ship"]
class_name_labels = {class_name:i for i,class_name in enumerate(class_names)} # dictionary of class names: ids
num_classes = len(class_names)
class_name_labels
{'no-ship': 0, 'ship': 1}
Load images as CV2 format, normalize pixel values and read labels
images, labels = [], []
for folder in os.listdir(data_location):
label = class_name_labels[folder]
for file in tqdm(os.listdir(os.path.join(data_location,folder))):
img_path = os.path.join(data_location,folder,file)
img = cv2.imread(img_path)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) # OpenCV loads images as BGR while matplotlib will need RGB to display them
images.append(img)
labels.append(label)
pass
pass
images = np.array(images,dtype=np.float32) /255.0 # Normalize data
labels = np.array(labels,dtype=np.int32)
images.shape, labels.shape
100%|████████████████████████████████████████████████████████████████████████████| 3000/3000 [00:01<00:00, 2194.01it/s] 100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2248.16it/s]
((4000, 80, 80, 3), (4000,))
Let's see the distribution of data betweeen classes
n_labels = labels.shape[0]
_, count = np.unique(labels, return_counts=True)
df = pd.DataFrame(data = count)
df['Class Label'] = class_names
df.columns = ['Count','Class-Label']
df.set_index('Class-Label',inplace=True)
df
Count | |
---|---|
Class-Label | |
no-ship | 3000 |
ship | 1000 |
df.plot.bar(rot=0)
plt.title("distribution of images per class");
plt.pie(count,
explode=(0,0),
labels=class_names,
autopct="%1.2f%%");
For each ship image, there are 3 no-ship ones. Such skewness will have to be corrected for, otherwise the alogrithm will be able to achieve 75% accuracy just by always predicting "no-ship" which will create unnecessary bias.
Let's now see a few images. First row is no-ship, second row are ships.
columns = 4
rows = 1
# no ships are images from 0 to 3000
fig=plt.figure(figsize=(20, 20))
for i in range(0, columns*rows):
img = images[np.random.choice(3000)]
fig.add_subplot(rows, columns, i+1)
plt.imshow(img)
plt.show()
# ships are from 3000 to 4000
fig=plt.figure(figsize=(20, 20))
columns = 4
rows = 1
for i in range(0, columns*rows):
img = images[np.random.choice(1000)+3000]
fig.add_subplot(rows, columns, i+1)
plt.imshow(img)
plt.show()
It is worth noting that only a full image of the vessel is considered to belong to ship class. No-ship sometimes can contain a part of a ship.
I will split the entire dataset into train/dev/test using 70/20/10 ratio appropriate for a relatively small number of examples I have.
train_test_split
function from scikitlearn
only allows splitting into two sets (train and test), so I will use it twice, first splitting all data into train and test, and then splitting test further into dev and test proper
images_train, images_test, labels_train, labels_test = train_test_split(images, labels, test_size=0.3, random_state=42)
images_dev, images_test, labels_dev, labels_test = train_test_split(images_test, labels_test, test_size=1/3, random_state=43)
images_train.shape, images_dev.shape, images_test.shape
((2800, 80, 80, 3), (800, 80, 80, 3), (400, 80, 80, 3))
The amount of original data (only 4000 examples) is not enough to train a CNN from scratch. It makes more sense to reuse some publicly available models that have been trained over several weeks using GPU on millions of examples. Thus I will be able to transfer the existing knowledge about low-level features (borders, angles etc) and only learn the actual classification between ships and no-ships.
This will be a two step process:
Generally speaking, the classifier at step 2 does not have to be a neural network at all. One can use SVM, logistic regression or anything at all, but I will stick to NN as this is the topic of the whole notebook.
VGG16 is a deep CNN with 16 trainable layers that made history by achieving 92.7% top-5 test accuracy on ILSVRC-2014 competition by ImageNet. You can read about it in this article, but here is the architecture for general understanding.
There are two parts to a CNN model: the network architecture (displayed above) and the weights after training. For transfer learning, we'll need both, and both very conveniently are now available via keras package.
However, out of VGG16 original 16 trainable layers (not counting pooling and softmax), only the first 13 are convolutional layers, and last 3 are fully connected ones that perform classification. I will only load the convolutional layers and build my own classifier on top of that. A nice perk is that it will allow me to use any size image for input (generating different size features vectors for output), because otherwise VGG16 only accepts 256x256 px format.
vgg_conv = vgg16.VGG16(weights='imagenet', # loads the weights, first time they will be downloaded from the github but then saved locally
include_top=False, # do not load fully connected layers
input_shape=(80, 80, 3)) # shape of the images in ships dataset
WARNING:tensorflow:From C:\Users\dmitrytoda\Anaconda3\envs\tf-gpu\lib\site-packages\tensorflow\python\ops\init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor
First pass all the images through VGG16 to get feature tensors. With a 256x256x3 image, the result would be 7x7x256, but with a 80x80x3 ones that I am using it will be smaller, some kind of n x n x 256, where n < 7. This step takes a few minutes.
vgg_train, vgg_dev, vgg_test = vgg_conv.predict(images_train), vgg_conv.predict(images_dev), vgg_conv.predict(images_test)
vgg_train.shape, vgg_dev.shape, vgg_test.shape
((2800, 2, 2, 512), (800, 2, 2, 512), (400, 2, 2, 512))
Okay resulting tensors are 2x2x512. Need to flatten them into feature vectors because this is what fully connected layers like.
def flatten(tensor): # this function will be reused when flattening Inception features that have a different shape
return np.reshape(tensor, (tensor.shape[0], tensor.shape[1] * tensor.shape[2] * tensor.shape[3]))
vgg_train, vgg_dev, vgg_test = flatten(vgg_train), flatten(vgg_dev), flatten(vgg_test)
vgg_train.shape, vgg_dev.shape, vgg_test.shape
((2800, 2048), (800, 2048), (400, 2048))
Perfect, now each image is represented with a 2048-long feature vector.
While VGG is conceptually a bunch of convolutional layers stacked upon each other, Inception has more complicated structure. Fully describing it is out of scope of this notebook (you can read a good overview here). I will only mention that it creates a "sparsely connected architecture" by stacking up a new kind of bricks called "Inception modules":
Among other things, this makes the network less expensive to train and less prone to overfitting.
For my particular task, I will just load the architecture and weights (without the top fully connected layers) and use them to calculate feature vectors as with VGG.
incep_conv = inception_v3.InceptionV3(weights='imagenet',
include_top=False,
input_shape=(80, 80, 3))
By analogy with VGG, first get features as tensors:
incep_train, incep_dev, incep_test = incep_conv.predict(images_train), incep_conv.predict(images_dev), incep_conv.predict(images_test)
incep_train.shape, incep_dev.shape, incep_test.shape
((2800, 1, 1, 2048), (800, 1, 1, 2048), (400, 1, 1, 2048))
It's interesting that both VGG16 and Inception end up with the same size features, although shaped differently.
Now flatten:
incep_train, incep_dev, incep_test = flatten(incep_train), flatten(incep_dev), flatten(incep_test)
incep_train.shape, incep_dev.shape, incep_test.shape
((2800, 2048), (800, 2048), (400, 2048))
Due to the lucky coincidence that VGG16 and Inception output feature vectors of the same size, I can use the same classifier architurecture, feed it both outputs and compare.
Generally speaking, my classifier does not even have to be a neural network. I can feed CNN-generated features into an SVM or even logistic regression. However, here I do use a neural network of the following architecture:
METRICS = [
keras.metrics.TruePositives(name='tp'),
keras.metrics.FalsePositives(name='fp'),
keras.metrics.TrueNegatives(name='tn'),
keras.metrics.FalseNegatives(name='fn'),
keras.metrics.BinaryAccuracy(name='accuracy'),
keras.metrics.Precision(name='precision'),
keras.metrics.Recall(name='recall'),
keras.metrics.AUC(name='auc'),
]
def make_model(metrics=METRICS):
model = Sequential()
model.add(Dense(1024, activation='relu', input_dim=2048, kernel_initializer=tf.keras.initializers.glorot_uniform(seed=42)))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid', kernel_initializer=tf.keras.initializers.glorot_uniform(seed=40)))
model.compile(optimizer=optimizers.RMSprop(lr=2e-4),
loss='binary_crossentropy', # Using binary crossentropy cuz it's a binary classification problem
metrics=metrics)
return model
Notice that there are a few metrics defined above that can be computed by the model that will be helpful when evaluating the performance.
$\frac{\text{true samples}}{\text{total samples}}$
$\frac{\text{true positives}}{\text{true positives + false positives}}$
$\frac{\text{true positives}}{\text{true positives + false negatives}}$
$2 ⋅ \frac{\text{precision ⋅ recall}}{\text{precision + recall}}$
Note: Accuracy is may not be the best metric for this task. You can achieve 75% accuracy by always predicting no-ship. However, it cannot be discarded completely. Read more:
Create a classifier using the previously defined function. This baseline version will probably not work very well and will be improved in later sections.
EPOCHS = 100
BATCH_SIZE = 32
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_auc',
verbose=1,
patience=10,
mode='max',
restore_best_weights=True)
np.random.seed(32)
tf.compat.v1.set_random_seed(33)
model = make_model()
model.summary()
Model: "sequential_28" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_54 (Dense) (None, 1024) 2098176 _________________________________________________________________ dropout_27 (Dropout) (None, 1024) 0 _________________________________________________________________ dense_55 (Dense) (None, 1) 1025 ================================================================= Total params: 2,099,201 Trainable params: 2,099,201 Non-trainable params: 0 _________________________________________________________________
Saving initial weights to a temp file so that I can load them and re-train from the same point later.
initial_weights = os.path.join(tempfile.mkdtemp(), 'initial_weights')
model.save_weights(initial_weights)
In this section, I will train two classifiers, using VGG16 and Inception embeddings as input, without correcting for the imbalanced classes.
# Train VGG classifier (or rather my own classifier on top of VGG convolutional layers)
vgg_model = make_model()
vgg_model.load_weights(initial_weights)
baseline_vgg_history = vgg_model.fit(
vgg_train,
labels_train,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks=[early_stopping],
validation_data=(vgg_dev, labels_dev))
Train on 2800 samples, validate on 800 samples Epoch 1/100 2800/2800 [==============================] - 5s 2ms/sample - loss: 0.1390 - tp: 626.0000 - fp: 56.0000 - tn: 2028.0000 - fn: 90.0000 - accuracy: 0.9479 - precision: 0.9179 - recall: 0.8743 - auc: 0.9851 - val_loss: 0.1117 - val_tp: 188.0000 - val_fp: 34.0000 - val_tn: 574.0000 - val_fn: 4.0000 - val_accuracy: 0.9525 - val_precision: 0.8468 - val_recall: 0.9792 - val_auc: 0.9955 Epoch 2/100 2800/2800 [==============================] - 2s 747us/sample - loss: 0.0558 - tp: 683.0000 - fp: 18.0000 - tn: 2066.0000 - fn: 33.0000 - accuracy: 0.9818 - precision: 0.9743 - recall: 0.9539 - auc: 0.9976 - val_loss: 0.0709 - val_tp: 189.0000 - val_fp: 20.0000 - val_tn: 588.0000 - val_fn: 3.0000 - val_accuracy: 0.9712 - val_precision: 0.9043 - val_recall: 0.9844 - val_auc: 0.9977 Epoch 3/100 2800/2800 [==============================] - 2s 748us/sample - loss: 0.0380 - tp: 694.0000 - fp: 17.0000 - tn: 2067.0000 - fn: 22.0000 - accuracy: 0.9861 - precision: 0.9761 - recall: 0.9693 - auc: 0.9988 - val_loss: 0.0456 - val_tp: 188.0000 - val_fp: 12.0000 - val_tn: 596.0000 - val_fn: 4.0000 - val_accuracy: 0.9800 - val_precision: 0.9400 - val_recall: 0.9792 - val_auc: 0.9984 Epoch 4/100 2800/2800 [==============================] - 2s 739us/sample - loss: 0.0302 - tp: 701.0000 - fp: 11.0000 - tn: 2073.0000 - fn: 15.0000 - accuracy: 0.9907 - precision: 0.9846 - recall: 0.9791 - auc: 0.9993 - val_loss: 0.0502 - val_tp: 189.0000 - val_fp: 18.0000 - val_tn: 590.0000 - val_fn: 3.0000 - val_accuracy: 0.9737 - val_precision: 0.9130 - val_recall: 0.9844 - val_auc: 0.9988 Epoch 5/100 2800/2800 [==============================] - 2s 729us/sample - loss: 0.0240 - tp: 701.0000 - fp: 12.0000 - tn: 2072.0000 - fn: 15.0000 - accuracy: 0.9904 - precision: 0.9832 - recall: 0.9791 - auc: 0.9996 - val_loss: 0.0362 - val_tp: 183.0000 - val_fp: 0.0000e+00 - val_tn: 608.0000 - val_fn: 9.0000 - val_accuracy: 0.9887 - val_precision: 1.0000 - val_recall: 0.9531 - val_auc: 0.9968 Epoch 6/100 2800/2800 [==============================] - 2s 726us/sample - loss: 0.0199 - tp: 706.0000 - fp: 9.0000 - tn: 2075.0000 - fn: 10.0000 - accuracy: 0.9932 - precision: 0.9874 - recall: 0.9860 - auc: 0.9997 - val_loss: 0.0360 - val_tp: 189.0000 - val_fp: 8.0000 - val_tn: 600.0000 - val_fn: 3.0000 - val_accuracy: 0.9862 - val_precision: 0.9594 - val_recall: 0.9844 - val_auc: 0.9990 Epoch 7/100 2800/2800 [==============================] - 2s 723us/sample - loss: 0.0138 - tp: 708.0000 - fp: 3.0000 - tn: 2081.0000 - fn: 8.0000 - accuracy: 0.9961 - precision: 0.9958 - recall: 0.9888 - auc: 0.9997 - val_loss: 0.0278 - val_tp: 189.0000 - val_fp: 3.0000 - val_tn: 605.0000 - val_fn: 3.0000 - val_accuracy: 0.9925 - val_precision: 0.9844 - val_recall: 0.9844 - val_auc: 0.9992 Epoch 8/100 2800/2800 [==============================] - 2s 731us/sample - loss: 0.0126 - tp: 708.0000 - fp: 3.0000 - tn: 2081.0000 - fn: 8.0000 - accuracy: 0.9961 - precision: 0.9958 - recall: 0.9888 - auc: 0.9999 - val_loss: 0.0247 - val_tp: 189.0000 - val_fp: 2.0000 - val_tn: 606.0000 - val_fn: 3.0000 - val_accuracy: 0.9937 - val_precision: 0.9895 - val_recall: 0.9844 - val_auc: 0.9969 Epoch 9/100 2800/2800 [==============================] - 2s 752us/sample - loss: 0.0110 - tp: 711.0000 - fp: 4.0000 - tn: 2080.0000 - fn: 5.0000 - accuracy: 0.9968 - precision: 0.9944 - recall: 0.9930 - auc: 0.9999 - val_loss: 0.0268 - val_tp: 189.0000 - val_fp: 2.0000 - val_tn: 606.0000 - val_fn: 3.0000 - val_accuracy: 0.9937 - val_precision: 0.9895 - val_recall: 0.9844 - val_auc: 0.9968 Epoch 10/100 2800/2800 [==============================] - 2s 774us/sample - loss: 0.0096 - tp: 709.0000 - fp: 4.0000 - tn: 2080.0000 - fn: 7.0000 - accuracy: 0.9961 - precision: 0.9944 - recall: 0.9902 - auc: 0.9999 - val_loss: 0.0603 - val_tp: 191.0000 - val_fp: 20.0000 - val_tn: 588.0000 - val_fn: 1.0000 - val_accuracy: 0.9737 - val_precision: 0.9052 - val_recall: 0.9948 - val_auc: 0.9992 Epoch 11/100 2800/2800 [==============================] - 2s 776us/sample - loss: 0.0092 - tp: 710.0000 - fp: 5.0000 - tn: 2079.0000 - fn: 6.0000 - accuracy: 0.9961 - precision: 0.9930 - recall: 0.9916 - auc: 1.0000 - val_loss: 0.0250 - val_tp: 189.0000 - val_fp: 3.0000 - val_tn: 605.0000 - val_fn: 3.0000 - val_accuracy: 0.9925 - val_precision: 0.9844 - val_recall: 0.9844 - val_auc: 0.9970 Epoch 12/100 2800/2800 [==============================] - 2s 763us/sample - loss: 0.0065 - tp: 714.0000 - fp: 4.0000 - tn: 2080.0000 - fn: 2.0000 - accuracy: 0.9979 - precision: 0.9944 - recall: 0.9972 - auc: 1.0000 - val_loss: 0.0260 - val_tp: 190.0000 - val_fp: 4.0000 - val_tn: 604.0000 - val_fn: 2.0000 - val_accuracy: 0.9925 - val_precision: 0.9794 - val_recall: 0.9896 - val_auc: 0.9970 Epoch 13/100 2800/2800 [==============================] - 2s 764us/sample - loss: 0.0086 - tp: 711.0000 - fp: 4.0000 - tn: 2080.0000 - fn: 5.0000 - accuracy: 0.9968 - precision: 0.9944 - recall: 0.9930 - auc: 1.0000 - val_loss: 0.0288 - val_tp: 190.0000 - val_fp: 4.0000 - val_tn: 604.0000 - val_fn: 2.0000 - val_accuracy: 0.9925 - val_precision: 0.9794 - val_recall: 0.9896 - val_auc: 0.9969 Epoch 14/100 2800/2800 [==============================] - 2s 769us/sample - loss: 0.0064 - tp: 713.0000 - fp: 3.0000 - tn: 2081.0000 - fn: 3.0000 - accuracy: 0.9979 - precision: 0.9958 - recall: 0.9958 - auc: 1.0000 - val_loss: 0.0256 - val_tp: 189.0000 - val_fp: 2.0000 - val_tn: 606.0000 - val_fn: 3.0000 - val_accuracy: 0.9937 - val_precision: 0.9895 - val_recall: 0.9844 - val_auc: 0.9970 Epoch 15/100 2800/2800 [==============================] - 2s 742us/sample - loss: 0.0051 - tp: 714.0000 - fp: 1.0000 - tn: 2083.0000 - fn: 2.0000 - accuracy: 0.9989 - precision: 0.9986 - recall: 0.9972 - auc: 1.0000 - val_loss: 0.0283 - val_tp: 189.0000 - val_fp: 4.0000 - val_tn: 604.0000 - val_fn: 3.0000 - val_accuracy: 0.9912 - val_precision: 0.9793 - val_recall: 0.9844 - val_auc: 0.9969 Epoch 16/100 2800/2800 [==============================] - 2s 723us/sample - loss: 0.0037 - tp: 715.0000 - fp: 3.0000 - tn: 2081.0000 - fn: 1.0000 - accuracy: 0.9986 - precision: 0.9958 - recall: 0.9986 - auc: 1.0000 - val_loss: 0.0272 - val_tp: 189.0000 - val_fp: 2.0000 - val_tn: 606.0000 - val_fn: 3.0000 - val_accuracy: 0.9937 - val_precision: 0.9895 - val_recall: 0.9844 - val_auc: 0.9970 Epoch 17/100 2800/2800 [==============================] - 2s 725us/sample - loss: 0.0034 - tp: 714.0000 - fp: 1.0000 - tn: 2083.0000 - fn: 2.0000 - accuracy: 0.9989 - precision: 0.9986 - recall: 0.9972 - auc: 1.0000 - val_loss: 0.0297 - val_tp: 190.0000 - val_fp: 7.0000 - val_tn: 601.0000 - val_fn: 2.0000 - val_accuracy: 0.9887 - val_precision: 0.9645 - val_recall: 0.9896 - val_auc: 0.9970 Epoch 18/100 2800/2800 [==============================] - 2s 728us/sample - loss: 0.0023 - tp: 715.0000 - fp: 1.0000 - tn: 2083.0000 - fn: 1.0000 - accuracy: 0.9993 - precision: 0.9986 - recall: 0.9986 - auc: 1.0000 - val_loss: 0.0255 - val_tp: 189.0000 - val_fp: 2.0000 - val_tn: 606.0000 - val_fn: 3.0000 - val_accuracy: 0.9937 - val_precision: 0.9895 - val_recall: 0.9844 - val_auc: 0.9971 Epoch 19/100 2800/2800 [==============================] - 2s 736us/sample - loss: 0.0032 - tp: 715.0000 - fp: 1.0000 - tn: 2083.0000 - fn: 1.0000 - accuracy: 0.9993 - precision: 0.9986 - recall: 0.9986 - auc: 1.0000 - val_loss: 0.0314 - val_tp: 190.0000 - val_fp: 5.0000 - val_tn: 603.0000 - val_fn: 2.0000 - val_accuracy: 0.9912 - val_precision: 0.9744 - val_recall: 0.9896 - val_auc: 0.9970 Epoch 20/100 2720/2800 [============================>.] - ETA: 0s - loss: 0.0021 - tp: 702.0000 - fp: 1.0000 - tn: 2017.0000 - fn: 0.0000e+00 - accuracy: 0.9996 - precision: 0.9986 - recall: 1.0000 - auc: 1.0000 Restoring model weights from the end of the best epoch. 2800/2800 [==============================] - 3s 1ms/sample - loss: 0.0022 - tp: 716.0000 - fp: 1.0000 - tn: 2083.0000 - fn: 0.0000e+00 - accuracy: 0.9996 - precision: 0.9986 - recall: 1.0000 - auc: 1.0000 - val_loss: 0.0381 - val_tp: 186.0000 - val_fp: 1.0000 - val_tn: 607.0000 - val_fn: 6.0000 - val_accuracy: 0.9912 - val_precision: 0.9947 - val_recall: 0.9688 - val_auc: 0.9945 Epoch 00020: early stopping
def plot_metrics(history):
metrics = ['loss', 'auc', 'precision', 'recall']
for n, metric in enumerate(metrics):
name = metric.replace("_"," ").capitalize()
plt.subplot(2,2,n+1)
plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train')
plt.plot(history.epoch, history.history['val_'+metric],
color=colors[0], linestyle="--", label='Val')
plt.xlabel('Epoch')
plt.ylabel(name)
if metric == 'loss':
plt.ylim([0, plt.ylim()[1]])
elif metric == 'auc':
plt.ylim([0.98,1])
else:
plt.ylim([0.89,1])
plt.legend()
mpl.rcParams['figure.figsize'] = (12, 10)
plot_metrics(baseline_vgg_history)
One way to evaluate the resulting model is to use a confusion matrix to summarize the actual vs. predicted labels where the X axis is the predicted label and the Y axis is the actual label.
A function to plot the confusion matrix:
def plot_cm(labels, predictions, p=0.5):
cm = confusion_matrix(labels, predictions > p)
plt.figure(figsize=(5,5))
sns.heatmap(cm, annot=True, fmt="d")
plt.title('Confusion matrix @{:.2f}'.format(p))
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
print('No-ships Detected (True Negatives): ', cm[0][0])
print('No-ships Incorrectly Detected (False Positives): ', cm[0][1])
print('Ships Missed (False Negatives): ', cm[1][0])
print('Ships Detected (True Positives): ', cm[1][1])
print('Total Ships: ', np.sum(cm[1]))
return(cm)
vgg_test_predictions_baseline = vgg_model.predict(vgg_test, batch_size=BATCH_SIZE)
vgg_baseline_eval = vgg_model.evaluate(vgg_test, labels_test,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(vgg_model.metrics_names, vgg_baseline_eval):
print(name, ': ', value)
print()
_ = plot_cm(labels_test, vgg_test_predictions_baseline)
loss : 0.048311117980629203 tp : 89.0 fp : 4.0 tn : 304.0 fn : 3.0 accuracy : 0.9825 precision : 0.9569892 recall : 0.9673913 auc : 0.9934359 No-ships Detected (True Negatives): 304 No-ships Incorrectly Detected (False Positives): 4 Ships Missed (False Negatives): 3 Ships Detected (True Positives): 89 Total Ships: 92
ROC curve is useful because it shows how you can tune your classifier by adjusting the prediction threshold. I will only plot it for test data as with a total of 6 models, the chart will be eventually quite cluttered already.
def plot_roc(name, labels, predictions, **kwargs):
fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)
plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
plt.xlabel('False positives [%]')
plt.ylabel('True positives [%]')
plt.xlim([-0.5,12.5])
plt.ylim([87,100.5])
plt.grid(True)
ax = plt.gca()
ax.set_aspect('equal')
plot_roc("VGG baseline", labels_test, vgg_test_predictions_baseline, color=colors[0])
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x1bfaae2edc8>
# load same initial weights as were used for VGG
incep_model = make_model()
incep_model.load_weights(initial_weights)
# train my classifier using Inception features
baseline_incep_history = incep_model.fit(
incep_train,
labels_train,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks=[early_stopping],
validation_data=(incep_dev, labels_dev))
Train on 2800 samples, validate on 800 samples Epoch 1/100 2800/2800 [==============================] - 4s 1ms/sample - loss: 0.1779 - tp: 598.0000 - fp: 96.0000 - tn: 1988.0000 - fn: 118.0000 - accuracy: 0.9236 - precision: 0.8617 - recall: 0.8352 - auc: 0.9749 - val_loss: 0.1096 - val_tp: 186.0000 - val_fp: 27.0000 - val_tn: 581.0000 - val_fn: 6.0000 - val_accuracy: 0.9588 - val_precision: 0.8732 - val_recall: 0.9688 - val_auc: 0.9919 Epoch 2/100 2800/2800 [==============================] - 2s 731us/sample - loss: 0.0860 - tp: 678.0000 - fp: 37.0000 - tn: 2047.0000 - fn: 38.0000 - accuracy: 0.9732 - precision: 0.9483 - recall: 0.9469 - auc: 0.9930 - val_loss: 0.0830 - val_tp: 187.0000 - val_fp: 17.0000 - val_tn: 591.0000 - val_fn: 5.0000 - val_accuracy: 0.9725 - val_precision: 0.9167 - val_recall: 0.9740 - val_auc: 0.9941 Epoch 3/100 2800/2800 [==============================] - 2s 724us/sample - loss: 0.0549 - tp: 690.0000 - fp: 27.0000 - tn: 2057.0000 - fn: 26.0000 - accuracy: 0.9811 - precision: 0.9623 - recall: 0.9637 - auc: 0.9976 - val_loss: 0.0741 - val_tp: 185.0000 - val_fp: 16.0000 - val_tn: 592.0000 - val_fn: 7.0000 - val_accuracy: 0.9712 - val_precision: 0.9204 - val_recall: 0.9635 - val_auc: 0.9952 Epoch 4/100 2800/2800 [==============================] - 2s 724us/sample - loss: 0.0375 - tp: 700.0000 - fp: 22.0000 - tn: 2062.0000 - fn: 16.0000 - accuracy: 0.9864 - precision: 0.9695 - recall: 0.9777 - auc: 0.9989 - val_loss: 0.0740 - val_tp: 178.0000 - val_fp: 11.0000 - val_tn: 597.0000 - val_fn: 14.0000 - val_accuracy: 0.9688 - val_precision: 0.9418 - val_recall: 0.9271 - val_auc: 0.9958 Epoch 5/100 2800/2800 [==============================] - 2s 723us/sample - loss: 0.0246 - tp: 706.0000 - fp: 14.0000 - tn: 2070.0000 - fn: 10.0000 - accuracy: 0.9914 - precision: 0.9806 - recall: 0.9860 - auc: 0.9996 - val_loss: 0.0799 - val_tp: 183.0000 - val_fp: 11.0000 - val_tn: 597.0000 - val_fn: 9.0000 - val_accuracy: 0.9750 - val_precision: 0.9433 - val_recall: 0.9531 - val_auc: 0.9946 Epoch 6/100 2800/2800 [==============================] - 2s 723us/sample - loss: 0.0177 - tp: 707.0000 - fp: 7.0000 - tn: 2077.0000 - fn: 9.0000 - accuracy: 0.9943 - precision: 0.9902 - recall: 0.9874 - auc: 0.9999 - val_loss: 0.0848 - val_tp: 186.0000 - val_fp: 14.0000 - val_tn: 594.0000 - val_fn: 6.0000 - val_accuracy: 0.9750 - val_precision: 0.9300 - val_recall: 0.9688 - val_auc: 0.9943 Epoch 7/100 2800/2800 [==============================] - 2s 729us/sample - loss: 0.0130 - tp: 713.0000 - fp: 7.0000 - tn: 2077.0000 - fn: 3.0000 - accuracy: 0.9964 - precision: 0.9903 - recall: 0.9958 - auc: 0.9999 - val_loss: 0.0831 - val_tp: 182.0000 - val_fp: 11.0000 - val_tn: 597.0000 - val_fn: 10.0000 - val_accuracy: 0.9737 - val_precision: 0.9430 - val_recall: 0.9479 - val_auc: 0.9949 Epoch 8/100 2800/2800 [==============================] - 2s 725us/sample - loss: 0.0110 - tp: 714.0000 - fp: 8.0000 - tn: 2076.0000 - fn: 2.0000 - accuracy: 0.9964 - precision: 0.9889 - recall: 0.9972 - auc: 0.9999 - val_loss: 0.0856 - val_tp: 182.0000 - val_fp: 11.0000 - val_tn: 597.0000 - val_fn: 10.0000 - val_accuracy: 0.9737 - val_precision: 0.9430 - val_recall: 0.9479 - val_auc: 0.9924 Epoch 9/100 2800/2800 [==============================] - 2s 731us/sample - loss: 0.0071 - tp: 713.0000 - fp: 2.0000 - tn: 2082.0000 - fn: 3.0000 - accuracy: 0.9982 - precision: 0.9972 - recall: 0.9958 - auc: 1.0000 - val_loss: 0.0896 - val_tp: 186.0000 - val_fp: 11.0000 - val_tn: 597.0000 - val_fn: 6.0000 - val_accuracy: 0.9787 - val_precision: 0.9442 - val_recall: 0.9688 - val_auc: 0.9945 Epoch 10/100 2800/2800 [==============================] - 2s 727us/sample - loss: 0.0050 - tp: 716.0000 - fp: 3.0000 - tn: 2081.0000 - fn: 0.0000e+00 - accuracy: 0.9989 - precision: 0.9958 - recall: 1.0000 - auc: 1.0000 - val_loss: 0.1001 - val_tp: 184.0000 - val_fp: 14.0000 - val_tn: 594.0000 - val_fn: 8.0000 - val_accuracy: 0.9725 - val_precision: 0.9293 - val_recall: 0.9583 - val_auc: 0.9922 Epoch 11/100 2800/2800 [==============================] - 2s 729us/sample - loss: 0.0043 - tp: 714.0000 - fp: 1.0000 - tn: 2083.0000 - fn: 2.0000 - accuracy: 0.9989 - precision: 0.9986 - recall: 0.9972 - auc: 1.0000 - val_loss: 0.1022 - val_tp: 181.0000 - val_fp: 11.0000 - val_tn: 597.0000 - val_fn: 11.0000 - val_accuracy: 0.9725 - val_precision: 0.9427 - val_recall: 0.9427 - val_auc: 0.9894 Epoch 12/100 2800/2800 [==============================] - 2s 728us/sample - loss: 0.0030 - tp: 715.0000 - fp: 1.0000 - tn: 2083.0000 - fn: 1.0000 - accuracy: 0.9993 - precision: 0.9986 - recall: 0.9986 - auc: 1.0000 - val_loss: 0.1125 - val_tp: 187.0000 - val_fp: 12.0000 - val_tn: 596.0000 - val_fn: 5.0000 - val_accuracy: 0.9787 - val_precision: 0.9397 - val_recall: 0.9740 - val_auc: 0.9917 Epoch 13/100 2800/2800 [==============================] - 2s 739us/sample - loss: 0.0012 - tp: 716.0000 - fp: 0.0000e+00 - tn: 2084.0000 - fn: 0.0000e+00 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 - val_loss: 0.1151 - val_tp: 185.0000 - val_fp: 14.0000 - val_tn: 594.0000 - val_fn: 7.0000 - val_accuracy: 0.9737 - val_precision: 0.9296 - val_recall: 0.9635 - val_auc: 0.9917 Epoch 14/100 2784/2800 [============================>.] - ETA: 0s - loss: 0.0019 - tp: 710.0000 - fp: 0.0000e+00 - tn: 2073.0000 - fn: 1.0000 - accuracy: 0.9996 - precision: 1.0000 - recall: 0.9986 - auc: 1.0000Restoring model weights from the end of the best epoch. 2800/2800 [==============================] - 3s 1ms/sample - loss: 0.0019 - tp: 715.0000 - fp: 0.0000e+00 - tn: 2084.0000 - fn: 1.0000 - accuracy: 0.9996 - precision: 1.0000 - recall: 0.9986 - auc: 1.0000 - val_loss: 0.1182 - val_tp: 181.0000 - val_fp: 11.0000 - val_tn: 597.0000 - val_fn: 11.0000 - val_accuracy: 0.9725 - val_precision: 0.9427 - val_recall: 0.9427 - val_auc: 0.9899 Epoch 00014: early stopping
plot_metrics(baseline_incep_history)
We can see a severe case of overfitting. Training loss goes down, but dev loss goes up. Normally I would attempt to solve this with one of the following:
However, in this case I have another network that is learning well enough using VGG features for input, so I will keep that as main candidate, but keep an eye on Inception too just out of curiosity.
incep_test_predictions_baseline = incep_model.predict(incep_test, batch_size=BATCH_SIZE)
incep_baseline_eval = incep_model.evaluate(incep_test, labels_test,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(incep_model.metrics_names, incep_baseline_eval):
print(name, ': ', value)
print()
_ = plot_cm(labels_test, incep_test_predictions_baseline)
loss : 0.09401846636086703 tp : 86.0 fp : 5.0 tn : 303.0 fn : 6.0 accuracy : 0.9725 precision : 0.94505495 recall : 0.9347826 auc : 0.99347126 No-ships Detected (True Negatives): 303 No-ships Incorrectly Detected (False Positives): 5 Ships Missed (False Negatives): 6 Ships Detected (True Positives): 86 Total Ships: 92
As could be expected from training plots, the performance is not as good as VGG option, but still pretty decent. Maybe this classification task is just relatively easy by itself.
plot_roc("VGG Baseline", labels_test, vgg_test_predictions_baseline, color=colors[0])
plot_roc("Inception Baseline", labels_test, incep_test_predictions_baseline, color=colors[1])
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x1c040576c88>
One option to improve the performance of a model on imbalanced data is to assign higher weight to the errors produced by minority class when training, making model "pay more attention" to the minority class.
# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / df.loc['no-ship'].Count)*(sum(df.Count))/2.0
weight_for_1 = (1 / df.loc['ship'].Count)*(sum(df.Count))/2.0
class_weight = {0: weight_for_0, 1: weight_for_1}
print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
Weight for class 0: 0.67 Weight for class 1: 2.00
Keras has a special argument when training a model to pass the class weights.
weighted_vgg_model = make_model()
weighted_vgg_model.load_weights(initial_weights)
weighted_vgg_history = weighted_vgg_model.fit(
vgg_train,
labels_train,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks=[early_stopping],
validation_data=(vgg_dev, labels_dev),
# The class weights go here
class_weight=class_weight)
Train on 2800 samples, validate on 800 samples Epoch 1/100 2800/2800 [==============================] - 4s 1ms/sample - loss: 0.1580 - tp: 672.0000 - fp: 128.0000 - tn: 1956.0000 - fn: 44.0000 - accuracy: 0.9386 - precision: 0.8400 - recall: 0.9385 - auc: 0.9863 - val_loss: 0.0795 - val_tp: 182.0000 - val_fp: 18.0000 - val_tn: 590.0000 - val_fn: 10.0000 - val_accuracy: 0.9650 - val_precision: 0.9100 - val_recall: 0.9479 - val_auc: 0.9961 Epoch 2/100 2800/2800 [==============================] - 2s 731us/sample - loss: 0.0610 - tp: 703.0000 - fp: 45.0000 - tn: 2039.0000 - fn: 13.0000 - accuracy: 0.9793 - precision: 0.9398 - recall: 0.9818 - auc: 0.9977 - val_loss: 0.0786 - val_tp: 190.0000 - val_fp: 24.0000 - val_tn: 584.0000 - val_fn: 2.0000 - val_accuracy: 0.9675 - val_precision: 0.8879 - val_recall: 0.9896 - val_auc: 0.9981 Epoch 3/100 2800/2800 [==============================] - 2s 741us/sample - loss: 0.0423 - tp: 706.0000 - fp: 28.0000 - tn: 2056.0000 - fn: 10.0000 - accuracy: 0.9864 - precision: 0.9619 - recall: 0.9860 - auc: 0.9989 - val_loss: 0.0433 - val_tp: 189.0000 - val_fp: 13.0000 - val_tn: 595.0000 - val_fn: 3.0000 - val_accuracy: 0.9800 - val_precision: 0.9356 - val_recall: 0.9844 - val_auc: 0.9988 Epoch 4/100 2800/2800 [==============================] - 2s 730us/sample - loss: 0.0335 - tp: 706.0000 - fp: 28.0000 - tn: 2056.0000 - fn: 10.0000 - accuracy: 0.9864 - precision: 0.9619 - recall: 0.9860 - auc: 0.9994 - val_loss: 0.0506 - val_tp: 189.0000 - val_fp: 17.0000 - val_tn: 591.0000 - val_fn: 3.0000 - val_accuracy: 0.9750 - val_precision: 0.9175 - val_recall: 0.9844 - val_auc: 0.9988 Epoch 5/100 2800/2800 [==============================] - 2s 731us/sample - loss: 0.0264 - tp: 710.0000 - fp: 21.0000 - tn: 2063.0000 - fn: 6.0000 - accuracy: 0.9904 - precision: 0.9713 - recall: 0.9916 - auc: 0.9994 - val_loss: 0.0342 - val_tp: 187.0000 - val_fp: 5.0000 - val_tn: 603.0000 - val_fn: 5.0000 - val_accuracy: 0.9875 - val_precision: 0.9740 - val_recall: 0.9740 - val_auc: 0.9988 Epoch 6/100 2800/2800 [==============================] - 2s 733us/sample - loss: 0.0228 - tp: 709.0000 - fp: 18.0000 - tn: 2066.0000 - fn: 7.0000 - accuracy: 0.9911 - precision: 0.9752 - recall: 0.9902 - auc: 0.9997 - val_loss: 0.0518 - val_tp: 190.0000 - val_fp: 17.0000 - val_tn: 591.0000 - val_fn: 2.0000 - val_accuracy: 0.9762 - val_precision: 0.9179 - val_recall: 0.9896 - val_auc: 0.9991 Epoch 7/100 2800/2800 [==============================] - 2s 740us/sample - loss: 0.0181 - tp: 712.0000 - fp: 15.0000 - tn: 2069.0000 - fn: 4.0000 - accuracy: 0.9932 - precision: 0.9794 - recall: 0.9944 - auc: 0.9998 - val_loss: 0.0290 - val_tp: 189.0000 - val_fp: 3.0000 - val_tn: 605.0000 - val_fn: 3.0000 - val_accuracy: 0.9925 - val_precision: 0.9844 - val_recall: 0.9844 - val_auc: 0.9991 Epoch 8/100 2800/2800 [==============================] - 2s 739us/sample - loss: 0.0147 - tp: 714.0000 - fp: 12.0000 - tn: 2072.0000 - fn: 2.0000 - accuracy: 0.9950 - precision: 0.9835 - recall: 0.9972 - auc: 0.9999 - val_loss: 0.0256 - val_tp: 188.0000 - val_fp: 1.0000 - val_tn: 607.0000 - val_fn: 4.0000 - val_accuracy: 0.9937 - val_precision: 0.9947 - val_recall: 0.9792 - val_auc: 0.9969 Epoch 9/100 2800/2800 [==============================] - 2s 735us/sample - loss: 0.0143 - tp: 712.0000 - fp: 10.0000 - tn: 2074.0000 - fn: 4.0000 - accuracy: 0.9950 - precision: 0.9861 - recall: 0.9944 - auc: 0.9999 - val_loss: 0.0835 - val_tp: 191.0000 - val_fp: 24.0000 - val_tn: 584.0000 - val_fn: 1.0000 - val_accuracy: 0.9688 - val_precision: 0.8884 - val_recall: 0.9948 - val_auc: 0.9992 Epoch 10/100 2800/2800 [==============================] - 2s 734us/sample - loss: 0.0107 - tp: 715.0000 - fp: 7.0000 - tn: 2077.0000 - fn: 1.0000 - accuracy: 0.9971 - precision: 0.9903 - recall: 0.9986 - auc: 0.9997 - val_loss: 0.0278 - val_tp: 187.0000 - val_fp: 2.0000 - val_tn: 606.0000 - val_fn: 5.0000 - val_accuracy: 0.9912 - val_precision: 0.9894 - val_recall: 0.9740 - val_auc: 0.9970 Epoch 11/100 2800/2800 [==============================] - 2s 732us/sample - loss: 0.0098 - tp: 713.0000 - fp: 8.0000 - tn: 2076.0000 - fn: 3.0000 - accuracy: 0.9961 - precision: 0.9889 - recall: 0.9958 - auc: 1.0000 - val_loss: 0.0374 - val_tp: 182.0000 - val_fp: 1.0000 - val_tn: 607.0000 - val_fn: 10.0000 - val_accuracy: 0.9862 - val_precision: 0.9945 - val_recall: 0.9479 - val_auc: 0.9970 Epoch 12/100 2800/2800 [==============================] - 2s 738us/sample - loss: 0.0086 - tp: 715.0000 - fp: 8.0000 - tn: 2076.0000 - fn: 1.0000 - accuracy: 0.9968 - precision: 0.9889 - recall: 0.9986 - auc: 1.0000 - val_loss: 0.0261 - val_tp: 189.0000 - val_fp: 2.0000 - val_tn: 606.0000 - val_fn: 3.0000 - val_accuracy: 0.9937 - val_precision: 0.9895 - val_recall: 0.9844 - val_auc: 0.9970 Epoch 13/100 2800/2800 [==============================] - 2s 739us/sample - loss: 0.0078 - tp: 714.0000 - fp: 4.0000 - tn: 2080.0000 - fn: 2.0000 - accuracy: 0.9979 - precision: 0.9944 - recall: 0.9972 - auc: 1.0000 - val_loss: 0.0379 - val_tp: 190.0000 - val_fp: 9.0000 - val_tn: 599.0000 - val_fn: 2.0000 - val_accuracy: 0.9862 - val_precision: 0.9548 - val_recall: 0.9896 - val_auc: 0.9969 Epoch 14/100 2800/2800 [==============================] - 2s 737us/sample - loss: 0.0071 - tp: 714.0000 - fp: 7.0000 - tn: 2077.0000 - fn: 2.0000 - accuracy: 0.9968 - precision: 0.9903 - recall: 0.9972 - auc: 1.0000 - val_loss: 0.0311 - val_tp: 190.0000 - val_fp: 8.0000 - val_tn: 600.0000 - val_fn: 2.0000 - val_accuracy: 0.9875 - val_precision: 0.9596 - val_recall: 0.9896 - val_auc: 0.9969 Epoch 15/100 2800/2800 [==============================] - 2s 746us/sample - loss: 0.0042 - tp: 716.0000 - fp: 3.0000 - tn: 2081.0000 - fn: 0.0000e+00 - accuracy: 0.9989 - precision: 0.9958 - recall: 1.0000 - auc: 1.0000 - val_loss: 0.0288 - val_tp: 190.0000 - val_fp: 3.0000 - val_tn: 605.0000 - val_fn: 2.0000 - val_accuracy: 0.9937 - val_precision: 0.9845 - val_recall: 0.9896 - val_auc: 0.9970 Epoch 16/100 2800/2800 [==============================] - 2s 733us/sample - loss: 0.0042 - tp: 716.0000 - fp: 4.0000 - tn: 2080.0000 - fn: 0.0000e+00 - accuracy: 0.9986 - precision: 0.9944 - recall: 1.0000 - auc: 1.0000 - val_loss: 0.0292 - val_tp: 188.0000 - val_fp: 2.0000 - val_tn: 606.0000 - val_fn: 4.0000 - val_accuracy: 0.9925 - val_precision: 0.9895 - val_recall: 0.9792 - val_auc: 0.9970 Epoch 17/100 2800/2800 [==============================] - 2s 733us/sample - loss: 0.0056 - tp: 715.0000 - fp: 3.0000 - tn: 2081.0000 - fn: 1.0000 - accuracy: 0.9986 - precision: 0.9958 - recall: 0.9986 - auc: 1.0000 - val_loss: 0.0326 - val_tp: 190.0000 - val_fp: 7.0000 - val_tn: 601.0000 - val_fn: 2.0000 - val_accuracy: 0.9887 - val_precision: 0.9645 - val_recall: 0.9896 - val_auc: 0.9969 Epoch 18/100 2800/2800 [==============================] - 2s 736us/sample - loss: 0.0030 - tp: 715.0000 - fp: 1.0000 - tn: 2083.0000 - fn: 1.0000 - accuracy: 0.9993 - precision: 0.9986 - recall: 0.9986 - auc: 1.0000 - val_loss: 0.0371 - val_tp: 190.0000 - val_fp: 9.0000 - val_tn: 599.0000 - val_fn: 2.0000 - val_accuracy: 0.9862 - val_precision: 0.9548 - val_recall: 0.9896 - val_auc: 0.9970 Epoch 19/100 2720/2800 [============================>.] - ETA: 0s - loss: 0.0039 - tp: 694.0000 - fp: 4.0000 - tn: 2021.0000 - fn: 1.0000 - accuracy: 0.9982 - precision: 0.9943 - recall: 0.9986 - auc: 1.0000Restoring model weights from the end of the best epoch. 2800/2800 [==============================] - 3s 1ms/sample - loss: 0.0038 - tp: 715.0000 - fp: 4.0000 - tn: 2080.0000 - fn: 1.0000 - accuracy: 0.9982 - precision: 0.9944 - recall: 0.9986 - auc: 1.0000 - val_loss: 0.0285 - val_tp: 189.0000 - val_fp: 2.0000 - val_tn: 606.0000 - val_fn: 3.0000 - val_accuracy: 0.9937 - val_precision: 0.9895 - val_recall: 0.9844 - val_auc: 0.9970 Epoch 00019: early stopping
plot_metrics(weighted_vgg_history)
Loss plot looks similar to baseline VGG model. Precision shows a lot of noise.
vgg_test_predictions_weights = weighted_vgg_model.predict(vgg_test, batch_size=BATCH_SIZE)
vgg_weights_eval = weighted_vgg_model.evaluate(vgg_test, labels_test,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_vgg_model.metrics_names, vgg_weights_eval):
print(name, ': ', value)
print()
_ = plot_cm(labels_test, vgg_test_predictions_weights)
loss : 0.10322743088006973 tp : 91.0 fp : 12.0 tn : 296.0 fn : 1.0 accuracy : 0.9675 precision : 0.88349515 recall : 0.98913044 auc : 0.99520046 No-ships Detected (True Negatives): 296 No-ships Incorrectly Detected (False Positives): 12 Ships Missed (False Negatives): 1 Ships Detected (True Positives): 91 Total Ships: 92
Compared to baseline VGG model, this one has
plot_roc("VGG baseline", labels_test, vgg_test_predictions_baseline, color=colors[0])
plot_roc("VGG with class weights", labels_test, vgg_test_predictions_weights, color=colors[0], linestyle='--')
plot_roc("Inception baseline", labels_test, incep_test_predictions_baseline, color=colors[1])
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x1c03c938748>
Let's see what happens with the previously overfitting Inception model when the same class weights are applied.
weighted_incep_model = make_model()
weighted_incep_model.load_weights(initial_weights)
weighted_incep_history = weighted_incep_model.fit(
incep_train,
labels_train,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks=[early_stopping],
validation_data=(incep_dev, labels_dev),
# The class weights go here
class_weight=class_weight)
Train on 2800 samples, validate on 800 samples Epoch 1/100 2800/2800 [==============================] - 4s 1ms/sample - loss: 0.1903 - tp: 675.0000 - fp: 154.0000 - tn: 1930.0000 - fn: 41.0000 - accuracy: 0.9304 - precision: 0.8142 - recall: 0.9427 - auc: 0.9775 - val_loss: 0.1102 - val_tp: 190.0000 - val_fp: 29.0000 - val_tn: 579.0000 - val_fn: 2.0000 - val_accuracy: 0.9613 - val_precision: 0.8676 - val_recall: 0.9896 - val_auc: 0.9938 Epoch 2/100 2800/2800 [==============================] - 2s 738us/sample - loss: 0.0832 - tp: 702.0000 - fp: 72.0000 - tn: 2012.0000 - fn: 14.0000 - accuracy: 0.9693 - precision: 0.9070 - recall: 0.9804 - auc: 0.9949 - val_loss: 0.0798 - val_tp: 187.0000 - val_fp: 18.0000 - val_tn: 590.0000 - val_fn: 5.0000 - val_accuracy: 0.9712 - val_precision: 0.9122 - val_recall: 0.9740 - val_auc: 0.9943 Epoch 3/100 2800/2800 [==============================] - 2s 743us/sample - loss: 0.0579 - tp: 703.0000 - fp: 42.0000 - tn: 2042.0000 - fn: 13.0000 - accuracy: 0.9804 - precision: 0.9436 - recall: 0.9818 - auc: 0.9974 - val_loss: 0.0818 - val_tp: 190.0000 - val_fp: 18.0000 - val_tn: 590.0000 - val_fn: 2.0000 - val_accuracy: 0.9750 - val_precision: 0.9135 - val_recall: 0.9896 - val_auc: 0.9953 Epoch 4/100 2800/2800 [==============================] - 2s 735us/sample - loss: 0.0381 - tp: 712.0000 - fp: 30.0000 - tn: 2054.0000 - fn: 4.0000 - accuracy: 0.9879 - precision: 0.9596 - recall: 0.9944 - auc: 0.9985 - val_loss: 0.0845 - val_tp: 190.0000 - val_fp: 17.0000 - val_tn: 591.0000 - val_fn: 2.0000 - val_accuracy: 0.9762 - val_precision: 0.9179 - val_recall: 0.9896 - val_auc: 0.9951 Epoch 5/100 2800/2800 [==============================] - 2s 729us/sample - loss: 0.0284 - tp: 712.0000 - fp: 25.0000 - tn: 2059.0000 - fn: 4.0000 - accuracy: 0.9896 - precision: 0.9661 - recall: 0.9944 - auc: 0.9992 - val_loss: 0.0833 - val_tp: 187.0000 - val_fp: 14.0000 - val_tn: 594.0000 - val_fn: 5.0000 - val_accuracy: 0.9762 - val_precision: 0.9303 - val_recall: 0.9740 - val_auc: 0.9951 Epoch 6/100 2800/2800 [==============================] - 2s 734us/sample - loss: 0.0220 - tp: 714.0000 - fp: 21.0000 - tn: 2063.0000 - fn: 2.0000 - accuracy: 0.9918 - precision: 0.9714 - recall: 0.9972 - auc: 0.9995 - val_loss: 0.0919 - val_tp: 189.0000 - val_fp: 16.0000 - val_tn: 592.0000 - val_fn: 3.0000 - val_accuracy: 0.9762 - val_precision: 0.9220 - val_recall: 0.9844 - val_auc: 0.9946 Epoch 7/100 2800/2800 [==============================] - 2s 739us/sample - loss: 0.0167 - tp: 714.0000 - fp: 15.0000 - tn: 2069.0000 - fn: 2.0000 - accuracy: 0.9939 - precision: 0.9794 - recall: 0.9972 - auc: 0.9996 - val_loss: 0.0919 - val_tp: 187.0000 - val_fp: 14.0000 - val_tn: 594.0000 - val_fn: 5.0000 - val_accuracy: 0.9762 - val_precision: 0.9303 - val_recall: 0.9740 - val_auc: 0.9944 Epoch 8/100 2800/2800 [==============================] - 2s 752us/sample - loss: 0.0108 - tp: 715.0000 - fp: 9.0000 - tn: 2075.0000 - fn: 1.0000 - accuracy: 0.9964 - precision: 0.9876 - recall: 0.9986 - auc: 0.9999 - val_loss: 0.0938 - val_tp: 185.0000 - val_fp: 12.0000 - val_tn: 596.0000 - val_fn: 7.0000 - val_accuracy: 0.9762 - val_precision: 0.9391 - val_recall: 0.9635 - val_auc: 0.9950 Epoch 9/100 2800/2800 [==============================] - 2s 741us/sample - loss: 0.0100 - tp: 715.0000 - fp: 7.0000 - tn: 2077.0000 - fn: 1.0000 - accuracy: 0.9971 - precision: 0.9903 - recall: 0.9986 - auc: 0.9999 - val_loss: 0.0962 - val_tp: 187.0000 - val_fp: 13.0000 - val_tn: 595.0000 - val_fn: 5.0000 - val_accuracy: 0.9775 - val_precision: 0.9350 - val_recall: 0.9740 - val_auc: 0.9946 Epoch 10/100 2800/2800 [==============================] - 2s 736us/sample - loss: 0.0064 - tp: 715.0000 - fp: 6.0000 - tn: 2078.0000 - fn: 1.0000 - accuracy: 0.9975 - precision: 0.9917 - recall: 0.9986 - auc: 1.0000 - val_loss: 0.0948 - val_tp: 185.0000 - val_fp: 11.0000 - val_tn: 597.0000 - val_fn: 7.0000 - val_accuracy: 0.9775 - val_precision: 0.9439 - val_recall: 0.9635 - val_auc: 0.9922 Epoch 11/100 2800/2800 [==============================] - 2s 740us/sample - loss: 0.0062 - tp: 714.0000 - fp: 6.0000 - tn: 2078.0000 - fn: 2.0000 - accuracy: 0.9971 - precision: 0.9917 - recall: 0.9972 - auc: 1.0000 - val_loss: 0.1036 - val_tp: 186.0000 - val_fp: 15.0000 - val_tn: 593.0000 - val_fn: 6.0000 - val_accuracy: 0.9737 - val_precision: 0.9254 - val_recall: 0.9688 - val_auc: 0.9922 Epoch 12/100 2800/2800 [==============================] - 2s 737us/sample - loss: 0.0036 - tp: 716.0000 - fp: 2.0000 - tn: 2082.0000 - fn: 0.0000e+00 - accuracy: 0.9993 - precision: 0.9972 - recall: 1.0000 - auc: 1.0000 - val_loss: 0.1089 - val_tp: 186.0000 - val_fp: 13.0000 - val_tn: 595.0000 - val_fn: 6.0000 - val_accuracy: 0.9762 - val_precision: 0.9347 - val_recall: 0.9688 - val_auc: 0.9923 Epoch 13/100 2784/2800 [============================>.] - ETA: 0s - loss: 0.0029 - tp: 710.0000 - fp: 3.0000 - tn: 2071.0000 - fn: 0.0000e+00 - accuracy: 0.9989 - precision: 0.9958 - recall: 1.0000 - auc: 1.0000Restoring model weights from the end of the best epoch. 2800/2800 [==============================] - 3s 1ms/sample - loss: 0.0029 - tp: 716.0000 - fp: 3.0000 - tn: 2081.0000 - fn: 0.0000e+00 - accuracy: 0.9989 - precision: 0.9958 - recall: 1.0000 - auc: 1.0000 - val_loss: 0.1110 - val_tp: 187.0000 - val_fp: 12.0000 - val_tn: 596.0000 - val_fn: 5.0000 - val_accuracy: 0.9787 - val_precision: 0.9397 - val_recall: 0.9740 - val_auc: 0.9923 Epoch 00013: early stopping
plot_metrics(weighted_incep_history)
Similar overfitting picture. Inception-based model is unlikely to make it to production.
incep_test_predictions_weights = weighted_incep_model.predict(incep_test, batch_size=BATCH_SIZE)
incep_weights_eval = weighted_incep_model.evaluate(incep_test, labels_test,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_incep_model.metrics_names, incep_weights_eval):
print(name, ': ', value)
print()
_ = plot_cm(labels_test, incep_test_predictions_weights)
loss : 0.11990830048918724 tp : 87.0 fp : 8.0 tn : 300.0 fn : 5.0 accuracy : 0.9675 precision : 0.9157895 recall : 0.9456522 auc : 0.9918126 No-ships Detected (True Negatives): 300 No-ships Incorrectly Detected (False Positives): 8 Ships Missed (False Negatives): 5 Ships Detected (True Positives): 87 Total Ships: 92
plot_roc("VGG baseline", labels_test, vgg_test_predictions_baseline, color=colors[0])
plot_roc("VGG with class weights", labels_test, vgg_test_predictions_weights, color=colors[0], linestyle='--')
plot_roc("Inception baseline", labels_test, incep_test_predictions_baseline, color=colors[1])
plot_roc("Inception with class weights", labels_test, incep_test_predictions_weights, color=colors[1], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x1c03cae9708>
Another way to deal with imbalanced classes is do use data augmentation to create more examples of the minority class.
# this function will create two new images for each one it recieves
def augment_add(images, seq, labels):
augmented_images, augmented_labels = [],[]
for idx,img in tqdm(enumerate(images)):
if labels[idx] == 1:
image_aug_1 = seq.augment_image(image=img)
image_aug_2 = seq.augment_image(image=img)
augmented_images.append(image_aug_1)
augmented_images.append(image_aug_2)
augmented_labels.append(labels[idx])
augmented_labels.append(labels[idx])
pass
augmented_images = np.array(augmented_images, dtype=np.float32)
augmented_labels = np.array(augmented_labels, dtype=np.float32)
return (augmented_images, augmented_labels)
Several ways of augmentation will be applied randomly. Some other ways, like cropping or rotating the image, will not work very well for this case, because the ship usually takes almost the entire image, so we are risking of either cutting away a part of it, or creating significant black corners when rotating the square.
seq = iaa.Sequential([
iaa.Fliplr(1), # flips left-to-right
iaa.Flipud(1), # flips upside down (okay for satellite images)
iaa.LinearContrast((0.75,1.5)), # changes contrast
iaa.Multiply((0.8,1.2), per_channel=0.2), # changes brightness
], random_order=True)
Augmentation will only be applied to the training set. This way dev and test sets will be same across all models, allowing to compare apples to apples.
np.random.seed(41) # augmentation is a random process so setting the seed for reproducible results
(aug_images, aug_labels) = augment_add(images_train, seq, labels_train)
aug_images = np.concatenate([images_train, aug_images])
aug_labels = np.concatenate([labels_train, aug_labels])
2800it [00:02, 1175.65it/s]
images_train.shape, labels_train.shape, aug_images.shape, aug_labels.shape
((2800, 80, 80, 3), (2800,), (4232, 80, 80, 3), (4232,))
_, count = np.unique(aug_labels, return_counts=True)
mpl.rcParams['figure.figsize'] = (5, 5)
plt.pie(count,
explode=(0,0),
labels=class_names,
autopct="%1.2f%%")
plt.axis('equal')
plt.title("After augmentation");
An (almost) balanced dataset (original images_train
was not exactly 3:1 ratio due random nature of tran/dev/test splitting)
I now have some more training example for which the features have to be calculated and flattened again.
vgg_train_aug = vgg_conv.predict(aug_images)
vgg_train_aug.shape
(4232, 2, 2, 512)
vgg_train_aug = flatten(vgg_train_aug)
vgg_train_aug.shape
(4232, 2048)
aug_vgg_model = make_model()
aug_vgg_model.load_weights(initial_weights)
aug_vgg_history = aug_vgg_model.fit(
vgg_train_aug,
aug_labels,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks=[early_stopping],
validation_data=(vgg_dev, labels_dev))
Train on 4232 samples, validate on 800 samples Epoch 1/100 4232/4232 [==============================] - 5s 1ms/sample - loss: 0.1414 - tp: 2042.0000 - fp: 131.0000 - tn: 1953.0000 - fn: 106.0000 - accuracy: 0.9440 - precision: 0.9397 - recall: 0.9507 - auc: 0.9886 - val_loss: 0.0660 - val_tp: 184.0000 - val_fp: 14.0000 - val_tn: 594.0000 - val_fn: 8.0000 - val_accuracy: 0.9725 - val_precision: 0.9293 - val_recall: 0.9583 - val_auc: 0.9970 Epoch 2/100 4232/4232 [==============================] - 3s 722us/sample - loss: 0.0571 - tp: 2111.0000 - fp: 49.0000 - tn: 2035.0000 - fn: 37.0000 - accuracy: 0.9797 - precision: 0.9773 - recall: 0.9828 - auc: 0.9978 - val_loss: 0.0479 - val_tp: 190.0000 - val_fp: 14.0000 - val_tn: 594.0000 - val_fn: 2.0000 - val_accuracy: 0.9800 - val_precision: 0.9314 - val_recall: 0.9896 - val_auc: 0.9987 Epoch 3/100 4232/4232 [==============================] - 3s 714us/sample - loss: 0.0354 - tp: 2120.0000 - fp: 30.0000 - tn: 2054.0000 - fn: 28.0000 - accuracy: 0.9863 - precision: 0.9860 - recall: 0.9870 - auc: 0.9993 - val_loss: 0.0622 - val_tp: 191.0000 - val_fp: 19.0000 - val_tn: 589.0000 - val_fn: 1.0000 - val_accuracy: 0.9750 - val_precision: 0.9095 - val_recall: 0.9948 - val_auc: 0.9992 Epoch 4/100 4232/4232 [==============================] - 3s 715us/sample - loss: 0.0254 - tp: 2128.0000 - fp: 17.0000 - tn: 2067.0000 - fn: 20.0000 - accuracy: 0.9913 - precision: 0.9921 - recall: 0.9907 - auc: 0.9996 - val_loss: 0.0356 - val_tp: 191.0000 - val_fp: 12.0000 - val_tn: 596.0000 - val_fn: 1.0000 - val_accuracy: 0.9837 - val_precision: 0.9409 - val_recall: 0.9948 - val_auc: 0.9992 Epoch 5/100 4232/4232 [==============================] - 3s 717us/sample - loss: 0.0222 - tp: 2134.0000 - fp: 16.0000 - tn: 2068.0000 - fn: 14.0000 - accuracy: 0.9929 - precision: 0.9926 - recall: 0.9935 - auc: 0.9995 - val_loss: 0.0338 - val_tp: 191.0000 - val_fp: 11.0000 - val_tn: 597.0000 - val_fn: 1.0000 - val_accuracy: 0.9850 - val_precision: 0.9455 - val_recall: 0.9948 - val_auc: 0.9993 Epoch 6/100 4232/4232 [==============================] - 3s 716us/sample - loss: 0.0183 - tp: 2133.0000 - fp: 11.0000 - tn: 2073.0000 - fn: 15.0000 - accuracy: 0.9939 - precision: 0.9949 - recall: 0.9930 - auc: 0.9998 - val_loss: 0.0271 - val_tp: 191.0000 - val_fp: 7.0000 - val_tn: 601.0000 - val_fn: 1.0000 - val_accuracy: 0.9900 - val_precision: 0.9646 - val_recall: 0.9948 - val_auc: 0.9994 Epoch 7/100 4232/4232 [==============================] - 3s 719us/sample - loss: 0.0160 - tp: 2133.0000 - fp: 11.0000 - tn: 2073.0000 - fn: 15.0000 - accuracy: 0.9939 - precision: 0.9949 - recall: 0.9930 - auc: 0.9998 - val_loss: 0.0280 - val_tp: 190.0000 - val_fp: 6.0000 - val_tn: 602.0000 - val_fn: 2.0000 - val_accuracy: 0.9900 - val_precision: 0.9694 - val_recall: 0.9896 - val_auc: 0.9969 Epoch 8/100 4232/4232 [==============================] - 3s 723us/sample - loss: 0.0115 - tp: 2141.0000 - fp: 9.0000 - tn: 2075.0000 - fn: 7.0000 - accuracy: 0.9962 - precision: 0.9958 - recall: 0.9967 - auc: 0.9999 - val_loss: 0.0357 - val_tp: 191.0000 - val_fp: 8.0000 - val_tn: 600.0000 - val_fn: 1.0000 - val_accuracy: 0.9887 - val_precision: 0.9598 - val_recall: 0.9948 - val_auc: 0.9994 Epoch 9/100 4232/4232 [==============================] - 3s 721us/sample - loss: 0.0102 - tp: 2142.0000 - fp: 6.0000 - tn: 2078.0000 - fn: 6.0000 - accuracy: 0.9972 - precision: 0.9972 - recall: 0.9972 - auc: 0.9999 - val_loss: 0.0243 - val_tp: 190.0000 - val_fp: 4.0000 - val_tn: 604.0000 - val_fn: 2.0000 - val_accuracy: 0.9925 - val_precision: 0.9794 - val_recall: 0.9896 - val_auc: 0.9970 Epoch 10/100 4232/4232 [==============================] - 3s 725us/sample - loss: 0.0107 - tp: 2142.0000 - fp: 8.0000 - tn: 2076.0000 - fn: 6.0000 - accuracy: 0.9967 - precision: 0.9963 - recall: 0.9972 - auc: 0.9997 - val_loss: 0.0475 - val_tp: 191.0000 - val_fp: 13.0000 - val_tn: 595.0000 - val_fn: 1.0000 - val_accuracy: 0.9825 - val_precision: 0.9363 - val_recall: 0.9948 - val_auc: 0.9987 Epoch 11/100 4232/4232 [==============================] - 3s 718us/sample - loss: 0.0078 - tp: 2145.0000 - fp: 8.0000 - tn: 2076.0000 - fn: 3.0000 - accuracy: 0.9974 - precision: 0.9963 - recall: 0.9986 - auc: 1.0000 - val_loss: 0.0301 - val_tp: 187.0000 - val_fp: 3.0000 - val_tn: 605.0000 - val_fn: 5.0000 - val_accuracy: 0.9900 - val_precision: 0.9842 - val_recall: 0.9740 - val_auc: 0.9969 Epoch 12/100 4232/4232 [==============================] - 3s 700us/sample - loss: 0.0091 - tp: 2143.0000 - fp: 6.0000 - tn: 2078.0000 - fn: 5.0000 - accuracy: 0.9974 - precision: 0.9972 - recall: 0.9977 - auc: 0.9999 - val_loss: 0.0315 - val_tp: 191.0000 - val_fp: 7.0000 - val_tn: 601.0000 - val_fn: 1.0000 - val_accuracy: 0.9900 - val_precision: 0.9646 - val_recall: 0.9948 - val_auc: 0.9970 Epoch 13/100 4232/4232 [==============================] - 3s 699us/sample - loss: 0.0065 - tp: 2141.0000 - fp: 4.0000 - tn: 2080.0000 - fn: 7.0000 - accuracy: 0.9974 - precision: 0.9981 - recall: 0.9967 - auc: 1.0000 - val_loss: 0.0301 - val_tp: 191.0000 - val_fp: 4.0000 - val_tn: 604.0000 - val_fn: 1.0000 - val_accuracy: 0.9937 - val_precision: 0.9795 - val_recall: 0.9948 - val_auc: 0.9970 Epoch 14/100 4232/4232 [==============================] - 3s 718us/sample - loss: 0.0064 - tp: 2146.0000 - fp: 6.0000 - tn: 2078.0000 - fn: 2.0000 - accuracy: 0.9981 - precision: 0.9972 - recall: 0.9991 - auc: 1.0000 - val_loss: 0.0247 - val_tp: 190.0000 - val_fp: 3.0000 - val_tn: 605.0000 - val_fn: 2.0000 - val_accuracy: 0.9937 - val_precision: 0.9845 - val_recall: 0.9896 - val_auc: 0.9971 Epoch 15/100 4232/4232 [==============================] - 3s 730us/sample - loss: 0.0058 - tp: 2145.0000 - fp: 4.0000 - tn: 2080.0000 - fn: 3.0000 - accuracy: 0.9983 - precision: 0.9981 - recall: 0.9986 - auc: 1.0000 - val_loss: 0.0255 - val_tp: 190.0000 - val_fp: 3.0000 - val_tn: 605.0000 - val_fn: 2.0000 - val_accuracy: 0.9937 - val_precision: 0.9845 - val_recall: 0.9896 - val_auc: 0.9971 Epoch 16/100 4160/4232 [============================>.] - ETA: 0s - loss: 0.0064 - tp: 2105.0000 - fp: 5.0000 - tn: 2045.0000 - fn: 5.0000 - accuracy: 0.9976 - precision: 0.9976 - recall: 0.9976 - auc: 1.0000Restoring model weights from the end of the best epoch. 4232/4232 [==============================] - 4s 936us/sample - loss: 0.0063 - tp: 2143.0000 - fp: 5.0000 - tn: 2079.0000 - fn: 5.0000 - accuracy: 0.9976 - precision: 0.9977 - recall: 0.9977 - auc: 1.0000 - val_loss: 0.0281 - val_tp: 191.0000 - val_fp: 4.0000 - val_tn: 604.0000 - val_fn: 1.0000 - val_accuracy: 0.9937 - val_precision: 0.9795 - val_recall: 0.9948 - val_auc: 0.9971 Epoch 00016: early stopping
mpl.rcParams['figure.figsize'] = (12, 10)
plot_metrics(aug_vgg_history)
vgg_test_predictions_aug = aug_vgg_model.predict(vgg_test, batch_size=BATCH_SIZE)
vgg_aug_eval = aug_vgg_model.evaluate(vgg_test, labels_test,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(aug_vgg_model.metrics_names, vgg_aug_eval):
print(name, ': ', value)
print()
_ = plot_cm(labels_test, vgg_test_predictions_aug)
loss : 0.03780091498978436 tp : 90.0 fp : 4.0 tn : 304.0 fn : 2.0 accuracy : 0.985 precision : 0.9574468 recall : 0.9782609 auc : 0.9936829 No-ships Detected (True Negatives): 304 No-ships Incorrectly Detected (False Positives): 4 Ships Missed (False Negatives): 2 Ships Detected (True Positives): 90 Total Ships: 92
Compared with the defending champion (baseline VGG model), this one has a tiny tiny improvement: one false positive less (therefore higher precision) and also a bit better AUC. This is the world of modern computer vision: fighting for improvements somewhere in the forth digit after the point.
plot_roc("VGG baseline", labels_test, vgg_test_predictions_baseline, color=colors[0])
plot_roc("VGG with class weights", labels_test, vgg_test_predictions_weights, color=colors[0], linestyle='--')
plot_roc("VGG with augmentation", labels_test, vgg_test_predictions_aug, color=colors[0], linestyle=':')
plot_roc("Inception baseline", labels_test, incep_test_predictions_baseline, color=colors[1])
plot_roc("Inception with class weights", labels_test, incep_test_predictions_weights, color=colors[1], linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x1bfaad85c08>
Last chance for Inception based model. Maybe with more training data there will be less overfitting?
incep_train_aug = incep_conv.predict(aug_images)
incep_train_aug.shape
(4232, 1, 1, 2048)
incep_train_aug = flatten(incep_train_aug)
incep_train_aug.shape
(4232, 2048)
aug_incep_model = make_model()
aug_incep_model.load_weights(initial_weights)
aug_incep_history = aug_incep_model.fit(
incep_train_aug,
aug_labels,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks=[early_stopping],
validation_data=(incep_dev, labels_dev))
Train on 4232 samples, validate on 800 samples Epoch 1/100 4232/4232 [==============================] - 5s 1ms/sample - loss: 0.1655 - tp: 2060.0000 - fp: 161.0000 - tn: 1923.0000 - fn: 88.0000 - accuracy: 0.9412 - precision: 0.9275 - recall: 0.9590 - auc: 0.9817 - val_loss: 0.1020 - val_tp: 191.0000 - val_fp: 26.0000 - val_tn: 582.0000 - val_fn: 1.0000 - val_accuracy: 0.9663 - val_precision: 0.8802 - val_recall: 0.9948 - val_auc: 0.9943 Epoch 2/100 4232/4232 [==============================] - 3s 744us/sample - loss: 0.0722 - tp: 2108.0000 - fp: 64.0000 - tn: 2020.0000 - fn: 40.0000 - accuracy: 0.9754 - precision: 0.9705 - recall: 0.9814 - auc: 0.9964 - val_loss: 0.0916 - val_tp: 191.0000 - val_fp: 20.0000 - val_tn: 588.0000 - val_fn: 1.0000 - val_accuracy: 0.9737 - val_precision: 0.9052 - val_recall: 0.9948 - val_auc: 0.9943 Epoch 3/100 4232/4232 [==============================] - 3s 732us/sample - loss: 0.0426 - tp: 2127.0000 - fp: 35.0000 - tn: 2049.0000 - fn: 21.0000 - accuracy: 0.9868 - precision: 0.9838 - recall: 0.9902 - auc: 0.9985 - val_loss: 0.0791 - val_tp: 190.0000 - val_fp: 16.0000 - val_tn: 592.0000 - val_fn: 2.0000 - val_accuracy: 0.9775 - val_precision: 0.9223 - val_recall: 0.9896 - val_auc: 0.9949 Epoch 4/100 4232/4232 [==============================] - 3s 728us/sample - loss: 0.0340 - tp: 2129.0000 - fp: 31.0000 - tn: 2053.0000 - fn: 19.0000 - accuracy: 0.9882 - precision: 0.9856 - recall: 0.9912 - auc: 0.9992 - val_loss: 0.0911 - val_tp: 191.0000 - val_fp: 19.0000 - val_tn: 589.0000 - val_fn: 1.0000 - val_accuracy: 0.9750 - val_precision: 0.9095 - val_recall: 0.9948 - val_auc: 0.9931 Epoch 5/100 4232/4232 [==============================] - 3s 731us/sample - loss: 0.0256 - tp: 2138.0000 - fp: 24.0000 - tn: 2060.0000 - fn: 10.0000 - accuracy: 0.9920 - precision: 0.9889 - recall: 0.9953 - auc: 0.9993 - val_loss: 0.0895 - val_tp: 187.0000 - val_fp: 17.0000 - val_tn: 591.0000 - val_fn: 5.0000 - val_accuracy: 0.9725 - val_precision: 0.9167 - val_recall: 0.9740 - val_auc: 0.9934 Epoch 6/100 4232/4232 [==============================] - 3s 733us/sample - loss: 0.0190 - tp: 2140.0000 - fp: 14.0000 - tn: 2070.0000 - fn: 8.0000 - accuracy: 0.9948 - precision: 0.9935 - recall: 0.9963 - auc: 0.9995 - val_loss: 0.0825 - val_tp: 186.0000 - val_fp: 13.0000 - val_tn: 595.0000 - val_fn: 6.0000 - val_accuracy: 0.9762 - val_precision: 0.9347 - val_recall: 0.9688 - val_auc: 0.9941 Epoch 7/100 4232/4232 [==============================] - 3s 732us/sample - loss: 0.0133 - tp: 2145.0000 - fp: 13.0000 - tn: 2071.0000 - fn: 3.0000 - accuracy: 0.9962 - precision: 0.9940 - recall: 0.9986 - auc: 0.9999 - val_loss: 0.0851 - val_tp: 183.0000 - val_fp: 12.0000 - val_tn: 596.0000 - val_fn: 9.0000 - val_accuracy: 0.9737 - val_precision: 0.9385 - val_recall: 0.9531 - val_auc: 0.9935 Epoch 8/100 4232/4232 [==============================] - 3s 725us/sample - loss: 0.0095 - tp: 2147.0000 - fp: 7.0000 - tn: 2077.0000 - fn: 1.0000 - accuracy: 0.9981 - precision: 0.9968 - recall: 0.9995 - auc: 1.0000 - val_loss: 0.1279 - val_tp: 189.0000 - val_fp: 22.0000 - val_tn: 586.0000 - val_fn: 3.0000 - val_accuracy: 0.9688 - val_precision: 0.8957 - val_recall: 0.9844 - val_auc: 0.9929 Epoch 9/100 4232/4232 [==============================] - 3s 733us/sample - loss: 0.0081 - tp: 2145.0000 - fp: 6.0000 - tn: 2078.0000 - fn: 3.0000 - accuracy: 0.9979 - precision: 0.9972 - recall: 0.9986 - auc: 1.0000 - val_loss: 0.1129 - val_tp: 189.0000 - val_fp: 18.0000 - val_tn: 590.0000 - val_fn: 3.0000 - val_accuracy: 0.9737 - val_precision: 0.9130 - val_recall: 0.9844 - val_auc: 0.9934 Epoch 10/100 4232/4232 [==============================] - 3s 730us/sample - loss: 0.0059 - tp: 2148.0000 - fp: 6.0000 - tn: 2078.0000 - fn: 0.0000e+00 - accuracy: 0.9986 - precision: 0.9972 - recall: 1.0000 - auc: 1.0000 - val_loss: 0.1186 - val_tp: 188.0000 - val_fp: 19.0000 - val_tn: 589.0000 - val_fn: 4.0000 - val_accuracy: 0.9712 - val_precision: 0.9082 - val_recall: 0.9792 - val_auc: 0.9920 Epoch 11/100 4232/4232 [==============================] - 3s 734us/sample - loss: 0.0043 - tp: 2147.0000 - fp: 5.0000 - tn: 2079.0000 - fn: 1.0000 - accuracy: 0.9986 - precision: 0.9977 - recall: 0.9995 - auc: 1.0000 - val_loss: 0.1383 - val_tp: 188.0000 - val_fp: 21.0000 - val_tn: 587.0000 - val_fn: 4.0000 - val_accuracy: 0.9688 - val_precision: 0.8995 - val_recall: 0.9792 - val_auc: 0.9915 Epoch 12/100 4232/4232 [==============================] - 3s 736us/sample - loss: 0.0032 - tp: 2148.0000 - fp: 3.0000 - tn: 2081.0000 - fn: 0.0000e+00 - accuracy: 0.9993 - precision: 0.9986 - recall: 1.0000 - auc: 1.0000 - val_loss: 0.1389 - val_tp: 189.0000 - val_fp: 21.0000 - val_tn: 587.0000 - val_fn: 3.0000 - val_accuracy: 0.9700 - val_precision: 0.9000 - val_recall: 0.9844 - val_auc: 0.9915 Epoch 13/100 4160/4232 [============================>.] - ETA: 0s - loss: 0.0017 - tp: 2116.0000 - fp: 1.0000 - tn: 2043.0000 - fn: 0.0000e+00 - accuracy: 0.9998 - precision: 0.9995 - recall: 1.0000 - auc: 1.0000Restoring model weights from the end of the best epoch. 4232/4232 [==============================] - 4s 951us/sample - loss: 0.0017 - tp: 2148.0000 - fp: 1.0000 - tn: 2083.0000 - fn: 0.0000e+00 - accuracy: 0.9998 - precision: 0.9995 - recall: 1.0000 - auc: 1.0000 - val_loss: 0.1327 - val_tp: 186.0000 - val_fp: 17.0000 - val_tn: 591.0000 - val_fn: 6.0000 - val_accuracy: 0.9712 - val_precision: 0.9163 - val_recall: 0.9688 - val_auc: 0.9905 Epoch 00013: early stopping
plot_metrics(aug_incep_history)
Wow, the more data it gets, the worse it performs. Too bad for you, Inception!
incep_test_predictions_aug = aug_incep_model.predict(incep_test, batch_size=BATCH_SIZE)
incep_aug_eval = aug_incep_model.evaluate(incep_test, labels_test,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(aug_incep_model.metrics_names, incep_aug_eval):
print(name, ': ', value)
print()
_ = plot_cm(labels_test, incep_test_predictions_aug)
loss : 0.11528322748839855 tp : 88.0 fp : 7.0 tn : 301.0 fn : 4.0 accuracy : 0.9725 precision : 0.9263158 recall : 0.95652175 auc : 0.99107134 No-ships Detected (True Negatives): 301 No-ships Incorrectly Detected (False Positives): 7 Ships Missed (False Negatives): 4 Ships Detected (True Positives): 88 Total Ships: 92
plot_roc("VGG baseline", labels_test, vgg_test_predictions_baseline, color=colors[0])
plot_roc("VGG with class weights", labels_test, vgg_test_predictions_weights, color=colors[0], linestyle='--')
plot_roc("VGG with augmentation", labels_test, vgg_test_predictions_aug, color=colors[0], linestyle=':')
plot_roc("Inception baseline", labels_test, incep_test_predictions_baseline, color=colors[1])
plot_roc("Inception with classe weights", labels_test, incep_test_predictions_weights, color=colors[1], linestyle='--')
plot_roc("Inception with augmentation", labels_test, incep_test_predictions_aug, color=colors[1], linestyle=':')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x1c03df97088>
In the previous section, I trained a total of 6 classifiers = 2 underlying embeddings (VGG16 and Inception) x 3 options for each embedding (baseline, class weights and data augmentation). Time to compare all of them together.
# build a dataframe with all models' metrics
chart = pd.DataFrame(columns = model.metrics_names)
for m in [vgg_baseline_eval, incep_baseline_eval, vgg_weights_eval, incep_weights_eval, vgg_aug_eval, incep_aug_eval]:
chart.loc[len(chart)] = m
models = ['vgg_baseline', 'incep_baseline', 'vgg_weights', 'incep_weights', 'vgg_augment', 'incep_augment']
chart['model'] = models
chart.set_index('model', inplace=True)
chart['F1'] = 2 * chart['precision'] * chart['recall'] / (chart['precision'] + chart['recall'])
chart.sort_values(by='F1', ascending=False)
loss | tp | fp | tn | fn | accuracy | precision | recall | auc | F1 | |
---|---|---|---|---|---|---|---|---|---|---|
model | ||||||||||
vgg_augment | 0.037801 | 90.0 | 4.0 | 304.0 | 2.0 | 0.9850 | 0.957447 | 0.978261 | 0.993683 | 0.967742 |
vgg_baseline | 0.048311 | 89.0 | 4.0 | 304.0 | 3.0 | 0.9825 | 0.956989 | 0.967391 | 0.993436 | 0.962162 |
incep_augment | 0.115283 | 88.0 | 7.0 | 301.0 | 4.0 | 0.9725 | 0.926316 | 0.956522 | 0.991071 | 0.941176 |
incep_baseline | 0.094018 | 86.0 | 5.0 | 303.0 | 6.0 | 0.9725 | 0.945055 | 0.934783 | 0.993471 | 0.939891 |
vgg_weights | 0.103227 | 91.0 | 12.0 | 296.0 | 1.0 | 0.9675 | 0.883495 | 0.989130 | 0.995200 | 0.933333 |
incep_weights | 0.119908 | 87.0 | 8.0 | 300.0 | 5.0 | 0.9675 | 0.915789 | 0.945652 | 0.991813 | 0.930481 |
Without having a specific business problem to solve, it is hard to choose between six models, all showing decent performance (even the worst one gives 96.75% accuracy). It is always good to use a single real-value metric to compare all the models, but what should it be? Three options come to mind:
VGG-based model with augmented data scores best both in terms of accuracy (only 6 mislabelled examples out of 400, or 98.5%) and F1 score. However, VGG-based model with class weights scores better on AUC-ROC metric.
With current decision treshold of 0.5, this model found 12 false positive ships where there are none (and missed one true ship), but here is what we can get of it by simple rising that threshold. I played with p manually until I found the value that yields the least false positives without adding any new false negatives. This may be considered data leakage, as I am adjusting the threshold according to test data, so I would not do this in production, but here my task is just to show what I can theoretically squeeze out of this model.
cm = plot_cm(labels_test, vgg_test_predictions_weights, p=0.81)
No-ships Detected (True Negatives): 302 No-ships Incorrectly Detected (False Positives): 6 Ships Missed (False Negatives): 1 Ships Detected (True Positives): 91 Total Ships: 92
If my goal was not to miss an actual ship at any cost, this would probably be my best choice. Let's manually calculate precision, recall and F1 score for this model:
tn = cm[0][0]
fp = cm[0][1]
fn = cm[1][0]
tp = cm[1][1]
precision = tp / (tp+fp)
recall = tp / (tp+fn)
f1 = 2 * precision * recall / (precision + recall)
print("Improved F1 score for VGG based model with class weights: %.3f" % f1)
Improved F1 score for VGG based model with class weights: 0.963
So even with this cheating adjustment of threshold on test data, the F1 score does not beat data augmentation model's one.
Using VGG model with data augmentation, let's display mislabeled examples from training, dev and test sets to see if there is anything in common.
vgg_train_predictions_aug = aug_vgg_model.predict(vgg_train, batch_size=BATCH_SIZE) # look only at original train data, not the augmented data
errors_train = np.reshape(vgg_train_predictions_aug >0.5, labels_train.shape) != labels_train # all mislabeled test examples as boolean array
fp_train = images_train[errors_train & (labels_train==0)] # false positive images
fn_train = images_train[errors_train & (labels_train==1)] # false negative images
fp_train.shape, fn_train.shape
((5, 80, 80, 3), (1, 80, 80, 3))
False positives:
columns = 5
rows = 1
fig=plt.figure(figsize=(20, 20))
for i in range(0, columns*rows):
img = fp_train[i]
fig.add_subplot(rows, columns, i+1)
plt.imshow(img)
plt.show()
False negatives:
columns = 1
rows = 1
fig=plt.figure(figsize=(4, 4))
for i in range(0, columns*rows):
img = fn_train[i]
fig.add_subplot(rows, columns, i+1)
plt.imshow(img)
plt.show()
vgg_dev_predictions_aug = aug_vgg_model.predict(vgg_dev, batch_size=BATCH_SIZE)
errors_dev = np.reshape(vgg_dev_predictions_aug >0.5, labels_dev.shape) != labels_dev
fp_dev = images_dev[errors_dev & (labels_dev==0)]
fn_dev = images_dev[errors_dev & (labels_dev==1)] # false negative images
fp_dev.shape, fn_dev.shape
((7, 80, 80, 3), (1, 80, 80, 3))
False positives:
columns = 7
rows = 1
fig=plt.figure(figsize=(20, 20))
for i in range(0, columns*rows):
img = fp_dev[i]
fig.add_subplot(rows, columns, i+1)
plt.imshow(img)
plt.show()
False negatives:
columns = 1
rows = 1
fig=plt.figure(figsize=(4, 4))
for i in range(0, columns*rows):
img = fn_dev[i]
fig.add_subplot(rows, columns, i+1)
plt.imshow(img)
plt.show()
# vgg_test_predictions_aug was calculated earlier
errors_test = np.reshape(vgg_test_predictions_aug >0.5, labels_test.shape) != labels_test
fp_test = images_test[errors_test & (labels_test==0)]
fn_test = images_test[errors_test & (labels_test==1)]
fp_test.shape, fn_test.shape
((4, 80, 80, 3), (2, 80, 80, 3))
False positives:
columns = 4
rows = 1
fig=plt.figure(figsize=(20, 20))
for i in range(0, columns*rows):
img = fp_test[i]
fig.add_subplot(rows, columns, i+1)
plt.imshow(img)
plt.show()
False negatives:
columns = 2
rows = 1
fig=plt.figure(figsize=(6, 6))
for i in range(0, columns*rows):
img = fn_test[i]
fig.add_subplot(rows, columns, i+1)
plt.imshow(img)
plt.show()
Findings are not super clear, but it looks like false positives often include either
By providing more images of this kind in training data, it may be possible to improve model's results.
Next step would be going from image classification to object detection: exploring bigger satellite images, finding regions of interest (ROI) and predicting bounding boxes for all the ships present.