Skip to main content

A miscellaneous set of helper functions, custom distributions, and other utilities that I find useful when using NumPyro in my work

Project description

Extensions for NumPyro

This library includes a miscellaneous set of helper functions, custom distributions, and other utilities that I find useful when using NumPyro in my work.

Installation

Since NumPyro, and hence this library, are built on top of JAX, it's typically good practice to start by installing JAX following the installation instructions. Then, you can install this library using pip:

python -m pip install numpyro-ext

Usage

Since this README is checked using doctest, let's start by importing some common modules that we'll need in all our examples:

>>> import jax
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro_ext

Distributions

The tradition is to import numpyro_ext.distributions as distx to differentiate from numpyro.distributions, which is imported as dist:

>>> from numpyro import distributions as dist
>>> from numpyro_ext import distributions as distx
>>> key = jax.random.PRNGKey(0)

Angle

A uniform distribution over angles in radians. The actual sampling is performed in the two-dimensional vector space proportional to (sin(theta), cos(theta)) so that the sampler doesn't see a discontinuity at pi.

>>> angle = distx.Angle()
>>> print(angle.sample(key, (2, 3)))
[[ 0.4...]
 [ 2.4...]]

UnitDisk

A uniform distribution over two-dimensional points within the disk of radius 1. This means that the sum over squares of the last dimension of a random variable generated from this distribution will always be less than 1.

>>> unit_disk = distx.UnitDisk()
>>> u = unit_disk.sample(key, (5,))
>>> print(jnp.sum(u**2, axis=-1))
[0.07...]

NoncentralChi2

A non-central chi-squared distribution. To use this distribution, you'll need to install the optional tensorflow-probability dependency.

>>> ncx2 = distx.NoncentralChi2(df=3, nc=2.)
>>> print(ncx2.sample(key, (5,)))
[2.19...]

MarginalizedLinear

The marginalized product of two (possibly multivariate) normal distributions with a linear relationship between them. The mathematical details of these models are discussed in detail in this note, and this distribution implements the math presented there, in a computationally efficient way, assuming that the number of marginalized parameters is small compared to the size of the dataset.

The following example shows a particularly simple example of a fully-marginalized model for fitting a line to data:

>>> def model(x, y=None):
...     design_matrix = jnp.vander(x, 2)
...     prior = dist.Normal(0.0, 1.0)
...     data = dist.Normal(0.0, 2.0)
...     numpyro.sample(
...         "y",
...         distx.MarginalizedLinear(design_matrix, prior, data),
...         obs=y
...     )
...

Things get a little more interesting when the design matrix and/or the distributions are functions of non-linear parameters. For example, if we want to find the period of a sinusoidal signal, also fitting for some unknown excess measurement uncertainty (often called "jitter") we can use the following model:

>>> def model(x, y_err, y=None):
...     period = numpyro.sample("period", dist.Uniform(1.0, 250.0))
...     ln_jitter = numpyro.sample("ln_jitter", dist.Normal(0.0, 2.0))
...     design_matrix = jnp.stack(
...         [
...             jnp.sin(2 * jnp.pi * x / period),
...             jnp.cos(2 * jnp.pi * x / period),
...             jnp.ones_like(x),
...         ],
...         axis=-1,
...     )
...     prior = dist.Normal(0.0, 10.0).expand([3])
...     data = dist.Normal(0.0, jnp.sqrt(y_err**2 + jnp.exp(2*ln_jitter)))
...     numpyro.sample(
...         "y",
...         distx.MarginalizedLinear(design_matrix, prior, data),
...         obs=y
...     )
...
>>> x = jnp.linspace(-1.0, 1.0, 5)
>>> samples = numpyro.infer.Predictive(model, num_samples=2)(key, x, 0.1)
>>> print(samples["period"])
[... ...]
>>> print(samples["y"])
[[... ... ...]
 [... ... ...]]

It's often useful to also track conditional samples of the marginalized parameters during inference. The conditional distribution can be accessed using the conditional method on MarginalizedLinear:

>>> x = jnp.linspace(-1.0, 1.0, 5)
>>> y = jnp.sin(x)  # just some fake data
>>> design_matrix = jnp.vander(x, 2)
>>> prior = dist.Normal(0.0, 1.0)
>>> data = dist.Normal(0.0, 2.0)
>>> marg = distx.MarginalizedLinear(design_matrix, prior, data)
>>> cond = marg.conditional(y)
>>> print(type(cond).__name__)
MultivariateNormal
>>> print(cond.sample(key, (3,)))
[[...]
 [...]
 [...]]

Optimization

The inference lore is a little mixed on the benefits of optimization as an initialization tool for MCMC, but I find that at least in a lot of astronomy applications, an initial optimization can make a huge difference in performance. Even if you don't want to use the optimization results as an initialization, it can still sometimes be useful to numerically search for the maximum a posteriori parameters for your model. However, the NumPyro interface for these types of optimization isn't terribly user-friendly, so this library provides some helpers to make it a little more straightforward.

By default, this optimization uses the wrappers of scipy's optimization routines provided by the JAXopt library, so you'll need to install JAXopt:

python -m pip install jaxopt

before running these examples.

The following example shows a simple optimization of a model with a single parameter:

>>> from numpyro_ext import optim as optimx
>>>
>>> def model(y=None):
...     x = numpyro.sample("x", dist.Normal(0.0, 1.0))
...     numpyro.sample("y", dist.Normal(x, 2.0), obs=y)
...
>>> soln = optimx.optimize(model)(key, y=0.5)

By default, the optimization starts from a prior sample, but you can provide custom initial coordinates as follows:

>>> soln = optimx.optimize(model, start={"x": 12.3})(key, y=0.5)

Similarly, if you only want to optimize a subset of the parameters, you can provide a list of parameters to target:

>>> soln = optimx.optimize(model, sites=["x"])(key, y=0.5)

Information matrix computation

The Fisher information matrix for models with Gaussian likelihoods is straightforward to compute, and this library provides a helper function for automating this computation:

>>> from numpyro_ext import information
>>>
>>> def model(x, y=None):
...     a = numpyro.sample("a", dist.Normal(0.0, 1.0))
...     b = numpyro.sample("b", dist.Normal(0.0, 1.0))
...     log_alpha = numpyro.sample("log_alpha", dist.Normal(0.0, 1.0))
...     cov = jnp.exp(log_alpha - 0.5 * (x[:, None] - x[None, :])**2)
...     cov += 0.1 * jnp.eye(len(x))
...     numpyro.sample(
...         "y",
...         dist.MultivariateNormal(loc=a * x + b, covariance_matrix=cov),
...         obs=y,
...     )
...
>>> x = jnp.linspace(-1.0, 1.0, 5)
>>> y = jnp.sin(x)  # the input data just needs to have the right shape
>>> params = {"a": 0.5, "b": -0.2, "log_alpha": -0.5}
>>> info = information(model)(params, x, y=y)
>>> print(info)
{'a': {'a': ..., 'b': ... 'log_alpha': ...}, 'b': ...}

The returned information matrix is a nested dictionary of dictionaries, indexed by pairs of parameter names, where the values are the corresponding blocks of the information matrix.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

numpyro-ext-0.0.3rc2.tar.gz (29.4 kB view details)

Uploaded Source

Built Distribution

numpyro_ext-0.0.3rc2-py3-none-any.whl (21.1 kB view details)

Uploaded Python 3

File details

Details for the file numpyro-ext-0.0.3rc2.tar.gz.

File metadata

  • Download URL: numpyro-ext-0.0.3rc2.tar.gz
  • Upload date:
  • Size: 29.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/4.0.2 CPython/3.11.8

File hashes

Hashes for numpyro-ext-0.0.3rc2.tar.gz
Algorithm Hash digest
SHA256 7953e69742865e46d0e0fe8c5eb3b5e259cf37a7629babe713a1457c00db57c0
MD5 30fb4932d02c68b53164b27461e6c875
BLAKE2b-256 3cfef928f61778bd1239a6cdfd7df6fe69146359eba3560cbb691ccf7c3d55b3

See more details on using hashes here.

File details

Details for the file numpyro_ext-0.0.3rc2-py3-none-any.whl.

File metadata

File hashes

Hashes for numpyro_ext-0.0.3rc2-py3-none-any.whl
Algorithm Hash digest
SHA256 010e5923dd651f1060cc9e74ab79160f8282fbc2a7919e7ab15fcace7e636fee
MD5 c54be2fd3c8a4dd1e5963278400c5fce
BLAKE2b-256 824e18d2a7c53e06efb884128308deb6436c03ee0e07359689a9d1baf92fdf81

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page