Explaining Model’s Predictions¶
This library has the function explain()
as core
component, which is used to execute any AI explaining method and technique.
Think of it as the keras.Model.fit()
or keras.Model.predict()
loops of Keras’ models, in which the execution graph of the operations
contained in a model is compiled (conditioned to Model.run_eagerly
and Model.jit_compile
) and the explaining maps are computed
according to the method’s strategy.
Just like in keras.model.predict()
, explain()
allows various types of input data and retrieves the Model’s associated
distribute strategy in order to distribute the workload across multiple
GPUs and/or workers.
We demonstrate bellow how predictions can be explained using the Xception network trained over ImageNet, using a few image samples. Firstly, we load the network:
model = tf.keras.applications.Xception(
classifier_activation=None,
weights='imagenet',
)
print(f"Spatial map sizes: {model.get_layer('avg_pool').input.shape}")
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels.h5
8192/91884032 [..............................] - ETA: 0s
57344/91884032 [..............................] - ETA: 1:22
221184/91884032 [..............................] - ETA: 42s
819200/91884032 [..............................] - ETA: 16s
3203072/91884032 [>.............................] - ETA: 5s
8396800/91884032 [=>............................] - ETA: 2s
11788288/91884032 [==>...........................] - ETA: 2s
16769024/91884032 [====>.........................] - ETA: 1s
22061056/91884032 [======>.......................] - ETA: 1s
27549696/91884032 [=======>......................] - ETA: 1s
33464320/91884032 [=========>....................] - ETA: 0s
39378944/91884032 [===========>..................] - ETA: 0s
45424640/91884032 [=============>................] - ETA: 0s
51519488/91884032 [===============>..............] - ETA: 0s
57450496/91884032 [=================>............] - ETA: 0s
63627264/91884032 [===================>..........] - ETA: 0s
69804032/91884032 [=====================>........] - ETA: 0s
75964416/91884032 [=======================>......] - ETA: 0s
82141184/91884032 [=========================>....] - ETA: 0s
87990272/91884032 [===========================>..] - ETA: 0s
91884032/91884032 [==============================] - 1s 0us/step
Spatial map sizes: (None, 10, 10, 2048)
We can feed-forward the samples once and get the predicted classes for each sample. Besides making sure the model is outputting the expected classes, this step is required in order to determine the most activating units in the logits layer, which improves performance of the explaining methods.
from tensorflow.keras.applications.imagenet_utils import preprocess_input, decode_predictions
inputs = images / 127.5 - 1
logits = model.predict(inputs, verbose=0)
indices = np.argsort(logits, axis=-1)[:, ::-1]
probs = tf.nn.softmax(logits).numpy()
predictions = decode_predictions(probs, top=1)
ke.utils.visualize(
images=images,
titles=[
", ".join(f"{klass} {prob:.0%}" for code, klass, prob in p)
for p in predictions
]
)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json
8192/35363 [=====>........................] - ETA: 0s
35363/35363 [==============================] - 0s 1us/step
Finally, we can simply run all available explaining methods:
explaining_units = indices[:, :1] # First most likely class.
# Gradient Back-propagation
_, g_maps = ke.gradients(model, inputs, explaining_units)
# Full-Gradient
logits = ke.inspection.get_logits_layer(model)
inters, biases = ke.inspection.layers_with_biases(model, exclude=[logits])
model_exp = ke.inspection.expose(model, inters, logits)
_, fg_maps = ke.full_gradients(model_exp, inputs, explaining_units, biases=biases)
# CAM-Based
model_exp = ke.inspection.expose(model)
_, c_maps = ke.cam(model_exp, inputs, explaining_units)
_, gc_maps = ke.gradcam(model_exp, inputs, explaining_units)
_, gcpp_maps = ke.gradcampp(model_exp, inputs, explaining_units)
_, sc_maps = ke.scorecam(model_exp, inputs, explaining_units)
The functions above are simply shortcuts for
explain()
, using their conventional
hyper-parameters and post processing functions.
For more flexibility, you can use the regular form:
logits, cams = ke.explain(
ke.methods.cam.gradcam,
model_exp,
inputs,
explaining_units,
batch_size=32,
postprocessing=ke.filters.positive_normalize,
)
While the explain()
function is a convenient
wrapper, transparently distributing the workload based on the distribution strategy
associated with the model, it is not a necessary component in the overall functioning
of the library. Alternatively, one can call any explaining method directly:
logits, cams = ke.methods.cams.gradcam(model, inputs, explaining_units)
# Or the following, which is more efficient:
gradcam = tf.function(ke.methods.cams.gradcam, reduce_retracing=True)
logits, cams = gradcam(model, inputs, explaining_units)
cams = ke.filters.positive_normalize(cams)
cams = tf.image.resize(cams, (299, 299)).numpy()