Explaining Model’s Predictions

This library has the function explain() as core component, which is used to execute any AI explaining method and technique. Think of it as the keras.Model.fit() or keras.Model.predict() loops of Keras’ models, in which the execution graph of the operations contained in a model is compiled (conditioned to Model.run_eagerly and Model.jit_compile) and the explaining maps are computed according to the method’s strategy.

Just like in keras.model.predict(), explain() allows various types of input data and retrieves the Model’s associated distribute strategy in order to distribute the workload across multiple GPUs and/or workers.

We demonstrate bellow how predictions can be explained using the Xception network trained over ImageNet, using a few image samples. Firstly, we load the network:

model = tf.keras.applications.Xception(
  classifier_activation=None,
  weights='imagenet',
)

print(f"Spatial map sizes: {model.get_layer('avg_pool').input.shape}")
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels.h5

    8192/91884032 [..............................] - ETA: 0s

   57344/91884032 [..............................] - ETA: 1:22

  221184/91884032 [..............................] - ETA: 42s 

  819200/91884032 [..............................] - ETA: 16s

 3203072/91884032 [>.............................] - ETA: 5s 

 8396800/91884032 [=>............................] - ETA: 2s

11788288/91884032 [==>...........................] - ETA: 2s

16769024/91884032 [====>.........................] - ETA: 1s

22061056/91884032 [======>.......................] - ETA: 1s

27549696/91884032 [=======>......................] - ETA: 1s

33464320/91884032 [=========>....................] - ETA: 0s

39378944/91884032 [===========>..................] - ETA: 0s

45424640/91884032 [=============>................] - ETA: 0s

51519488/91884032 [===============>..............] - ETA: 0s

57450496/91884032 [=================>............] - ETA: 0s

63627264/91884032 [===================>..........] - ETA: 0s

69804032/91884032 [=====================>........] - ETA: 0s

75964416/91884032 [=======================>......] - ETA: 0s

82141184/91884032 [=========================>....] - ETA: 0s

87990272/91884032 [===========================>..] - ETA: 0s

91884032/91884032 [==============================] - 1s 0us/step
Spatial map sizes: (None, 10, 10, 2048)

We can feed-forward the samples once and get the predicted classes for each sample. Besides making sure the model is outputting the expected classes, this step is required in order to determine the most activating units in the logits layer, which improves performance of the explaining methods.

from tensorflow.keras.applications.imagenet_utils import preprocess_input, decode_predictions

inputs = images / 127.5 - 1
logits = model.predict(inputs, verbose=0)

indices = np.argsort(logits, axis=-1)[:, ::-1]
probs = tf.nn.softmax(logits).numpy()
predictions = decode_predictions(probs, top=1)

ke.utils.visualize(
  images=images,
  titles=[
    ", ".join(f"{klass} {prob:.0%}" for code, klass, prob in p)
    for p in predictions
  ]
)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json

 8192/35363 [=====>........................] - ETA: 0s

35363/35363 [==============================] - 0s 1us/step
_images/explaining_2_3.png

Finally, we can simply run all available explaining methods:

explaining_units = indices[:, :1]  # First most likely class.

# Gradient Back-propagation
_, g_maps = ke.gradients(model, inputs, explaining_units)

# Full-Gradient
logits = ke.inspection.get_logits_layer(model)
inters, biases = ke.inspection.layers_with_biases(model, exclude=[logits])
model_exp = ke.inspection.expose(model, inters, logits)
_, fg_maps = ke.full_gradients(model_exp, inputs, explaining_units, biases=biases)

# CAM-Based
model_exp = ke.inspection.expose(model)
_, c_maps = ke.cam(model_exp, inputs, explaining_units)
_, gc_maps = ke.gradcam(model_exp, inputs, explaining_units)
_, gcpp_maps = ke.gradcampp(model_exp, inputs, explaining_units)
_, sc_maps = ke.scorecam(model_exp, inputs, explaining_units)
_images/explaining_4_0.png

The functions above are simply shortcuts for explain(), using their conventional hyper-parameters and post processing functions. For more flexibility, you can use the regular form:

logits, cams = ke.explain(
  ke.methods.cam.gradcam,
  model_exp,
  inputs,
  explaining_units,
  batch_size=32,
  postprocessing=ke.filters.positive_normalize,
)

While the explain() function is a convenient wrapper, transparently distributing the workload based on the distribution strategy associated with the model, it is not a necessary component in the overall functioning of the library. Alternatively, one can call any explaining method directly:

logits, cams = ke.methods.cams.gradcam(model, inputs, explaining_units)

# Or the following, which is more efficient:
gradcam = tf.function(ke.methods.cams.gradcam, reduce_retracing=True)
logits, cams = gradcam(model, inputs, explaining_units)

cams = ke.filters.positive_normalize(cams)
cams = tf.image.resize(cams, (299, 299)).numpy()