Skip to main content

Flexible and fast inference in Python

Project description

BlackJAX

CI codecov

What is BlackJAX?

BlackJAX is a library of samplers for JAX that works on CPU as well as GPU.

It is not a probabilistic programming library. However it integrates really well with PPLs as long as they can provide a (potentially unnormalized) log-probability density function compatible with JAX.

Who should use BlackJAX?

BlackJAX should appeal to those who:

  • Have a logpdf and just need a sampler;
  • Need more than a general-purpose sampler;
  • Want to sample on GPU;
  • Want to build upon robust elementary blocks for their research;
  • Are building a PPL;
  • Want to learn how sampling algorithms work.

Quickstart

Installation

BlackJAX is written in pure Python but depends on XLA via JAX. Since the JAX installation depends on your CUDA version BlackJAX does not list JAX as a dependency. If you simply want to use JAX on CPU, install it with:

pip install jax jaxlib

Follow these instructions to install JAX with the relevant hardware acceleration support.

Then install BlackJAX

pip install blackjax

Example

Let us look at a simple self-contained example sampling with NUTS:

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np

import blackjax.nuts as nuts

observed = np.random.normal(10, 20, size=1_000)
def potential_fn(loc, scale, observed=observed):
  logpdf = stats.norm.logpdf(observed, loc, scale)
  return -jnp.sum(logpdf)

# Build the kernel
step_size = 1e-3
inverse_mass_matrix = jnp.array([1., 1.])
potential = lambda x: potential_fn(**x)
kernel = nuts.kernel(potential, step_size, inverse_mass_matrix)
kernel = jax.jit(kernel)  # try without to see the speedup

# Initialize the state
initial_position = {"loc": 1., "scale": 2.}
state = nuts.new_state(initial_position, potential)

# Iterate
rng_key = jax.random.PRNGKey(0)
for _ in range(1_000):
    _, rng_key = jax.random.split(rng_key)
    state, _ = kernel(rng_key, state)

See this notebook for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc.

Philosophy

What is BlackJAX?

BlackJAX bridges the gap between "one liner" frameworks and modular, customizable libraries.

Users can import the library and interact with robut, well-tested and performant samplers with a few lines of code. These samplers are aimed at PPL developers, or people who have a logpdf and just need a sampler that works.

But the true strength of BlackJAX lies in its internals and how they can be used to experiment quickly on existing or new sampling schemes. This lower level exposes the building blocks of inference algorithms: integrators, proposal, momentum generators, etc and makes it easy to combine them to build new algorithms. It provides an opportunity to accelerate research on sampling algorithms by providing robust, performant and reusable code.

Why BlackJAX?

Sampling algorithms are too often integrated into PPLs and not decoupled from the rest of the framework, making them hard to use for people who do not need the modeling language to build their logpdf. Their implementation is most of the time monolithic and it is impossible to reuse parts of the algorithm to build custom kernels. BlackJAX solves both problems.

How does it work?

BlackJAX allows to build arbitrarily complex algorithms because it is built around a very general pattern. Everything that takes a state and returns a state is a transition kernel, and is implemented as:

new_state, info =  kernel(rng_key, state)

kernels are stateless functions and all follow the same API; state and information related to the transition are returned separately. They can thus be easily composed and exchanged. We specialize these kernels by closure instead of passing parameters.

Contributions

What contributions?

We value the following contributions:

  • Bug fixes
  • Documentation
  • High-level sampling algorithms from any family of algorithms: random walk, hamiltonian monte carlo, sequential monte carlo, variational inference, inference compilation, etc.
  • New building blocks, e.g. new metrics for HMC, integrators, etc.

How to contribute?

  1. Run pip install -r requirements-dev.txt to install all the dev dependencies.
  2. Run make lint and make test before pushing on the repo; CI should pass if these pass locally.

Acknowledgements

Some details of the NUTS implementation were largely inspired by Numpyro's.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

blackjax-0.2.1.tar.gz (35.8 kB view details)

Uploaded Source

Built Distribution

blackjax-0.2.1-py3-none-any.whl (48.4 kB view details)

Uploaded Python 3

File details

Details for the file blackjax-0.2.1.tar.gz.

File metadata

  • Download URL: blackjax-0.2.1.tar.gz
  • Upload date:
  • Size: 35.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.7.10

File hashes

Hashes for blackjax-0.2.1.tar.gz
Algorithm Hash digest
SHA256 e7fafe5ebb636537d437ace9330fb17844546727d1cf66dbefca10b2eaf7d9bb
MD5 5d9c8ac7cc696928822443d153f07349
BLAKE2b-256 0e67a3f553e6959f73d58f3e6207a3d2639565c14ec74ee7441ca80abebb570c

See more details on using hashes here.

File details

Details for the file blackjax-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: blackjax-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 48.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.7.10

File hashes

Hashes for blackjax-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 12dcc9c5fed7b3d43b3b8aee53ae9be12c2eb7499719ab21d4d4cfc5a1c2f5b1
MD5 4ce254c00e63030bdddce483ede40c72
BLAKE2b-256 1e79ae5bcadda0fd1029e7bc2ce0d91add995440daa06871b552bf1b474167a8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page