Skip to main content

Framework-agnostic saliency methods

Project description

Saliency Library

Updates

🔴   Now framework-agnostic! (Example core notebook)  🔴

🔗   For further explanation of the methods and more examples of the resulting maps, see our Github Pages website  🔗

If upgrading from an older version, update old imports to import saliency.tf1 as saliency. We provide wrappers to make the framework-agnostic version compatible with TF1 models. (Example TF1 notebook)

🔴   Added Performance Information Curve (PIC) - a human independent metric for evaluating the quality of saliency methods. (Example notebook)  🔴

Saliency Methods

This repository contains code for the following saliency techniques:

*Developed by PAIR.

This list is by no means comprehensive. We are accepting pull requests to add new methods!

Evaluation of Saliency Methods

The repository provides an implementation of Performance Information Curve (PIC) - a human independent metric for evaluating the quality of saliency methods (paper, poster, code, notebook).

Download

# To install the core subpackage:
pip install saliency

# To install core and tf1 subpackages:
pip install saliency[tf1]

or for the development version:

git clone https://github.com/pair-code/saliency
cd saliency

Usage

The saliency library has two subpackages:

  • core uses a generic call_model_function which can be used with any ML framework.
  • tf1 accepts input/output tensors directly, and sets up the necessary graph operations for each method.

Core

Each saliency mask class extends from the CoreSaliency base class. This class contains the following methods:

  • GetMask(x_value, call_model_function, call_model_args=None): Returns a mask of the shape of non-batched x_value given by the saliency technique.
  • GetSmoothedMask(x_value, call_model_function, call_model_args=None, stdev_spread=.15, nsamples=25, magnitude=True): Returns a mask smoothed of the shape of non-batched x_value with the SmoothGrad technique.

The visualization module contains two methods for saliency visualization:

  • VisualizeImageGrayscale(image_3d, percentile): Marginalizes across the absolute value of each channel to create a 2D single channel image, and clips the image at the given percentile of the distribution. This method returns a 2D tensor normalized between 0 to 1.
  • VisualizeImageDiverging(image_3d, percentile): Marginalizes across the value of each channel to create a 2D single channel image, and clips the image at the given percentile of the distribution. This method returns a 2D tensor normalized between -1 to 1 where zero remains unchanged.

If the sign of the value given by the saliency mask is not important, then use VisualizeImageGrayscale, otherwise use VisualizeImageDiverging. See the SmoothGrad paper for more details on which visualization method to use.

call_model_function

call_model_function is how we pass inputs to a given model and receive the outputs necessary to compute saliency masks. The description of this method and expected output format is in the CoreSaliency description, as well as separately for each method.

Examples

This example iPython notebook showing these techniques is a good starting place.

Here is a condensed example of using IG+SmoothGrad with TensorFlow 2:

import saliency.core as saliency
import tensorflow as tf

...

# call_model_function construction here.
def call_model_function(x_value_batched, call_model_args, expected_keys):
	tape = tf.GradientTape()
	grads = np.array(tape.gradient(output_layer, images))
	return {saliency.INPUT_OUTPUT_GRADIENTS: grads}

...

# Load data.
image = GetImagePNG(...)

# Compute IG+SmoothGrad.
ig_saliency = saliency.IntegratedGradients()
smoothgrad_ig = ig_saliency.GetSmoothedMask(image, 
											call_model_function, 
                                            call_model_args=None)

# Compute a 2D tensor for visualization.
grayscale_visualization = saliency.VisualizeImageGrayscale(
    smoothgrad_ig)

TF1

Each saliency mask class extends from the TF1Saliency base class. This class contains the following methods:

  • __init__(graph, session, y, x): Constructor of the SaliencyMask. This can modify the graph, or sometimes create a new graph. Often this will add nodes to the graph, so this shouldn't be called continuously. y is the output tensor to compute saliency masks with respect to, x is the input tensor with the outer most dimension being batch size.
  • GetMask(x_value, feed_dict): Returns a mask of the shape of non-batched x_value given by the saliency technique.
  • GetSmoothedMask(x_value, feed_dict): Returns a mask smoothed of the shape of non-batched x_value with the SmoothGrad technique.

The visualization module contains two visualization methods:

  • VisualizeImageGrayscale(image_3d, percentile): Marginalizes across the absolute value of each channel to create a 2D single channel image, and clips the image at the given percentile of the distribution. This method returns a 2D tensor normalized between 0 to 1.
  • VisualizeImageDiverging(image_3d, percentile): Marginalizes across the value of each channel to create a 2D single channel image, and clips the image at the given percentile of the distribution. This method returns a 2D tensor normalized between -1 to 1 where zero remains unchanged.

If the sign of the value given by the saliency mask is not important, then use VisualizeImageGrayscale, otherwise use VisualizeImageDiverging. See the SmoothGrad paper for more details on which visualization method to use.

Examples

This example iPython notebook shows these techniques is a good starting place.

Another example of using GuidedBackprop with SmoothGrad from TensorFlow:

from saliency.tf1 import GuidedBackprop
from saliency.tf1 import VisualizeImageGrayscale
import tensorflow.compat.v1 as tf

...
# Tensorflow graph construction here.
y = logits[5]
x = tf.placeholder(...)
...

# Compute guided backprop.
# NOTE: This creates another graph that gets cached, try to avoid creating many
# of these.
guided_backprop_saliency = GuidedBackprop(graph, session, y, x)

...
# Load data.
image = GetImagePNG(...)
...

smoothgrad_guided_backprop =
    guided_backprop_saliency.GetMask(image, feed_dict={...})

# Compute a 2D tensor for visualization.
grayscale_visualization = visualization.VisualizeImageGrayscale(
    smoothgrad_guided_backprop)

Conclusion/Disclaimer

If you have any questions or suggestions for improvements to this library, please contact the owners of the PAIR-code/saliency repository.

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

saliency-0.2.0.tar.gz (52.4 kB view details)

Uploaded Source

Built Distribution

saliency-0.2.0-py2.py3-none-any.whl (86.2 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file saliency-0.2.0.tar.gz.

File metadata

  • Download URL: saliency-0.2.0.tar.gz
  • Upload date:
  • Size: 52.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.7.13

File hashes

Hashes for saliency-0.2.0.tar.gz
Algorithm Hash digest
SHA256 af724bd8eeae15353a166b55daa6865fb1455a3b47bb92d04fc6b7b98e8b919c
MD5 dbab8783b4abf1d73ff5b37f352b0b8c
BLAKE2b-256 a0e0c842add0207e21b29bfd1e508f947ac20a4f6adee1070ef044146fce84f5

See more details on using hashes here.

File details

Details for the file saliency-0.2.0-py2.py3-none-any.whl.

File metadata

  • Download URL: saliency-0.2.0-py2.py3-none-any.whl
  • Upload date:
  • Size: 86.2 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.7.13

File hashes

Hashes for saliency-0.2.0-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 eea22bc7b7ba1d93307a620ba29b3f1d21340c7def1adeda9fbae399e184afab
MD5 c6f356b353a8d5e596066e5327ca5d7a
BLAKE2b-256 a4d3da39e3615293e3dc141e45ce1c0f01e4e1c71d32d1a73619e3f1576e17a6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page