Interpretability Callbacks for Tensorflow 2.0
Project description
tf-explain
tf-explain implements interpretability methods as Tensorflow 2.0 callbacks to ease neural network's understanding.
See Introducing tf-explain, Interpretability for Tensorflow 2.0
Documentation: https://tf-explain.readthedocs.io
Installation
tf-explain is available on PyPi as an alpha release. To install it:
virtualenv venv -p python3.6
pip install tf-explain
tf-explain is compatible with Tensorflow 2. It is not declared as a dependency to let you choose between CPU and GPU versions. Additionally to the previous install, run:
# For CPU version
pip install tensorflow==2.0.0
# For GPU version
pip install tensorflow-gpu==2.0.0
Available Methods
- Activations Visualization
- Vanilla Gradients
- Gradients*Inputs
- Occlusion Sensitivity
- Grad CAM (Class Activation Maps)
- SmoothGrad
- Integrated Gradients
Activations Visualization
Visualize how a given input comes out of a specific activation layer
from tf_explain.callbacks.activations_visualization import ActivationsVisualizationCallback
model = [...]
callbacks = [
ActivationsVisualizationCallback(
validation_data=(x_val, y_val),
layers_name=["activation_1"],
output_dir=output_dir,
),
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
Vanilla Gradients
Visualize gradients importance on input image
from tf_explain.callbacks.vanilla_gradients import VanillaGradientsCallback
model = [...]
callbacks = [
VanillaGradientsCallback(
validation_data=(x_val, y_val),
class_index=0,
output_dir=output_dir,
),
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
Gradients*Inputs
Variant of Vanilla Gradients ponderating gradients with input values
from tf_explain.callbacks.gradients_inputs import GradientsInputsCallback
model = [...]
callbacks = [
GradientsInputsCallback(
validation_data=(x_val, y_val),
class_index=0,
output_dir=output_dir,
),
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
Occlusion Sensitivity
Visualize how parts of the image affects neural network's confidence by occluding parts iteratively
from tf_explain.callbacks.occlusion_sensitivity import OcclusionSensitivityCallback
model = [...]
callbacks = [
OcclusionSensitivityCallback(
validation_data=(x_val, y_val),
class_index=0,
patch_size=4,
output_dir=output_dir,
),
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
Occlusion Sensitivity for Tabby class (stripes differentiate tabby cat from other ImageNet cat classes)
Grad CAM
Visualize how parts of the image affects neural network's output by looking into the activation maps
From Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
from tf_explain.callbacks.grad_cam import GradCAMCallback
model = [...]
callbacks = [
GradCAMCallback(
validation_data=(x_val, y_val),
class_index=0,
output_dir=output_dir,
)
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
SmoothGrad
Visualize stabilized gradients on the inputs towards the decision
From SmoothGrad: removing noise by adding noise
from tf_explain.callbacks.smoothgrad import SmoothGradCallback
model = [...]
callbacks = [
SmoothGradCallback(
validation_data=(x_val, y_val),
class_index=0,
num_samples=20,
noise=1.,
output_dir=output_dir,
)
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
Integrated Gradients
Visualize an average of the gradients along the construction of the input towards the decision
From Axiomatic Attribution for Deep Networks
from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback
model = [...]
callbacks = [
IntegratedGradientsCallback(
validation_data=(x_val, y_val),
class_index=0,
n_steps=20,
output_dir=output_dir,
)
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
Visualizing the results
When you use the callbacks, the output files are created in the logs
directory.
You can see them in Tensorboard with the following command: tensorboard --logdir logs
Roadmap
- Subclassing API Support
- Additional Methods
- Auto-generated API Documentation & Documentation Testing
Contributing
To contribute to the project, please read the dedicated section.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file tf-explain-0.2.0.tar.gz
.
File metadata
- Download URL: tf-explain-0.2.0.tar.gz
- Upload date:
- Size: 19.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.7.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f85e25082684e049757569eb31d77f95c14f2762ca21f7ae7f4dfd69d996d264 |
|
MD5 | 814aaa70bb92c2b3db8f18f13375785e |
|
BLAKE2b-256 | 67b80ec0f308457bda9133d91a745385d0b3000e8fff6b1e4d22d9174f31e4dd |
Provenance
File details
Details for the file tf_explain-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: tf_explain-0.2.0-py3-none-any.whl
- Upload date:
- Size: 41.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.7.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2e2efeae1b4da17bc9028c28d81089a5212f630c8a89b29b2ca5483b6b08e60f |
|
MD5 | c228dd3861a7d8004b2791d7047517e3 |
|
BLAKE2b-256 | 20ce993daad6523c5dfe250e55746eb137677b66ac085670180d9fa17783170f |