Updated On : Feb-24,2022 Time Investment : ~30 mins

SHAP Values for Image Classification Tasks (Keras)

Deep learning models like convolutional neural networks are giving quite good results at many computer vision tasks. We need to understand that the models that are giving such high accuracy are predicting results based on data parts that they should use for prediction. Let's say for example that we have an image classification task of predicting cat vs dog then the model should look at pixels of face and body of cat/dog to predict class, not the background pixels of images should be used to make a decision. If that is the case then we can be sure that our model has generalized better and actually learning features of cats and dogs. There are many prediction interpretation libraries but as a part of this tutorial, we'll be using SHAP. SHAP is a python library that generates shap values for predictions using a game-theoretic approach. We can then visualize these shap values using various visualizations to understand which features contributed to prediction. We have a starter tutorial on SHAP where we discuss how to use it for tabular (structured) datasets. Please check the below link if you want to refer to it.

As a part of this tutorial, we have designed a simple CNN using keras and trained it with the Fashion MNIST dataset. Then, we have explained correct and incorrect predictions using SHAP python library.

Below, we have listed important sections of tutorial to give an overview of the material covered.

Important Sections Of Tutorial

  1. Load Data
  2. Define CNN
  3. Compile And Train Model
  4. Evaluate Model Performance
  5. Explain Predictions Using SHAP Partition Explainer
    • Visualize SHAP Values For Correct Predictions
      • Image Plots
    • Visualize SHAP Values For Incorrect Predictions
  6. Explain Predictions Using SHAP Permutation Explainer
    • Visualize SHAP Values For Correct Predictions
    • Visualize SHAP Values For Incorrect Predictions

Below, we have imported the necessary libraries and printed the versions that we have used in our tutorial.

import tensorflow
from tensorflow import keras

print("Keras Version : {}".format(keras.__version__))
Keras Version : 2.6.0
import shap

print("SHAP Version : {}".format(shap.__version__))
SHAP Version : 0.40.0

1. Load Data

In this section, we have loaded the Fashion MNIST dataset available from keras. The dataset has grayscale images of shape (28,28) pixels for 10 different fashion items. The dataset is already divided into the train (60k images) and test (10k images) sets. Below we have included mapping from index to fashion item names.

Label Description
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot
from tensorflow import keras
from sklearn.model_selection import train_test_split
from jax import numpy as jnp
import numpy as np

(X_train, Y_train), (X_test, Y_test) = keras.datasets.fashion_mnist.load_data()

X_train, X_test = X_train.reshape(-1,28,28,1), X_test.reshape(-1,28,28,1)

X_train, X_test = X_train/255.0, X_test/255.0

classes =  np.unique(Y_train)
class_labels = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot"]
mapping = dict(zip(classes, class_labels))

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
40960/29515 [=========================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
26435584/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
16384/5148 [===============================================================================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step
4431872/4422102 [==============================] - 0s 0us/step
((60000, 28, 28, 1), (10000, 28, 28, 1), (60000,), (10000,))

2. Define CNN

In this section, we have defined a convolutional neural network that we'll use to classify images of the Fashion MNIST dataset loaded earlier. The network is simple with 2 convolution layers and one dense layer. The two convolution layers have filters 32 and 16 respectively and both are using kernel size of (3,3). We have applied relu (rectified linear unit) activation after both convolution layers. The output of the second convolution layer is flattened and then given to the dense layer. The dense layer has 10 output units, the same as the number of target classes.

from tensorflow.keras.models import Sequential
from tensorflow.keras import layers

model = Sequential([
    layers.Input(shape=X_train.shape[1:]),
    layers.Conv2D(filters=32, kernel_size=(3,3), padding="same", activation="relu"),
    layers.Conv2D(filters=16, kernel_size=(3,3), padding="same", activation="relu"),

    layers.Flatten(),
    layers.Dense(len(classes), activation="softmax")
])

model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d (Conv2D)              (None, 28, 28, 32)        320
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 28, 28, 16)        4624
_________________________________________________________________
flatten (Flatten)            (None, 12544)             0
_________________________________________________________________
dense (Dense)                (None, 10)                125450
=================================================================
Total params: 130,394
Trainable params: 130,394
Non-trainable params: 0
_________________________________________________________________

3. Compile And Train Model

In this section, we have first compiled our model to use Adam optimizer, cross entropy loss, and accuracy metric. Then, we have trained the model for 10 epochs using train and validation data. We can notice from the results getting printed after each epoch that the model has achieved near 91% accuracy on validation data and 93% accuracy on train data. It seems to be good accuracy and we can now evaluate the predictions made by our model.

model.compile("adam", "sparse_categorical_crossentropy", ["accuracy"])
model.fit(X_train, Y_train, batch_size=256, epochs=10, validation_data=(X_test, Y_test))
Epoch 1/10
235/235 [==============================] - 20s 80ms/step - loss: 0.5326 - accuracy: 0.8166 - val_loss: 0.3939 - val_accuracy: 0.8626
Epoch 2/10
235/235 [==============================] - 18s 77ms/step - loss: 0.3452 - accuracy: 0.8769 - val_loss: 0.3492 - val_accuracy: 0.8784
Epoch 3/10
235/235 [==============================] - 18s 79ms/step - loss: 0.3043 - accuracy: 0.8906 - val_loss: 0.3293 - val_accuracy: 0.8789
Epoch 4/10
235/235 [==============================] - 19s 79ms/step - loss: 0.2772 - accuracy: 0.9008 - val_loss: 0.3176 - val_accuracy: 0.8849
Epoch 5/10
235/235 [==============================] - 18s 77ms/step - loss: 0.2538 - accuracy: 0.9089 - val_loss: 0.2900 - val_accuracy: 0.8950
Epoch 6/10
235/235 [==============================] - 19s 79ms/step - loss: 0.2381 - accuracy: 0.9133 - val_loss: 0.2797 - val_accuracy: 0.9011
Epoch 7/10
235/235 [==============================] - 18s 78ms/step - loss: 0.2214 - accuracy: 0.9203 - val_loss: 0.2771 - val_accuracy: 0.9018
Epoch 8/10
235/235 [==============================] - 18s 79ms/step - loss: 0.2092 - accuracy: 0.9249 - val_loss: 0.2716 - val_accuracy: 0.9065
Epoch 9/10
235/235 [==============================] - 18s 78ms/step - loss: 0.1960 - accuracy: 0.9298 - val_loss: 0.2731 - val_accuracy: 0.9018
Epoch 10/10
235/235 [==============================] - 19s 79ms/step - loss: 0.1870 - accuracy: 0.9313 - val_loss: 0.2681 - val_accuracy: 0.9044
<keras.callbacks.History at 0x7fba369bd3d0>

4. Evaluate Model Performance

In this section, we have evaluated the performance of our network by calculating accuracy, confusion matrix, and classification report (precision, recall, and f1-score per class) metrics. We can notice from the results that our model is doing good for categories except 'shirt', 'coat', 'pullover' and 'top' which are almost the same hence can be confusing for the model to differentiate between them in grayscale images. Many 'shirt' category images are confused with 't-shirt/top' category as there is quite less difference between them.

We have used the function available from scikit-learn to calculate various metrics. If you want to learn about various ML metrics available from sklearn then please feel free to check the below link which covers the majority of them.

from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

Y_test_preds = model.predict(X_test)
Y_test_preds = np.argmax(Y_test_preds, axis=1)

print("Test Accuracy : {}".format(accuracy_score(Y_test, Y_test_preds)))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, Y_test_preds))
print("\nClassification Report :")
print(classification_report(Y_test, Y_test_preds, target_names=class_labels))
Test Accuracy : 0.9044

Confusion Matrix :
[[875   0  17  16   3   1  82   0   6   0]
 [  4 975   1  13   3   0   3   0   1   0]
 [ 20   1 877   8  37   0  57   0   0   0]
 [ 16   3  10 942   8   0  19   0   2   0]
 [  2   0  99  49 774   0  76   0   0   0]
 [  0   0   0   1   0 979   0  12   1   7]
 [125   1  67  36  43   0 719   0   9   0]
 [  0   0   0   0   0   7   0 964   1  28]
 [  6   1   1   2   4   3   6   3 974   0]
 [  1   0   0   0   0   8   0  26   0 965]]

Classification Report :
              precision    recall  f1-score   support

 T-shirt/top       0.83      0.88      0.85      1000
     Trouser       0.99      0.97      0.98      1000
    Pullover       0.82      0.88      0.85      1000
       Dress       0.88      0.94      0.91      1000
        Coat       0.89      0.77      0.83      1000
      Sandal       0.98      0.98      0.98      1000
       Shirt       0.75      0.72      0.73      1000
     Sneaker       0.96      0.96      0.96      1000
         Bag       0.98      0.97      0.98      1000
  Ankle boot       0.96      0.96      0.96      1000

    accuracy                           0.90     10000
   macro avg       0.90      0.90      0.90     10000
weighted avg       0.90      0.90      0.90     10000

5. SHAP Partition Explainer

In this section, we have used SHAP partition explainer to explain the predictions made by our network. We have tried to explain both correct and wrong predictions of our model to see which parts of an image are contributing to predictions.

In order to use SHAP, we first need to initialize it by calling initjs() function.

Then, we have created an instance of Partition explainer by using Explainer() constructor. We have given three values to the constructor.

  • Our trained model
  • Masker to mask parts of an image using blurring or inpainting.
  • List of Target Class labels

The Explainer() constructor by default creates an instance of Partition explainer. We can generate shap values for images using this explainer object.

shap.initjs()

SHAP Values for Image Classification Tasks (Keras)

masker = shap.maskers.Image("inpaint_telea", X_train[0].shape)

explainer = shap.Explainer(model, masker, output_names=class_labels)

explainer
<shap.explainers._partition.Partition at 0x7fba36892310>

Visualize SHAP Values For Correct Predictions

In this section, we have explained the correct predictions made by our model. We have first taken 4 sample images from our test dataset and generated shap values for them using the partition explainer object we created earlier. Then, we have printed the actual labels, predicted labels, and the model's predicted probabilities.

shap_values = explainer(X_test[:4], outputs=shap.Explanation.argsort.flip[:5])

shap_values.shape
(4, 28, 28, 1, 5)
print("Actual Labels    : {}".format([mapping[i] for i in Y_test[:4]]))
probs = model.predict(X_test[:4])
print("Predicted Labels : {}".format([mapping[i] for i in np.argmax(probs, axis=1)]))
print("Probabilities : {}".format(np.max(probs, axis=1)))
Actual Labels    : ['Ankle boot', 'Pullover', 'Trouser', 'Trouser']
Predicted Labels : ['Ankle boot', 'Pullover', 'Trouser', 'Trouser']
Probabilities : [0.99812347 0.9995265  0.9999877  0.99999714]

Image Plot

In this section, we have generated image plots that visualizes shap values generated by the explainer object.

Below, we have generated the first image plot using shap values generated in previous cells. The chart shows the actual image and parts of it highlighted in shades of red and blue colors. The shades of red color show parts that contributed positively and shades of blue color show parts that contributed negatively to the prediction of that category. It also shows the first five categories that the model thinks the image belongs to.

shap.image_plot(shap_values)

SHAP Values for Image Classification Tasks (Keras)

Below, we have generated another image plot using a different masker named inpaint_ns. We have created masker and explainer instances for this again.

masker = shap.maskers.Image("inpaint_ns", X_train[0].shape)

explainer = shap.Explainer(model, masker, output_names=class_labels)

shap_values = explainer(X_test[:4], outputs=shap.Explanation.argsort.flip[:5])

shap.image_plot(shap_values)

SHAP Values for Image Classification Tasks (Keras)

Below, we have generated another image plot using masker that uses blurring-based masker. We can notice that the blurring masker seems to be doing a good job compared to other maskers. The tuple of integer values that we provide in the string of masker is the size of the kernel used to blur.

masker = shap.maskers.Image("blur(28,28)", X_train[0].shape)

explainer = shap.Explainer(model, masker, output_names=class_labels)

shap_values = explainer(X_test[:4], outputs=shap.Explanation.argsort.flip[:5])

shap.image_plot(shap_values)

SHAP Values for Image Classification Tasks (Keras)

Visualize SHAP Values For Incorrect Predictions

In this section, we have explained wrong predictions using the explainer object. In order to explain wrong predictions, we have first retrieved indexes of all wrong predictions from the test set. Then, we have used indexes of wrong predictions to retrieve those samples and make predictions on them again to retrieve the probabilities of the model for those predictions.

wrong_preds_idx = np.argwhere(Y_test!=Y_test_preds)

X_batch = X_test[wrong_preds_idx.flatten()[:4]]
Y_batch = Y_test[wrong_preds_idx.flatten()[:4]]

print("Actual Labels    : {}".format([mapping[i] for i in Y_batch]))
probs = model.predict(X_batch)
print("Predicted Labels : {}".format([mapping[i] for i in np.argmax(probs, axis=1)]))
print("Probabilities : {}".format(np.max(probs, axis=1)))
Actual Labels    : ['Sneaker', 'Coat', 'Ankle boot', 'Coat']
Predicted Labels : ['Bag', 'Pullover', 'Sandal', 'Shirt']
Probabilities : [0.59116197 0.53697795 0.99997365 0.671842  ]

Below, we have generated an image plot using by generating shap values for wrong predictions. In the next cell, we have explained that we can generate visualization using image() function of plots sub-module of shap library.

masker = shap.maskers.Image("blur(28,28)", X_train[0].shape)

explainer = shap.Explainer(model, masker, output_names=class_labels)

shap_values = explainer(X_batch, outputs=shap.Explanation.argsort.flip[:5])

shap.image_plot(shap_values)

SHAP Values for Image Classification Tasks (Keras)

shap.plots.image(shap_values)

SHAP Values for Image Classification Tasks (Keras)

6. SHAP Permutation Explainer

In this section, we are trying another explainer available from a shap named permutation explainer. The permutation explainer can be created using PermutationExplainer() constructor and accepts the same parameters as the permutation explainer. The permutation explainer tries different combinations of features to generate shap values.

Below, we have first created a permutation explainer using model and masker objects.

masker = shap.maskers.Image("blur(28,28)", X_train[0].shape)

explainer = shap.PermutationExplainer(model, masker, output_names=class_labels)

explainer
<shap.explainers._permutation.Permutation.__init__.<locals>.Permutation at 0x7fb94886a790>

Visualize SHAP Values For Correct Predictions

In this section, we have generated shap values for 4 test images using the permutation explainer object. In the next cell, we have also printed actual labels, predicted labels, and probabilities of those 4 sample images. We have also calculated labels according to the 10 probabilities generated by our model.

shap_values = explainer(X_test[:4], max_evals=1600, outputs=shap.Explanation.argsort.flip[:5])

shap_values.shape
(4, 28, 28, 1, 10)
print("Actual Labels    : {}".format([mapping[i] for i in Y_test[:4]]))
probs = model.predict(X_test[:4])
print("Predicted Labels : {}".format([mapping[i] for i in np.argmax(probs, axis=1)]))
print("Probabilities : {}".format(np.max(probs, axis=1)))
Actual Labels    : ['Ankle boot', 'Pullover', 'Trouser', 'Trouser']
Predicted Labels : ['Ankle boot', 'Pullover', 'Trouser', 'Trouser']
Probabilities : [0.99812347 0.9995265  0.9999877  0.99999714]
Y_preds = model.predict(X_test[:4])

Y_preds = Y_preds.argsort()[:, ::-1]
Y_labels = [[class_labels[val] for val in row] for row in Y_preds]
Y_labels=np.array(Y_labels)

Y_labels
array([['Ankle boot', 'Sneaker', 'Sandal', 'Bag', 'Shirt', 'Pullover',
        'T-shirt/top', 'Coat', 'Dress', 'Trouser'],
       ['Pullover', 'T-shirt/top', 'Shirt', 'Coat', 'Bag', 'Dress',
        'Ankle boot', 'Trouser', 'Sneaker', 'Sandal'],
       ['Trouser', 'T-shirt/top', 'Dress', 'Bag', 'Shirt', 'Pullover',
        'Coat', 'Sandal', 'Ankle boot', 'Sneaker'],
       ['Trouser', 'Dress', 'Coat', 'T-shirt/top', 'Shirt', 'Ankle boot',
        'Pullover', 'Sandal', 'Bag', 'Sneaker']], dtype='<U11')

In this below cell, we have plotted an image plot showing shap values that contributed to predictions.

shap.image_plot(shap_values, labels=Y_labels)

SHAP Values for Image Classification Tasks (Keras)

shap.image_plot(shap_values[:,:,:,:,:5], labels=Y_labels[:,:5])

SHAP Values for Image Classification Tasks (Keras)

Visualize SHAP Values For Incorrect Predictions

In this section, we have generated shap values for wrong predictions. The majority of code in this section is a repeat of earlier sections hence we have not included repeated explanations for them.

wrong_preds_idx = np.argwhere(Y_test!=Y_test_preds)

X_batch = X_test[wrong_preds_idx.flatten()[:4]]
Y_batch = Y_test[wrong_preds_idx.flatten()[:4]]

print("Actual Labels    : {}".format([mapping[i] for i in Y_batch]))
probs = model.predict(X_batch)
print("Predicted Labels : {}".format([mapping[i] for i in np.argmax(probs, axis=1)]))
print("Probabilities : {}".format(np.max(probs, axis=1)))
Actual Labels    : ['Sneaker', 'Coat', 'Ankle boot', 'Coat']
Predicted Labels : ['Bag', 'Pullover', 'Sandal', 'Shirt']
Probabilities : [0.59116197 0.53697795 0.99997365 0.671842  ]
masker = shap.maskers.Image("blur(28,28)", X_train[0].shape)

explainer = shap.PermutationExplainer(model, masker, output_names=class_labels)

shap_values = explainer(X_batch, max_evals=1600, outputs=shap.Explanation.argsort.flip[:5])

shap_values.shape
(4, 28, 28, 1, 10)
Y_preds = model.predict(X_batch)

Y_preds = Y_preds.argsort()[:, ::-1]
Y_labels = [[class_labels[val] for val in row] for row in Y_preds]
Y_labels=np.array(Y_labels)

Y_labels
array([['Bag', 'Sandal', 'Sneaker', 'Dress', 'Pullover', 'Shirt',
        'T-shirt/top', 'Ankle boot', 'Coat', 'Trouser'],
       ['Pullover', 'Shirt', 'Coat', 'T-shirt/top', 'Bag', 'Dress',
        'Sandal', 'Trouser', 'Sneaker', 'Ankle boot'],
       ['Sandal', 'Ankle boot', 'Bag', 'Sneaker', 'Shirt', 'T-shirt/top',
        'Pullover', 'Dress', 'Coat', 'Trouser'],
       ['Shirt', 'Pullover', 'Coat', 'T-shirt/top', 'Dress', 'Trouser',
        'Bag', 'Ankle boot', 'Sandal', 'Sneaker']], dtype='<U11')
shap.image_plot(shap_values, labels=Y_labels)

SHAP Values for Image Classification Tasks (Keras)

shap.image_plot(shap_values[:,:,:,:,:5], labels=Y_labels[:,:5])

SHAP Values for Image Classification Tasks (Keras)

This ends our small tutorial explaining how we can generate SHAP values for image classification networks created using keras to explain predictions made by the model. Please feel free to let us know your views in the comments section.

References

Sunny Solanki  Sunny Solanki

YouTube Subscribe Comfortable Learning through Video Tutorials?

If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.

Need Help Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.