Skip to main content

Distrax: Probability distributions in JAX.

Project description

Distrax

CI status

Distrax is a lightweight library of probability distributions and bijectors. It acts as a JAX-native reimplementation of a subset of TensorFlow Probability (TFP), with some new features and emphasis on extensibility.

Design Principles

The general design principles for the DeepMind JAX Ecosystem are addressed in this blog. Additionally, Distrax places emphasis on the following:

  1. Readability. Distrax implementations are intended to be self-contained and read as close to the underlying math as possible.
  2. Extensibility. We have made it as simple as possible for users to define their own distribution or bijector. This is useful for example in reinforcement learning, where users may wish to define custom behavior for probabilistic agent policies.
  3. Compatibility. Distrax is not intended as a replacement for TFP, and TFP contains many advanced features that we do not intend to replicate. To this end, we have made the APIs for distributions and bijectors as cross-compatible as possible, and provide utilities for transforming between equivalent Distrax and TFP classes.

Features

Distributions

Distributions in Distrax are simple to define and use, particularly if you're used to TFP. Let's compare the two side-by-side:

import distrax
import jax
import jax.numpy as jnp

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

key = jax.random.PRNGKey(1234)
mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.1, 0.2, 0.3])

dist_distrax = distrax.MultivariateNormalDiag(mu, sigma)
dist_tfp = tfd.MultivariateNormalDiag(mu, sigma)

samples = dist_distrax.sample(seed=key)

# Both print 1.775
print(dist_distrax.log_prob(samples))
print(dist_tfp.log_prob(samples))

In addition to behaving consistently, Distrax distributions and TFP distributions are cross-compatible. For example:

mu_0 = jnp.array([-1., 0., 1.])
sigma_0 = jnp.array([0.1, 0.2, 0.3])
dist_distrax = distrax.MultivariateNormalDiag(mu_0, sigma_0)

mu_1 = jnp.array([1., 2., 3.])
sigma_1 = jnp.array([0.2, 0.3, 0.4])
dist_tfp = tfd.MultivariateNormalDiag(mu_1, sigma_1)

# Both print 85.237
print(dist_distrax.kl_divergence(dist_tfp))
print(tfd.kl_divergence(dist_distrax, dist_tfp))

Distrax distributions implement the method sample_and_log_prob, which provides samples and their log-probability in one line. For some distributions, this is more efficient than calling separately sample and log_prob:

mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.1, 0.2, 0.3])
dist_distrax = distrax.MultivariateNormalDiag(mu, sigma)

samples = dist_distrax.sample(seed=key, sample_shape=())
log_prob = dist_distrax.log_prob(samples)

# A one-line equivalent of the above is:
samples, log_prob = dist_distrax.sample_and_log_prob(seed=key, sample_shape=())

TFP distributions can be passed to Distrax meta-distributions as inputs. For example:

key = jax.random.PRNGKey(1234)

mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.2, 0.3, 0.4])
dist_tfp = tfd.Normal(mu, sigma)

metadist_distrax = distrax.Independent(dist_tfp, reinterpreted_batch_ndims=1)
samples = metadist_distrax.sample(seed=key)
print(metadist_distrax.log_prob(samples))  # Prints 0.38871175

To use Distrax distributions in TFP meta-distributions, Distrax provides the wrapper to_tfp. A wrapped Distrax distribution can be directly used in TFP:

key = jax.random.PRNGKey(1234)

distrax_dist = distrax.Normal(0., 1.)
wrapped_dist = distrax.to_tfp(distrax_dist)
metadist_tfp = tfd.Sample(wrapped_dist, sample_shape=[3])

samples = metadist_tfp.sample(seed=key)
print(metadist_tfp.log_prob(samples))  # Prints -3.3409896

Bijectors

A "bijector" in Distrax is an invertible function that knows how to compute its Jacobian determinant. Bijectors can be used to create complex distributions by transforming simpler ones. Distrax bijectors are functionally similar to TFP bijectors, with a few API differences. Here is an example comparing the two:

import distrax
import jax.numpy as jnp

from tensorflow_probability.substrates import jax as tfp
tfb = tfp.bijectors
tfd = tfp.distributions

# Same distribution.
distrax.Transformed(distrax.Normal(loc=0., scale=1.), distrax.Tanh())
tfd.TransformedDistribution(tfd.Normal(loc=0., scale=1.), tfb.Tanh())

Additionally, Distrax bijectors can be composed and inverted:

bij_distrax = distrax.Tanh()
bij_tfp = tfb.Tanh()

# Same bijector.
inv_bij_distrax = distrax.Inverse(bij_distrax)
inv_bij_tfp = tfb.Invert(bij_tfp)

# These are both the identity bijector.
distrax.Chain([bij_distrax, inv_bij_distrax])
tfb.Chain([bij_tfp, inv_bij_tfp])

All TFP bijectors can be passed to Distrax, and can be freely composed with Distrax bijectors. For example, all of the following will work:

distrax.Inverse(tfb.Tanh())

distrax.Chain([tfb.Tanh(), distrax.Tanh()])

distrax.Transformed(tfd.Normal(loc=0., scale=1.), tfb.Tanh())

Distrax bijectors can also be passed to TFP, but first they must be transformed with to_tfp:

bij_distrax = distrax.to_tfp(distrax.Tanh())

tfb.Invert(bij_distrax)

tfb.Chain([tfb.Tanh(), bij_distrax])

tfd.TransformedDistribution(tfd.Normal(loc=0., scale=1.), bij_distrax)

Distrax also comes with Lambda, a convenient wrapper for turning simple JAX functions into bijectors. Here are a few Lambda examples with their TFP equivalents:

distrax.Lambda(lambda x: x)
# tfb.Identity()

distrax.Lambda(lambda x: 2*x + 3)
# tfb.Chain([tfb.Shift(3), tfb.Scale(2)])

distrax.Lambda(jnp.sinh)
# tfb.Sinh()

distrax.Lambda(lambda x: jnp.sinh(2*x + 3))
# tfb.Chain([tfb.Sinh(), tfb.Shift(3), tfb.Scale(2)])

Unlike TFP, bijectors in Distrax do not take event_ndims as an argument when they compute the Jacobian determinant. Instead, Distrax assumes that the number of event dimensions is statically known to every bijector, and uses Block to lift bijectors to a different number of dimensions. For example:

x = jnp.zeros([2, 3, 4])

# In TFP, `event_ndims` can be passed to the bijector.
bij_tfp = tfb.Tanh()
ld_1 = bij_tfp.forward_log_det_jacobian(x, event_ndims=0)  # Shape = [2, 3, 4]

# Distrax assumes `Tanh` is a scalar bijector by default.
bij_distrax = distrax.Tanh()
ld_2 = bij_distrax.forward_log_det_jacobian(x)  # ld_1 == ld_2

# With `event_ndims=2`, TFP sums the last 2 dimensions of the log det.
ld_3 = bij_tfp.forward_log_det_jacobian(x, event_ndims=2)  # Shape = [2]

# Distrax treats the number of dimensions statically.
bij_distrax = distrax.Block(bij_distrax, ndims=2)
ld_4 = bij_distrax.forward_log_det_jacobian(x)  # ld_3 == ld_4

Distrax bijectors implement the method forward_and_log_det (some bijectors additionally implement inverse_and_log_det), which allows to obtain the forward mapping and its log Jacobian determinant in one line. For some bijectors, this is more efficient than calling separately forward and forward_log_det_jacobian. (Analogously, when available, inverse_and_log_det can be more efficient than inverse and inverse_log_det_jacobian.)

x = jnp.zeros([2, 3, 4])
bij_distrax = distrax.Tanh()

y = bij_distrax.forward(x)
ld = bij_distrax.forward_log_det_jacobian(x)

# A one-line equivalent of the above is:
y, ld = bij_distrax.forward_and_log_det(x)

Jitting Distrax

Distrax distributions and bijectors can be passed as arguments to jitted functions. User-defined distributions and bijectors get this property for free by subclassing distrax.Distribution and distrax.Bijector respectively. For example:

mu_0 = jnp.array([-1., 0., 1.])
sigma_0 = jnp.array([0.1, 0.2, 0.3])
dist_0 = distrax.MultivariateNormalDiag(mu_0, sigma_0)

mu_1 = jnp.array([1., 2., 3.])
sigma_1 = jnp.array([0.2, 0.3, 0.4])
dist_1 = distrax.MultivariateNormalDiag(mu_1, sigma_1)

jitted_kl = jax.jit(lambda d_0, d_1: d_0.kl_divergence(d_1))

# Both print 85.237
print(jitted_kl(dist_0, dist_1))
print(dist_0.kl_divergence(dist_1))

Subclassing Distributions and Bijectors

User-defined distributions can be created by subclassing distrax.Distribution. This can be achieved by implementing only a few methods:

class MyDistribution(distrax.Distribution):

  def __init__(self, ...):
    ...

  def _sample_n(self, key, n):
    samples = ...
    return samples

  def log_prob(self, value):
    log_prob = ...
    return log_prob

  def event_shape(self):
    event_shape = ...
    return event_shape

  def _sample_n_and_log_prob(self, key, n):
    # Optional. Only when more efficient implementation is possible.
    samples, log_prob = ...
    return samples, log_prob

Similarly, more complicated bijectors can be created by subclassing distrax.Bijector. This can be achieved by implementing only one or two class methods:

class MyBijector(distrax.Bijector):

  def __init__(self, ...):
    super().__init__(...)

  def forward_and_log_det(self, x):
    y = ...
    logdet = ...
    return y, logdet

  def inverse_and_log_det(self, y):
    # Optional. Can be omitted if inverse methods are not needed.
    x = ...
    logdet = ...
    return x, logdet

Acknowledgements

We greatly appreciate the ongoing support of the TensorFlow Probability authors in assisting with the design and cross-compatibility of Distrax.

Citing Distrax

To cite this repository:

@software{distrax2021github,
  author = {Jake Bruce and David Budden and Matteo Hessel and George Papamakarios and Francisco Ruiz},
  title = {Distrax: Probability distributions in {JAX}},
  url = {http://github.com/deepmind/distrax},
  version = {0.0.1},
  year = {2021},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

distrax-0.0.1.tar.gz (112.1 kB view hashes)

Uploaded Source

Built Distribution

distrax-0.0.1-py3-none-any.whl (196.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page