Skip to main content

Interpretability Callbacks for Tensorflow 2.0

Project description

tf-explain

Pypi Version Build Status Documentation Status Python Versions Tensorflow Versions Code style: black

tf-explain implements interpretability methods as Tensorflow 2.x callbacks to ease neural network's understanding. See Introducing tf-explain, Interpretability for Tensorflow 2.0

Documentation: https://tf-explain.readthedocs.io

Installation

tf-explain is available on PyPi as an alpha release. To install it:

virtualenv venv -p python3.8
pip install tf-explain

tf-explain is compatible with Tensorflow 2.x. It is not declared as a dependency to let you choose between full and standalone-CPU versions. Additionally to the previous install, run:

# For CPU or GPU
pip install tensorflow==2.6.0

Opencv is also a dependency. To install it, run:

# For CPU or GPU
pip install opencv-python

Quickstart

tf-explain offers 2 ways to apply interpretability methods. The full list of methods is the Available Methods section.

On trained model

The best option is probably to load a trained model and apply the methods on it.

# Load pretrained model or your own
model = tf.keras.applications.vgg16.VGG16(weights="imagenet", include_top=True)

# Load a sample image (or multiple ones)
img = tf.keras.preprocessing.image.load_img(IMAGE_PATH, target_size=(224, 224))
img = tf.keras.preprocessing.image.img_to_array(img)
data = ([img], None)

# Start explainer
explainer = GradCAM()
grid = explainer.explain(data, model, class_index=281)  # 281 is the tabby cat index in ImageNet

explainer.save(grid, ".", "grad_cam.png")

During training

If you want to follow your model during the training, you can also use it as a Keras Callback, and see the results directly in TensorBoard.

from tf_explain.callbacks.grad_cam import GradCAMCallback

model = [...]

callbacks = [
    GradCAMCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        output_dir=output_dir,
    )
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Available Methods

  1. Activations Visualization
  2. Vanilla Gradients
  3. Gradients*Inputs
  4. Occlusion Sensitivity
  5. Grad CAM (Class Activation Maps)
  6. SmoothGrad
  7. Integrated Gradients

Activations Visualization

Visualize how a given input comes out of a specific activation layer

from tf_explain.callbacks.activations_visualization import ActivationsVisualizationCallback

model = [...]

callbacks = [
    ActivationsVisualizationCallback(
        validation_data=(x_val, y_val),
        layers_name=["activation_1"],
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Vanilla Gradients

Visualize gradients importance on input image

from tf_explain.callbacks.vanilla_gradients import VanillaGradientsCallback

model = [...]

callbacks = [
    VanillaGradientsCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Gradients*Inputs

Variant of Vanilla Gradients ponderating gradients with input values

from tf_explain.callbacks.gradients_inputs import GradientsInputsCallback

model = [...]

callbacks = [
    GradientsInputsCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Occlusion Sensitivity

Visualize how parts of the image affects neural network's confidence by occluding parts iteratively

from tf_explain.callbacks.occlusion_sensitivity import OcclusionSensitivityCallback

model = [...]

callbacks = [
    OcclusionSensitivityCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        patch_size=4,
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Occlusion Sensitivity for Tabby class (stripes differentiate tabby cat from other ImageNet cat classes)

Grad CAM

Visualize how parts of the image affects neural network's output by looking into the activation maps

From Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization

from tf_explain.callbacks.grad_cam import GradCAMCallback

model = [...]

callbacks = [
    GradCAMCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        output_dir=output_dir,
    )
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

SmoothGrad

Visualize stabilized gradients on the inputs towards the decision

From SmoothGrad: removing noise by adding noise

from tf_explain.callbacks.smoothgrad import SmoothGradCallback

model = [...]

callbacks = [
    SmoothGradCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        num_samples=20,
        noise=1.,
        output_dir=output_dir,
    )
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Integrated Gradients

Visualize an average of the gradients along the construction of the input towards the decision

From Axiomatic Attribution for Deep Networks

from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback

model = [...]

callbacks = [
    IntegratedGradientsCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        n_steps=20,
        output_dir=output_dir,
    )
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Roadmap

Contributing

To contribute to the project, please read the dedicated section.

Citation

A citation file is available for citing this work.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tf-explain-0.3.1.tar.gz (24.7 kB view details)

Uploaded Source

Built Distribution

tf_explain-0.3.1-py3-none-any.whl (43.6 kB view details)

Uploaded Python 3

File details

Details for the file tf-explain-0.3.1.tar.gz.

File metadata

  • Download URL: tf-explain-0.3.1.tar.gz
  • Upload date:
  • Size: 24.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.5.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.8

File hashes

Hashes for tf-explain-0.3.1.tar.gz
Algorithm Hash digest
SHA256 508b53afc1fca3ab26d8e5d92d720a5ce3c1d37cf4479580b2ea7b9f4b7d7ee9
MD5 46be21785336c6576522a9f48888034a
BLAKE2b-256 1dadd9467ccd256ff5764f08bdd76da859539b35f01a9059fc8c2e412e6b912a

See more details on using hashes here.

Provenance

File details

Details for the file tf_explain-0.3.1-py3-none-any.whl.

File metadata

  • Download URL: tf_explain-0.3.1-py3-none-any.whl
  • Upload date:
  • Size: 43.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.5.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.8

File hashes

Hashes for tf_explain-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 91586bfce130a3e2ff6956f86fcda05d034e6eec466ce0adfc6ac5f9d52fb3ad
MD5 a4b7441caf4840e6bdbec6054467c123
BLAKE2b-256 43754611078380c5f5933b71378be6733e43f10667b4ca15de6e09e12a2d8025

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page