import warnings
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import tensorflow as tf
from keras import callbacks as callbacks_module
from keras.callbacks import Callback
from keras.engine import data_adapter
from keras.engine.training import _is_tpu_multi_host
from keras.engine.training import _minimum_control_deps
from keras.engine.training import potentially_ragged_concat
from keras.engine.training import reduce_per_replica
from keras.utils import tf_utils
from tensorflow.python.eager import context
from keras_explainable.inspection import SPATIAL_AXIS
[docs]def explain_step(
model: tf.keras.Model,
method: Callable,
data: Tuple[tf.Tensor],
spatial_axis: Tuple[int, int] = SPATIAL_AXIS,
postprocessing: Callable = None,
resizing: Optional[Union[bool, tf.Tensor]] = True,
**params,
) -> Tuple[tf.Tensor, tf.Tensor]:
inputs, indices, _ = data_adapter.unpack_x_y_sample_weight(data)
logits, maps = method(
model=model,
inputs=inputs,
indices=indices,
spatial_axis=spatial_axis,
**params,
)
if postprocessing is not None:
maps = postprocessing(maps, axis=spatial_axis)
if resizing is not None and resizing is not False:
if resizing is True:
resizing = tf.shape(inputs)[1:-1]
maps = tf.image.resize(maps, resizing)
return logits, maps
[docs]def make_explain_function(
model: tf.keras.Model,
method: Callable,
params: Dict[str, Any],
force: bool = False,
):
explain_function = getattr(model, "explain_function", None)
if explain_function is not None and not force:
return explain_function
def explain_function(iterator):
"""Runs a single explain step."""
def run_step(data):
outputs = explain_step(model, method, data, **params)
# Ensure counter is updated only if `test_step` succeeds.
with tf.control_dependencies(_minimum_control_deps(outputs)):
model._explain_counter.assign_add(1)
return outputs
if model._jit_compile:
run_step = tf.function(run_step, jit_compile=True, reduce_retracing=True)
data = next(iterator)
outputs = model.distribute_strategy.run(run_step, args=(data,))
outputs = reduce_per_replica(
outputs, model.distribute_strategy, reduction="concat"
)
return outputs
if not model.run_eagerly:
explain_function = tf.function(explain_function, reduce_retracing=True)
model.explain_function = explain_function
return explain_function
[docs]def make_data_handler(
model,
x,
y,
batch_size=None,
steps=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
):
dataset_types = (tf.compat.v1.data.Dataset, tf.data.Dataset)
if (
model._in_multi_worker_mode() or _is_tpu_multi_host(model.distribute_strategy)
) and isinstance(x, dataset_types):
try:
opts = tf.data.Options()
opts.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.DATA
)
x = x.with_options(opts)
except ValueError:
warnings.warn(
"Using evaluate with MultiWorkerMirroredStrategy "
"or TPUStrategy and AutoShardPolicy.FILE might lead to "
"out-of-order result. Consider setting it to "
"AutoShardPolicy.DATA.",
stacklevel=2,
)
return data_adapter.get_data_handler(
x=x,
y=y,
batch_size=batch_size,
steps_per_epoch=steps,
initial_epoch=0,
epochs=1,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
model=model,
steps_per_execution=model._steps_per_execution,
)
[docs]def explain(
method: Callable,
model: tf.keras.Model,
x: Union[np.ndarray, tf.Tensor, tf.data.Dataset],
y: Optional[Union[np.ndarray, tf.Tensor]] = None,
batch_size: Optional[int] = None,
verbose: Union[str, int] = "auto",
steps: Optional[int] = None,
callbacks: List[Callback] = None,
max_queue_size: int = 10,
workers: int = 1,
use_multiprocessing: bool = False,
force: bool = True,
**method_params,
) -> Tuple[np.ndarray, np.ndarray]:
"""Explain the outputs of ``model`` with respect to the inputs or an intermediate
signal, using an AI explaining method.
Usage:
.. code-block:: python
x = np.random.normal((1, 224, 224, 3))
y = np.asarray([[16, 32]])
model = tf.keras.applications.ResNet50V2(classifier_activation=None)
scores, maps = ke.explain(
ke.methods.gradient.gradients,
model, x, y,
postprocessing=filters.absolute_normalize,
)
Args:
method (Callable): An AI explaining function, as the ones contained in
`methods` module.
model (tf.keras.Model): The model whose predictions should be explained.
x (Union[np.ndarray, tf.Tensor, tf.data.Dataset]): the input data for the model.
y (Optional[Union[np.ndarray, tf.Tensor]], optional): the indices in the output
tensor that should be explained. If none, an activation map is computed
for each unit. Defaults to None.
batch_size (Optional[int], optional): the batch size used by ``method``.
Defaults to 32.
verbose (Union[str, int], optional): wether to show a progress bar during
the calculation of the explaining maps. Defaults to "auto".
steps (Optional[int], optional): the number of steps, if ``x`` is a
``tf.data.Dataset`` of unknown cardinallity. Defaults to None.
callbacks (List[Callback], optional): list of callbacks called during the
explaining procedure. Defaults to None.
max_queue_size (int, optional): the queue size when retrieving inputs.
Used if ``x`` is a generator. Defaults to 10.
workers (int, optional): the number of workers used when retrieving inputs.
Defaults to 1.
use_multiprocessing (bool, optional): wether to employ multi-process or
multi-threading when retrieving inputs, when ``x`` is a generator.
Defaults to False.
force (bool, optional): to force the creation of the explaining function.
Can be set to False if the same function is always applied to a model,
avoiding retracing. Defaults to True.
Besides the parameters described above, any named parameters passed to this function
will be collected into ``methods_params`` and passed onto the :func:`explain_step`
and ``method`` functions. The most common ones are:
- **indices_batch_dims** (int): The dimensions marked as ``batch`` when gathering
units described by ``y``. Ignore if ``y`` is None.
- **indices_axis** (int): The axes from which to gather units described by ``y``.
Ignore if ``y`` is None.
- **spatial_axis** (Tuple[int]): The axes containing the positional visual info.
We assume `inputs` to contain 2D images or videos in the shape
`(B1, B2, ..., BN, H, W, 3)`.
For 3D image data, set `spatial_axis` to `(1, 2, 3)` or `(-4, -3, -2)`.
- **postprocessing** (Callable): A function to process the activation maps before
normalization (most commonly adopted being `maximum(x, 0)` and `abs`).
Raises:
ValueError: the explaining method produced in an unexpected.
Returns:
Tuple[np.ndarray, np.ndarray]: logits and explaining maps tensors.
"""
if not hasattr(model, "_explain_counter"):
agg = tf.VariableAggregation.ONLY_FIRST_REPLICA
model._explain_counter = tf.Variable(0, dtype="int64", aggregation=agg)
outputs = None
with model.distribute_strategy.scope():
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
data_handler = make_data_handler(
model,
x,
y,
batch_size=batch_size,
steps=steps,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
)
# Container that configures and calls `tf.keras.Callback`s.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
model=model,
verbose=verbose,
epochs=1,
steps=data_handler.inferred_steps,
)
explain_function = make_explain_function(model, method, method_params, force)
model._explain_counter.assign(0)
callbacks.on_predict_begin()
batch_outputs = None
for _, iterator in data_handler.enumerate_epochs(): # Single epoch.
with data_handler.catch_stop_iteration():
for step in data_handler.steps():
callbacks.on_predict_batch_begin(step)
tmp_batch_outputs = explain_function(iterator)
if data_handler.should_sync:
context.async_wait()
batch_outputs = tmp_batch_outputs # No error, now safe to assign.
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(batch_output),
outputs,
batch_outputs,
)
end_step = step + data_handler.step_increment
callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs})
if batch_outputs is None:
raise ValueError(
"Unexpected result of `explain_function` "
"(Empty batch_outputs). Please use "
"`Model.compile(..., run_eagerly=True)`, or "
"`tf.config.run_functions_eagerly(True)` for more "
"information of where went wrong, or file a "
"issue/bug to `keras-explainable`."
)
callbacks.on_predict_end()
all_outputs = tf.__internal__.nest.map_structure_up_to(
batch_outputs, potentially_ragged_concat, outputs
)
return tf_utils.sync_to_numpy_or_python_type(all_outputs)
[docs]def partial_explain(method: Callable, **default_params):
"""Wrapper for explaining methods.
Args:
method (Callable): the explaining method being wrapped by ``explain``.
"""
def _partial_method_explain(*args, **params):
params = {**default_params, **params}
return explain(method, *args, **params)
_partial_method_explain.__name__ = f"{method.__name__}_explain"
return _partial_method_explain