Skip to main content

A miscellaneous set of helper functions, custom distributions, and other utilities that I find useful when using NumPyro in my work

Project description

Extensions for NumPyro

This library includes a miscellaneous set of helper functions, custom distributions, and other utilities that I find useful when using NumPyro in my work.

Installation

Since NumPyro, and hence this library, are built on top of JAX, it's typically good practice to start by installing JAX following the installation instructions. Then, you can install this library using pip:

python -m pip install numpyro-ext

Usage

Since this README is checked using doctest, let's start by importing some common modules that we'll need in all our examples:

>>> import jax
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro_ext

Distributions

The tradition is to import numpyro_ext.distributions as distx to differentiate from numpyro.distributions, which is imported as dist:

>>> from numpyro import distributions as dist
>>> from numpyro_ext import distributions as distx
>>> key = jax.random.PRNGKey(0)

Angle

A uniform distribution over angles in radians. The actual sampling is performed in the two-dimensional vector space proportional to (sin(theta), cos(theta)) so that the sampler doesn't see a discontinuity at pi.

>>> angle = distx.Angle()
>>> print(angle.sample(key, (2, 3)))
[[ 0.4...]
 [ 2.4...]]

UnitDisk

A uniform distribution over two-dimensional points within the disk of radius 1. This means that the sum over squares of the last dimension of a random variable generated from this distribution will always be less than 1.

>>> unit_disk = distx.UnitDisk()
>>> u = unit_disk.sample(key, (5,))
>>> print(jnp.sum(u**2, axis=-1))
[0.07...]

NoncentralChi2

A non-central chi-squared distribution. To use this distribution, you'll need to install the optional tensorflow-probability dependency.

>>> ncx2 = distx.NoncentralChi2(df=3, nc=2.)
>>> print(ncx2.sample(key, (5,)))
[2.19...]

MarginalizedLinear

The marginalized product of two (possibly multivariate) normal distributions with a linear relationship between them. The mathematical details of these models are discussed in detail in this note, and this distribution implements the math presented there, in a computationally efficient way, assuming that the number of marginalized parameters is small compared to the size of the dataset.

The following example shows a particularly simple example of a fully-marginalized model for fitting a line to data:

>>> def model(x, y=None):
...     design_matrix = jnp.vander(x, 2)
...     prior = dist.Normal(0.0, 1.0)
...     data = dist.Normal(0.0, 2.0)
...     numpyro.sample(
...         "y",
...         distx.MarginalizedLinear(design_matrix, prior, data),
...         obs=y
...     )
...

Things get a little more interesting when the design matrix and/or the distributions are functions of non-linear parameters. For example, if we want to find the period of a sinusoidal signal, also fitting for some unknown excess measurement uncertainty (often called "jitter") we can use the following model:

>>> def model(x, y_err, y=None):
...     period = numpyro.sample("period", dist.Uniform(1.0, 250.0))
...     ln_jitter = numpyro.sample("ln_jitter", dist.Normal(0.0, 2.0))
...     design_matrix = jnp.stack(
...         [
...             jnp.sin(2 * jnp.pi * x / period),
...             jnp.cos(2 * jnp.pi * x / period),
...             jnp.ones_like(x),
...         ],
...         axis=-1,
...     )
...     prior = dist.Normal(0.0, 10.0).expand([3])
...     data = dist.Normal(0.0, jnp.sqrt(y_err**2 + jnp.exp(2*ln_jitter)))
...     numpyro.sample(
...         "y",
...         distx.MarginalizedLinear(design_matrix, prior, data),
...         obs=y
...     )
...
>>> x = jnp.linspace(-1.0, 1.0, 5)
>>> samples = numpyro.infer.Predictive(model, num_samples=2)(key, x, 0.1)
>>> print(samples["period"])
[... ...]
>>> print(samples["y"])
[[... ... ...]
 [... ... ...]]

It's often useful to also track conditional samples of the marginalized parameters during inference. The conditional distribution can be accessed using the conditional method on MarginalizedLinear:

>>> x = jnp.linspace(-1.0, 1.0, 5)
>>> y = jnp.sin(x)  # just some fake data
>>> design_matrix = jnp.vander(x, 2)
>>> prior = dist.Normal(0.0, 1.0)
>>> data = dist.Normal(0.0, 2.0)
>>> marg = distx.MarginalizedLinear(design_matrix, prior, data)
>>> cond = marg.conditional(y)
>>> print(type(cond).__name__)
MultivariateNormal
>>> print(cond.sample(key, (3,)))
[[...]
 [...]
 [...]]

Optimization

The inference lore is a little mixed on the benefits of optimization as an initialization tool for MCMC, but I find that at least in a lot of astronomy applications, an initial optimization can make a huge difference in performance. Even if you don't want to use the optimization results as an initialization, it can still sometimes be useful to numerically search for the maximum a posteriori parameters for your model. However, the NumPyro interface for these types of optimization isn't terribly user-friendly, so this library provides some helpers to make it a little more straightforward.

By default, this optimization uses the wrappers of scipy's optimization routines provided by the JAXopt library, so you'll need to install JAXopt:

python -m pip install jaxopt

before running these examples.

The following example shows a simple optimization of a model with a single parameter:

>>> from numpyro_ext import optim as optimx
>>>
>>> def model(y=None):
...     x = numpyro.sample("x", dist.Normal(0.0, 1.0))
...     numpyro.sample("y", dist.Normal(x, 2.0), obs=y)
...
>>> soln = optimx.optimize(model)(key, y=0.5)

By default, the optimization starts from a prior sample, but you can provide custom initial coordinates as follows:

>>> soln = optimx.optimize(model, start={"x": 12.3})(key, y=0.5)

Similarly, if you only want to optimize a subset of the parameters, you can provide a list of parameters to target:

>>> soln = optimx.optimize(model, sites=["x"])(key, y=0.5)

Information matrix computation

The Fisher information matrix for models with Gaussian likelihoods is straightforward to compute, and this library provides a helper function for automating this computation:

>>> from numpyro_ext import information
>>>
>>> def model(x, y=None):
...     a = numpyro.sample("a", dist.Normal(0.0, 1.0))
...     b = numpyro.sample("b", dist.Normal(0.0, 1.0))
...     log_alpha = numpyro.sample("log_alpha", dist.Normal(0.0, 1.0))
...     cov = jnp.exp(log_alpha - 0.5 * (x[:, None] - x[None, :])**2)
...     cov += 0.1 * jnp.eye(len(x))
...     numpyro.sample(
...         "y",
...         dist.MultivariateNormal(loc=a * x + b, covariance_matrix=cov),
...         obs=y,
...     )
...
>>> x = jnp.linspace(-1.0, 1.0, 5)
>>> y = jnp.sin(x)  # the input data just needs to have the right shape
>>> params = {"a": 0.5, "b": -0.2, "log_alpha": -0.5}
>>> info = information(model)(params, x, y=y)
>>> print(info)
{'a': {'a': ..., 'b': ... 'log_alpha': ...}, 'b': ...}

The returned information matrix is a nested dictionary of dictionaries, indexed by pairs of parameter names, where the values are the corresponding blocks of the information matrix.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

numpyro-ext-0.0.3rc1.tar.gz (29.4 kB view details)

Uploaded Source

Built Distribution

numpyro_ext-0.0.3rc1-py3-none-any.whl (21.1 kB view details)

Uploaded Python 3

File details

Details for the file numpyro-ext-0.0.3rc1.tar.gz.

File metadata

  • Download URL: numpyro-ext-0.0.3rc1.tar.gz
  • Upload date:
  • Size: 29.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/4.0.2 CPython/3.11.8

File hashes

Hashes for numpyro-ext-0.0.3rc1.tar.gz
Algorithm Hash digest
SHA256 0b9082a0aae1021e531ae3f54fc42aca536e48dc0184b712f424786e062c8f66
MD5 796e74799407e17cf3c4f70f0bae64ab
BLAKE2b-256 b57a214e786a91b9507d8970bece5dbf69f429ced8e9d4c8a1b81960677b942d

See more details on using hashes here.

File details

Details for the file numpyro_ext-0.0.3rc1-py3-none-any.whl.

File metadata

File hashes

Hashes for numpyro_ext-0.0.3rc1-py3-none-any.whl
Algorithm Hash digest
SHA256 490a2500445e5cd62742fb5374c2ae8ce8d3d0810961bab007f35a02a77fe09f
MD5 dd233a7c4b6872f61d560b6e271e115c
BLAKE2b-256 f645e56ee63597a20eef5ec1e6c6e9bb19aee14fc32ea40079f506e9a7a340ee

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page