Exposing Intermediate Signals¶
This page details the exposure procedure, necessary for most AI explaining
methods, and which can be easened with the help of the
expose()
function.
Simple Exposition Examples¶
Many explaining techniques require us to expose the intermediate tensors so their respective signals can be used, or so the gradient of the output can be computed with respect to their signals. For example, Grad-CAM computes the gradient of an output unit with respect to the activation signal advent from the last positional layer in the model:
with tf.GradientTape() as tape:
logits, activations = model(x)
gradients = tape.batch_jacobian(logits, activations)
Which evidently means the activations
signal, a tensor of
shape (batch, height, width, ..., kernels)
must be available at runtime.
For that to happen, we must redefine the model, setting its outputs
to contain the KerasTensor
’s objects that reference both
logits
and activations
tensors:
import numpy as np
import tensorflow as tf
from keras import Input, Model, Sequential
from keras.applications import ResNet50V2
from keras.layers import Activation, Dense, GlobalAveragePooling2D
import keras_explainable as ke
rn50 = ResNet50V2(weights=None, classifier_activation=None)
# activations_tensor = rn50.get_layer("avg_pool").input # or...
activations_tensor = rn50.get_layer("post_relu").output
model = Model(rn50.input, [rn50.output, activations_tensor])
print(model.name)
print(f" input: {model.input}")
print(" outputs:")
for o in model.outputs:
print(f" {o}")
model
input: KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'")
outputs:
KerasTensor(type_spec=TensorSpec(shape=(None, 1000), dtype=tf.float32, name=None), name='predictions/BiasAdd:0', description="created by layer 'predictions'")
KerasTensor(type_spec=TensorSpec(shape=(None, 7, 7, 2048), dtype=tf.float32, name=None), name='post_relu/Relu:0', description="created by layer 'post_relu'")
Which can be simplified with:
model = ke.inspection.expose(rn50)
The expose()
function inspects the model,
seeking for the logits layer (the last containing a kernel property) and the
global pooling layer, an instance of a GlobalPooling
or
Flatten
layer classes. The output of the former and the input of the
latter are collected and a new model is defined.
You can also manually indicate the name of the argument and output layers. All options bellow are equivalent:
model = ke.inspection.expose(rn50, "post_relu", "predictions")
model = ke.inspection.expose(
rn50,
{"name": "post_relu", "link": "output"},
{"name": "predictions"},
)
model = ke.inspection.expose(
rn50,
{"name": "post_relu", "link": "output", "node": 0},
{"name": "predictions", "link": "output", "node": 0},
)
model = ke.inspection.expose(
rn50,
{"name": "avg_pool", "link": "input"},
"predictions",
)
Grad-CAM (or Grad-CAM++) can be called immediately after that:
inputs = np.random.normal(size=(4, 224, 224, 3))
indices = np.asarray([[4], [9], [0], [2]])
scores, cams = ke.gradcam(model, inputs, indices)
print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]")
print(f"cams:{cams.shape} in [{cams.min()}, {cams.max()}]")
scores:(4, 1) in [-0.1541651487350464, -0.0010260604321956635]
cams:(4, 224, 224, 1) in [0.0, 0.9939659237861633]
Exposing Nested Models¶
Unfortunately, some model’s topologies can make exposition a little tricky.
An example of this is when nesting multiple models, producing more than one
Input
object and multiple conceptual graphs at once.
Then, if one naively collects KerasTensor
’s from the model, disconnected
nodes may be retrieved, resulting in the exception ValueError: Graph disconnected
being raised:
rn50 = ResNet50V2(weights=None, include_top=False)
x = Input([224, 224, 3], name="input_images")
y = rn50(x)
y = GlobalAveragePooling2D(name="avg_pool")(y)
y = Dense(10, name="logits")(y)
y = Activation("softmax", name="predictions", dtype="float32")(y)
rn50_clf = Model(x, y, name="resnet50v2_clf")
rn50_clf.summary()
logits = rn50_clf.get_layer("logits").output
activations = rn50_clf.get_layer("resnet50v2").output
model = tf.keras.Model(rn50_clf.input, [logits, activations])
scores, cams = ke.gradcam(model, inputs, indices)
print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]")
print(f"cams:{cams.shape} in [{cams.min()}, {cams.max()}]")
Model: "resnet50v2_clf"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_images (InputLayer) [(None, 224, 224, 3)] 0
resnet50v2 (Functional) (None, None, None, 2048) 23564800
avg_pool (GlobalAveragePool (None, 2048) 0
ing2D)
logits (Dense) (None, 10) 20490
predictions (Activation) (None, 10) 0
=================================================================
Total params: 23,585,290
Trainable params: 23,539,850
Non-trainable params: 45,440
_________________________________________________________________
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[3], line 15
12 logits = rn50_clf.get_layer("logits").output
13 activations = rn50_clf.get_layer("resnet50v2").output
---> 15 model = tf.keras.Model(rn50_clf.input, [logits, activations])
16 scores, cams = ke.gradcam(model, inputs, indices)
18 print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]")
File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/tensorflow/python/trackable/base.py:205, in no_automatic_dependency_tracking.<locals>._method_wrapper(self, *args, **kwargs)
203 self._self_setattr_tracking = False # pylint: disable=protected-access
204 try:
--> 205 result = method(self, *args, **kwargs)
206 finally:
207 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/keras/engine/functional.py:165, in Functional.__init__(self, inputs, outputs, name, trainable, **kwargs)
156 if not all(
157 [
158 functional_utils.is_input_keras_tensor(t)
159 for t in tf.nest.flatten(inputs)
160 ]
161 ):
162 inputs, outputs = functional_utils.clone_graph_nodes(
163 inputs, outputs
164 )
--> 165 self._init_graph_network(inputs, outputs)
File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/tensorflow/python/trackable/base.py:205, in no_automatic_dependency_tracking.<locals>._method_wrapper(self, *args, **kwargs)
203 self._self_setattr_tracking = False # pylint: disable=protected-access
204 try:
--> 205 result = method(self, *args, **kwargs)
206 finally:
207 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/keras/engine/functional.py:264, in Functional._init_graph_network(self, inputs, outputs)
261 self._input_coordinates.append((layer, node_index, tensor_index))
263 # Keep track of the network's nodes and layers.
--> 264 nodes, nodes_by_depth, layers, _ = _map_graph_network(
265 self.inputs, self.outputs
266 )
267 self._network_nodes = nodes
268 self._nodes_by_depth = nodes_by_depth
File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/keras/engine/functional.py:1128, in _map_graph_network(inputs, outputs)
1126 for x in tf.nest.flatten(node.keras_inputs):
1127 if id(x) not in computable_tensors:
-> 1128 raise ValueError(
1129 f"Graph disconnected: cannot obtain value for "
1130 f'tensor {x} at layer "{layer.name}". '
1131 "The following previous layers were accessed "
1132 f"without issue: {layers_with_complete_input}"
1133 )
1134 for x in tf.nest.flatten(node.outputs):
1135 computable_tensors.add(id(x))
ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, None, None, 3), dtype=tf.float32, name='input_2'), name='input_2', description="created by layer 'input_2'") at layer "conv1_pad". The following previous layers were accessed without issue: []
The operations in rn50
appear in two conceptual graphs. The first, defined
when ResNet50V2(...)
was invoked, contains all operations associated with the layers
in the ResNet50 architecture. The second one, on the other hand, is defined when
invoking Layer.__call__()
of each layer (rn50
, GAP
, Dense
and
Activation
).
When calling rn50_clf.get_layer("resnet50v2").output
(which is equivalent
to rn50_clf.get_layer("resnet50v2").get_output_at(0)
), the Node
from the first graph is retrieved.
This Node
is not associated with rn50_clf.input
or logits
, and thus
the error is raised.
There are multiple ways to correctly access the Node from the second graph. One of them
is to retrieve the input from the GAP
layer, as it only appeared in one graph:
model = ke.inspection.expose(
rn50_clf, {"name": "avg_pool", "link": "input"}, "predictions"
)
scores, cams = ke.gradcam(model, inputs, indices)
print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]")
print(f"cams:{cams.shape} in [{cams.min()}, {cams.max()}]")
scores:(4, 1) in [0.09053877741098404, 0.11063539236783981]
cams:(4, 224, 224, 1) in [0.0, 0.994608461856842]
Note
The alternatives ke.inspection.expose(rn50_clf, "resnet50v2", "predictions")
and ke.inspection.expose(rn50_clf)
would work as well.
In the former, the last output node is retrieved.
In the latter, the last input node (there’s only one) associated
with the GAP
layer is retrieved.
Access Nested Layer Signals¶
Another problem occurs when the global pooling layer is not part of layers set
of the out-most model. While you can still collect its output using a name
composition, we get a ValueError: Graph disconnected
.
This problem occurs because Keras does not create Nodes
for inner layers in a nested
model, when that model is reused. Instead, the model is treated as a single operation
in the conceptual graph, with a single new Node
being created to represent it.
Calling keras_explainable.inspection.expose()
over the model will expand the
parameter arguments
into {"name": ("ResNet50V2", "avg_pool"), "link": "input", "node": "last"}
,
but because no new nodes were created for the GAP
layer, the KerasTensor
associated with the first conceptual graph is retrieved, and the error ensues.
rn50 = ResNet50V2(weights=None, include_top=False, pooling="avg")
rn50_clf = Sequential([
Input([224, 224, 3], name="input_images"),
rn50,
Dense(10, name="logits"),
Activation("softmax", name="predictions", dtype="float32"),
])
model = ke.inspection.expose(rn50_clf)
scores, cams = ke.gradcam(model, inputs, indices)
print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]")
print(f"cams:{cams.shape} in [{cams.min()}, {cams.max()}]")
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[6], line 9
1 rn50 = ResNet50V2(weights=None, include_top=False, pooling="avg")
2 rn50_clf = Sequential([
3 Input([224, 224, 3], name="input_images"),
4 rn50,
5 Dense(10, name="logits"),
6 Activation("softmax", name="predictions", dtype="float32"),
7 ])
----> 9 model = ke.inspection.expose(rn50_clf)
10 scores, cams = ke.gradcam(model, inputs, indices)
12 print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]")
File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/keras_explainable/inspection.py:263, in expose(model, arguments, outputs)
259 arguments = tolist(arguments)
261 tensors = endpoints(model, outputs + arguments)
--> 263 return Model(
264 inputs=model.inputs,
265 outputs=tensors,
266 )
File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/tensorflow/python/trackable/base.py:205, in no_automatic_dependency_tracking.<locals>._method_wrapper(self, *args, **kwargs)
203 self._self_setattr_tracking = False # pylint: disable=protected-access
204 try:
--> 205 result = method(self, *args, **kwargs)
206 finally:
207 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/keras/engine/functional.py:165, in Functional.__init__(self, inputs, outputs, name, trainable, **kwargs)
156 if not all(
157 [
158 functional_utils.is_input_keras_tensor(t)
159 for t in tf.nest.flatten(inputs)
160 ]
161 ):
162 inputs, outputs = functional_utils.clone_graph_nodes(
163 inputs, outputs
164 )
--> 165 self._init_graph_network(inputs, outputs)
File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/tensorflow/python/trackable/base.py:205, in no_automatic_dependency_tracking.<locals>._method_wrapper(self, *args, **kwargs)
203 self._self_setattr_tracking = False # pylint: disable=protected-access
204 try:
--> 205 result = method(self, *args, **kwargs)
206 finally:
207 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/keras/engine/functional.py:264, in Functional._init_graph_network(self, inputs, outputs)
261 self._input_coordinates.append((layer, node_index, tensor_index))
263 # Keep track of the network's nodes and layers.
--> 264 nodes, nodes_by_depth, layers, _ = _map_graph_network(
265 self.inputs, self.outputs
266 )
267 self._network_nodes = nodes
268 self._nodes_by_depth = nodes_by_depth
File /opt/hostedtoolcache/Python/3.10.8/x64/lib/python3.10/site-packages/keras/engine/functional.py:1128, in _map_graph_network(inputs, outputs)
1126 for x in tf.nest.flatten(node.keras_inputs):
1127 if id(x) not in computable_tensors:
-> 1128 raise ValueError(
1129 f"Graph disconnected: cannot obtain value for "
1130 f'tensor {x} at layer "{layer.name}". '
1131 "The following previous layers were accessed "
1132 f"without issue: {layers_with_complete_input}"
1133 )
1134 for x in tf.nest.flatten(node.outputs):
1135 computable_tensors.add(id(x))
ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, None, None, 3), dtype=tf.float32, name='input_3'), name='input_3', description="created by layer 'input_3'") at layer "conv1_pad". The following previous layers were accessed without issue: []
Warning
Since TensorFlow 2, nodes are no longer being stacked in _inbound_nodes
for layers in nested models, which obstructs the access to intermediate
signals contained in a nested model, and makes the remaining of this
document obsolete.
To avoid this problem, it is recommended to “flat out” the model before
explaining it, or avoiding nesting models altogether.
For more information, see the GitHub issue #16123.
If you are using TensorFlow < 2.0, nodes are created for each operation in the inner model, and you may collect their internal signal by simply:
model = ke.inspection.expose(rn50_clf)
# ... or: ke.inspection.expose(rn50_clf, ("resnet50v2", "post_relu"))
# ... or: ke.inspection.expose(
# rn50_clf, {"name": ("resnet50v2", "avg_pool"), "link": "input"}
# )
scores, cams = ke.gradcam(model, inputs, indices)
Note
The above works because expose()
will recursively seek for a GAP
layer within the nested models.