Exposing Intermediate Signals

This page details the exposure procedure, necessary for most AI explaining methods, and which can be easened with the help of the expose() function.

Simple Exposition Examples

Many explaining techniques require us to expose the intermediate tensors so their respective signals can be used, or so the gradient of the output can be computed with respect to their signals. For example, Grad-CAM computes the gradient of an output unit with respect to the activation signal advent from the last positional layer in the model:

with tf.GradientTape() as tape:
  logits, activations = model(x)

gradients = tape.batch_jacobian(logits, activations)

Which evidently means the activations signal, a tensor of shape (batch, height, width, ..., kernels) must be available at runtime. For that to happen, we must redefine the model, setting its outputs to contain the KerasTensor’s objects that reference both logits and activations tensors:

import numpy as np
import tensorflow as tf
from keras import Input, Model, Sequential
from keras.applications import ResNet50V2
from keras.layers import Activation, Dense, GlobalAveragePooling2D

import keras_explainable as ke

rn50 = ResNet50V2(weights=None, classifier_activation=None)
# activations_tensor = rn50.get_layer("avg_pool").input  # or...
activations_tensor = rn50.get_layer("post_relu").output

model = Model(rn50.input, [rn50.output, activations_tensor])

print(model.name)
print(f"  input: {model.input}")
print("  outputs:")
for o in model.outputs:
  print(f"    {o}")
model
  input: KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'")
  outputs:
    KerasTensor(type_spec=TensorSpec(shape=(None, 1000), dtype=tf.float32, name=None), name='predictions/BiasAdd:0', description="created by layer 'predictions'")
    KerasTensor(type_spec=TensorSpec(shape=(None, 7, 7, 2048), dtype=tf.float32, name=None), name='post_relu/Relu:0', description="created by layer 'post_relu'")

Which can be simplified with:

model = ke.inspection.expose(rn50)

The expose() function inspects the model, seeking for the logits layer (the last containing a kernel property) and the global pooling layer, an instance of a GlobalPooling or Flatten layer classes. The output of the former and the input of the latter are collected and a new model is defined.

You can also manually indicate the name of the argument and output layers. All options bellow are equivalent:

model = ke.inspection.expose(rn50, "post_relu", "predictions")
model = ke.inspection.expose(
  rn50,
  {"name": "post_relu", "link": "output"},
  {"name": "predictions"},
)
model = ke.inspection.expose(
  rn50,
  {"name": "post_relu", "link": "output", "node": 0},
  {"name": "predictions", "link": "output", "node": 0},
)
model = ke.inspection.expose(
  rn50,
  {"name": "avg_pool", "link": "input"},
  "predictions",
)

Grad-CAM (or Grad-CAM++) can be called immediately after that:

inputs = np.random.normal(size=(4, 224, 224, 3))
indices = np.asarray([[4], [9], [0], [2]])

scores, cams = ke.gradcam(model, inputs, indices)

print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]")
print(f"cams:{cams.shape} in [{cams.min()}, {cams.max()}]")
scores:(4, 1) in [-0.1541651487350464, -0.0010260604321956635]
cams:(4, 224, 224, 1) in [0.0, 0.9939659237861633]

Exposing Nested Models

Unfortunately, some model’s topologies can make exposition a little tricky. An example of this is when nesting multiple models, producing more than one Input object and multiple conceptual graphs at once. Then, if one naively collects KerasTensor’s from the model, disconnected nodes may be retrieved, resulting in the exception ValueError: Graph disconnected being raised:

rn50 = ResNet50V2(weights=None, include_top=False)

x = Input([224, 224, 3], name="input_images")
y = rn50(x)
y = GlobalAveragePooling2D(name="avg_pool")(y)
y = Dense(10, name="logits")(y)
y = Activation("softmax", name="predictions", dtype="float32")(y)

rn50_clf = Model(x, y, name="resnet50v2_clf")
rn50_clf.summary()

logits = rn50_clf.get_layer("logits").output
activations = rn50_clf.get_layer("resnet50v2").output

model = tf.keras.Model(rn50_clf.input, [logits, activations])
scores, cams = ke.gradcam(model, inputs, indices)

print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]")
print(f"cams:{cams.shape} in [{cams.min()}, {cams.max()}]")
Model: "resnet50v2_clf"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_images (InputLayer)   [(None, 224, 224, 3)]     0         
                                                                 
 resnet50v2 (Functional)     (None, None, None, 2048)  23564800  
                                                                 
 avg_pool (GlobalAveragePool  (None, 2048)             0         
 ing2D)                                                          
                                                                 
 logits (Dense)              (None, 10)                20490     
                                                                 
 predictions (Activation)    (None, 10)                0         
                                                                 
=================================================================
Total params: 23,585,290
Trainable params: 23,539,850
Non-trainable params: 45,440
_________________________________________________________________
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[3], line 15
     12 logits = rn50_clf.get_layer("logits").output
     13 activations = rn50_clf.get_layer("resnet50v2").output
---> 15 model = tf.keras.Model(rn50_clf.input, [logits, activations])
     16 scores, cams = ke.gradcam(model, inputs, indices)
     18 print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]")

File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/tensorflow/python/trackable/base.py:205, in no_automatic_dependency_tracking.<locals>._method_wrapper(self, *args, **kwargs)
    203 self._self_setattr_tracking = False  # pylint: disable=protected-access
    204 try:
--> 205   result = method(self, *args, **kwargs)
    206 finally:
    207   self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/keras/engine/functional.py:165, in Functional.__init__(self, inputs, outputs, name, trainable, **kwargs)
    156     if not all(
    157         [
    158             functional_utils.is_input_keras_tensor(t)
    159             for t in tf.nest.flatten(inputs)
    160         ]
    161     ):
    162         inputs, outputs = functional_utils.clone_graph_nodes(
    163             inputs, outputs
    164         )
--> 165 self._init_graph_network(inputs, outputs)

File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/tensorflow/python/trackable/base.py:205, in no_automatic_dependency_tracking.<locals>._method_wrapper(self, *args, **kwargs)
    203 self._self_setattr_tracking = False  # pylint: disable=protected-access
    204 try:
--> 205   result = method(self, *args, **kwargs)
    206 finally:
    207   self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/keras/engine/functional.py:264, in Functional._init_graph_network(self, inputs, outputs)
    261     self._input_coordinates.append((layer, node_index, tensor_index))
    263 # Keep track of the network's nodes and layers.
--> 264 nodes, nodes_by_depth, layers, _ = _map_graph_network(
    265     self.inputs, self.outputs
    266 )
    267 self._network_nodes = nodes
    268 self._nodes_by_depth = nodes_by_depth

File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/keras/engine/functional.py:1128, in _map_graph_network(inputs, outputs)
   1126 for x in tf.nest.flatten(node.keras_inputs):
   1127     if id(x) not in computable_tensors:
-> 1128         raise ValueError(
   1129             f"Graph disconnected: cannot obtain value for "
   1130             f'tensor {x} at layer "{layer.name}". '
   1131             "The following previous layers were accessed "
   1132             f"without issue: {layers_with_complete_input}"
   1133         )
   1134 for x in tf.nest.flatten(node.outputs):
   1135     computable_tensors.add(id(x))

ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, None, None, 3), dtype=tf.float32, name='input_2'), name='input_2', description="created by layer 'input_2'") at layer "conv1_pad". The following previous layers were accessed without issue: []

The operations in rn50 appear in two conceptual graphs. The first, defined when ResNet50V2(...) was invoked, contains all operations associated with the layers in the ResNet50 architecture. The second one, on the other hand, is defined when invoking Layer.__call__() of each layer (rn50, GAP, Dense and Activation).

When calling rn50_clf.get_layer("resnet50v2").output (which is equivalent to rn50_clf.get_layer("resnet50v2").get_output_at(0)), the Node from the first graph is retrieved. This Node is not associated with rn50_clf.input or logits, and thus the error is raised.

There are multiple ways to correctly access the Node from the second graph. One of them is to retrieve the input from the GAP layer, as it only appeared in one graph:

model = ke.inspection.expose(
  rn50_clf, {"name": "avg_pool", "link": "input"}, "predictions"
)
scores, cams = ke.gradcam(model, inputs, indices)

print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]")
print(f"cams:{cams.shape} in [{cams.min()}, {cams.max()}]")
scores:(4, 1) in [0.09053877741098404, 0.11063539236783981]
cams:(4, 224, 224, 1) in [0.0, 0.994608461856842]

Note

The alternatives ke.inspection.expose(rn50_clf, "resnet50v2", "predictions") and ke.inspection.expose(rn50_clf) would work as well. In the former, the last output node is retrieved. In the latter, the last input node (there’s only one) associated with the GAP layer is retrieved.

Access Nested Layer Signals

Another problem occurs when the global pooling layer is not part of layers set of the out-most model. While you can still collect its output using a name composition, we get a ValueError: Graph disconnected.

This problem occurs because Keras does not create Nodes for inner layers in a nested model, when that model is reused. Instead, the model is treated as a single operation in the conceptual graph, with a single new Node being created to represent it. Calling keras_explainable.inspection.expose() over the model will expand the parameter arguments into {"name": ("ResNet50V2", "avg_pool"), "link": "input", "node": "last"}, but because no new nodes were created for the GAP layer, the KerasTensor associated with the first conceptual graph is retrieved, and the error ensues.

rn50 = ResNet50V2(weights=None, include_top=False, pooling="avg")
rn50_clf = Sequential([
  Input([224, 224, 3], name="input_images"),
  rn50,
  Dense(10, name="logits"),
  Activation("softmax", name="predictions", dtype="float32"),
])

model = ke.inspection.expose(rn50_clf)
scores, cams = ke.gradcam(model, inputs, indices)

print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]")
print(f"cams:{cams.shape} in [{cams.min()}, {cams.max()}]")
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[6], line 9
      1 rn50 = ResNet50V2(weights=None, include_top=False, pooling="avg")
      2 rn50_clf = Sequential([
      3   Input([224, 224, 3], name="input_images"),
      4   rn50,
      5   Dense(10, name="logits"),
      6   Activation("softmax", name="predictions", dtype="float32"),
      7 ])
----> 9 model = ke.inspection.expose(rn50_clf)
     10 scores, cams = ke.gradcam(model, inputs, indices)
     12 print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]")

File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/keras_explainable/inspection.py:263, in expose(model, arguments, outputs)
    259 arguments = tolist(arguments)
    261 tensors = endpoints(model, outputs + arguments)
--> 263 return Model(
    264     inputs=model.inputs,
    265     outputs=tensors,
    266 )

File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/tensorflow/python/trackable/base.py:205, in no_automatic_dependency_tracking.<locals>._method_wrapper(self, *args, **kwargs)
    203 self._self_setattr_tracking = False  # pylint: disable=protected-access
    204 try:
--> 205   result = method(self, *args, **kwargs)
    206 finally:
    207   self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/keras/engine/functional.py:165, in Functional.__init__(self, inputs, outputs, name, trainable, **kwargs)
    156     if not all(
    157         [
    158             functional_utils.is_input_keras_tensor(t)
    159             for t in tf.nest.flatten(inputs)
    160         ]
    161     ):
    162         inputs, outputs = functional_utils.clone_graph_nodes(
    163             inputs, outputs
    164         )
--> 165 self._init_graph_network(inputs, outputs)

File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/tensorflow/python/trackable/base.py:205, in no_automatic_dependency_tracking.<locals>._method_wrapper(self, *args, **kwargs)
    203 self._self_setattr_tracking = False  # pylint: disable=protected-access
    204 try:
--> 205   result = method(self, *args, **kwargs)
    206 finally:
    207   self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/keras/engine/functional.py:264, in Functional._init_graph_network(self, inputs, outputs)
    261     self._input_coordinates.append((layer, node_index, tensor_index))
    263 # Keep track of the network's nodes and layers.
--> 264 nodes, nodes_by_depth, layers, _ = _map_graph_network(
    265     self.inputs, self.outputs
    266 )
    267 self._network_nodes = nodes
    268 self._nodes_by_depth = nodes_by_depth

File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/keras/engine/functional.py:1128, in _map_graph_network(inputs, outputs)
   1126 for x in tf.nest.flatten(node.keras_inputs):
   1127     if id(x) not in computable_tensors:
-> 1128         raise ValueError(
   1129             f"Graph disconnected: cannot obtain value for "
   1130             f'tensor {x} at layer "{layer.name}". '
   1131             "The following previous layers were accessed "
   1132             f"without issue: {layers_with_complete_input}"
   1133         )
   1134 for x in tf.nest.flatten(node.outputs):
   1135     computable_tensors.add(id(x))

ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, None, None, 3), dtype=tf.float32, name='input_3'), name='input_3', description="created by layer 'input_3'") at layer "conv1_pad". The following previous layers were accessed without issue: []

Warning

Since TensorFlow 2, nodes are no longer being stacked in _inbound_nodes for layers in nested models, which obstructs the access to intermediate signals contained in a nested model, and makes the remaining of this document obsolete. To avoid this problem, it is recommended to “flat out” the model before explaining it, or avoiding nesting models altogether.

For more information, see the GitHub issue #16123.

If you are using TensorFlow < 2.0, nodes are created for each operation in the inner model, and you may collect their internal signal by simply:

model = ke.inspection.expose(rn50_clf)
# ... or: ke.inspection.expose(rn50_clf, ("resnet50v2", "post_relu"))
# ... or: ke.inspection.expose(
#  rn50_clf, {"name": ("resnet50v2", "avg_pool"), "link": "input"}
# )

scores, cams = ke.gradcam(model, inputs, indices)

Note

The above works because expose() will recursively seek for a GAP layer within the nested models.