Skip to main content

Interpretability Callbacks for Tensorflow 2.0

Project description

tf-explain

Pypi Version Build Status Documentation Status Python Versions Tensorflow Versions Code style: black

tf-explain implements interpretability methods as Tensorflow 2.0 callbacks to ease neural network's understanding.

Installation

tf-explain is available on Pypi as an alpha release. To install it:

virtualenv venv -p python3.6
pip install tf-explain

tf-explain is compatible with Tensorflow 2. It is not declared as a dependency to let you choose between CPU and GPU versions. Additionally to the previous install, run:

# For CPU version
pip install tensorflow==2.0.0-beta1
# For GPU version
pip install tensorflow-gpu==2.0.0-beta1

Available Methods

  1. Activations Visualization
  2. Occlusion Sensitivity
  3. Grad CAM (Class Activation Maps)

Activations Visualization

Visualize how a given input comes out of a specific activation layer

from tf_explain.callbacks.activations_visualization import ActivationsVisualizationCallback

model = [...]

callbacks = [
    ActivationsVisualizationCallback(
        validation_data=(x_val, y_val),
        layers_name=["activation_1"],
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Occlusion Sensitivity

Visualize how parts of the image affects neural network's confidence by occluding parts iteratively

from tf_explain.callbacks.occlusion_sensitivity import OcclusionSensitivityCallback

model = [...]

callbacks = [
    OcclusionSensitivityCallback(
        validation_data=(x_val, y_val),
        patch_size=4,
        class_index=0,
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Occlusion Sensitivity for Tabby class (stripes differentiate tabby cat from other ImageNet cat classes)

Grad CAM

Visualize how parts of the image affects neural network's output by looking into the activation maps

From Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization

from tf_explain.callbacks.grad_cam import GradCAMCallback

model = [...]

callbacks = [
    GradCAMCallback(
        validation_data=(x_val, y_val),
        layer_name="activation_1",
        class_index=0,
        output_dir=output_dir,
    )
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Roadmap

Next features are listed as issues with the roadmap label.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tf-explain-0.0.1a0.tar.gz (10.3 kB view details)

Uploaded Source

Built Distribution

tf_explain-0.0.1a0-py3-none-any.whl (19.6 kB view details)

Uploaded Python 3

File details

Details for the file tf-explain-0.0.1a0.tar.gz.

File metadata

  • Download URL: tf-explain-0.0.1a0.tar.gz
  • Upload date:
  • Size: 10.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.6.8

File hashes

Hashes for tf-explain-0.0.1a0.tar.gz
Algorithm Hash digest
SHA256 15a46bfa5ac15fc7670238bd8d5aeadbe99822215610fd87a0b99e8022262307
MD5 66acba3cbf9f6095fac1294df5ae371d
BLAKE2b-256 0a6f7753104e8db45355ae91fa1dd764e6c479590c8c989101d652e6a71e0f93

See more details on using hashes here.

Provenance

File details

Details for the file tf_explain-0.0.1a0-py3-none-any.whl.

File metadata

  • Download URL: tf_explain-0.0.1a0-py3-none-any.whl
  • Upload date:
  • Size: 19.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.6.8

File hashes

Hashes for tf_explain-0.0.1a0-py3-none-any.whl
Algorithm Hash digest
SHA256 e8d86ee2f70d09aa8660d079c715ab7fd3d1808e6fb66f2b81443db719285607
MD5 7161612ce1595138fbf92aafd3869dd8
BLAKE2b-256 268f3005fc0a7dc1ffbe29ef794e09cf86e60158142c955f9ee9cdc4193a6aeb

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page