Stitching together probabilistic models and inference.
Project description
Bayeux
Stitching together models and samplers
bayeux
lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods. The API aims to be simple, self descriptive, and helpful. Simply provide a log density function (which doesn't even have to be normalized), along with a single point (specified as a pytree) where that log density is finite. Then let bayeux
do the rest!
Installation
pip install bayeux-ml
Quickstart
We define a model by providing a log density in JAX. This could be defined using a probabilistic programming language (PPL) like numpyro, PyMC, TFP, distrax, oryx, coix, or directly in JAX.
import bayeux as bx
import jax
normal_density = bx.Model(
log_density=lambda x: -x*x,
test_point=1.)
seed = jax.random.PRNGKey(0)
Simple
Every inference algorithm in bayeux
will (try to) run with just a seed as an argument:
opt_results = normal_density.optimize.optax_adam(seed=seed)
# OR!
idata = normal_density.mcmc.numpyro_nuts(seed=seed)
# OR!
surrogate_posterior, loss = normal_density.vi.tfp_factored_surrogate_posterior(seed=seed)
An (only rarely) optional third argument to bx.Model
is transform_fn
, which maps a real number to the support of the distribution. The oryx library is used to automatically compute the inverse and Jacobian determinants for changes of variables, but the user can supply these if known.
half_normal_density = bx.Model(
lambda x: -x*x,
test_point=1.,
transform_fn=jax.nn.softplus)
Self descriptive
Since bayeux
is built on top of other fantastic libraries, it tries not to get in the way of them. Each algorithm has a .get_kwargs()
method that tells you how it will be called, and what functions are being called:
normal_density.optimize.jaxopt_bfgs.get_kwargs()
{jaxopt._src.bfgs.BFGS: {'value_and_grad': False,
'has_aux': False,
'maxiter': 500,
'tol': 0.001,
'stepsize': 0.0,
'linesearch': 'zoom',
'linesearch_init': 'increase',
'condition': None,
'maxls': 30,
'decrease_factor': None,
'increase_factor': 1.5,
'max_stepsize': 1.0,
'min_stepsize': 1e-06,
'implicit_diff': True,
'implicit_diff_solve': None,
'jit': True,
'unroll': 'auto',
'verbose': False},
'extra_parameters': {'chain_method': 'vectorized',
'num_particles': 8,
'num_iters': 1000,
'apply_transform': True}}
If you pass an argument into .get_kwargs()
, this will also tell you what will be passed on to the actual algorithms.
normal_density.mcmc.blackjax_nuts.get_kwargs(
num_chains=5,
target_acceptance_rate=0.99)
{<blackjax.adaptation.window_adaptation.window_adaptation: {'is_mass_matrix_diagonal': True,
'initial_step_size': 1.0,
'target_acceptance_rate': 0.99,
'progress_bar': False,
'algorithm': blackjax.mcmc.nuts.nuts},
blackjax.mcmc.nuts.nuts: {'max_num_doublings': 10,
'divergence_threshold': 1000,
'integrator': blackjax.mcmc.integrators.velocity_verlet,
'step_size': 0.01},
'extra_parameters': {'chain_method': 'vectorized',
'num_chains': 5,
'num_draws': 500,
'num_adapt_draws': 500,
'return_pytree': False}}
A full list of available algorithms and how to call them can be seen with
print(normal_density)
mcmc
.blackjax_hmc
.blackjax_nuts
.blackjax_hmc_pathfinder
.blackjax_nuts_pathfinder
.numpyro_hmc
.numpyro_nuts
optimize
.jaxopt_bfgs
.jaxopt_gradient_descent
.jaxopt_lbfgs
.jaxopt_nonlinear_cg
.optax_adabelief
.optax_adafactor
.optax_adagrad
.optax_adam
.optax_adamw
.optax_adamax
.optax_amsgrad
.optax_fromage
.optax_lamb
.optax_lion
.optax_noisy_sgd
.optax_novograd
.optax_radam
.optax_rmsprop
.optax_sgd
.optax_sm3
.optax_yogi
vi
.tfp_factored_surrogate_posterior
Helpful
Algorithms come with a built-in debug
mode that attempts to fail quickly and in a manner that might help debug problems quickly. The signature for debug
accepts verbosity
and catch_exceptions
arguments, as well as a kwargs
dictionary that the user plans to pass to the algorithm itself.
normal_density.mcmc.numpyro_nuts.debug(seed=seed)
Checking test_point shape ✓
Computing test point log density ✓
Loading keyword arguments... ✓
Checking it is possible to compute an initial state ✓
Checking initial state is has no NaN ✓
Computing initial state log density ✓
Transforming model to R^n ✓
Computing transformed state log density shape ✓
Comparing transformed log density to untransformed ✓
Computing gradients of transformed log density ✓
True
Here is an example of a bad model with a higher verbosity:
import jax.numpy as jnp
bad_model = bx.Model(
log_density=jnp.sqrt,
test_point=-1.)
bad_model.mcmc.blackjax_nuts.debug(jax.random.PRNGKey(0),
verbosity=3, kwargs={"num_chains": 17})
Checking test_point shape ✓
Test point has shape
()
✓✓✓✓✓✓✓✓✓✓
Computing test point log density ×
Test point has log density
Array(nan, dtype=float32, weak_type=True)
××××××××××
Loading keyword arguments... ✓
Keyword arguments are
{<function window_adaptation at 0x77feef9308b0>: {'algorithm': <class 'blackjax.mcmc.nuts.nuts'>,
'initial_step_size': 1.0,
'is_mass_matrix_diagonal': True,
'progress_bar': False,
'target_acceptance_rate': 0.8},
'extra_parameters': {'chain_method': 'vectorized',
'num_adapt_draws': 500,
'num_chains': 17,
'num_draws': 500,
'return_pytree': False},
<class 'blackjax.mcmc.nuts.nuts'>: {'divergence_threshold': 1000,
'integrator': <function velocity_verlet at 0x77feefbf4b80>,
'max_num_doublings': 10,
'step_size': 0.01}}
✓✓✓✓✓✓✓✓✓✓
Checking it is possible to compute an initial state ✓
Initial state has shape
(17,)
✓✓✓✓✓✓✓✓✓✓
Checking initial state is has no NaN ✓
No nans detected!
✓✓✓✓✓✓✓✓✓✓
Computing initial state log density ×
Initial state has log density
Array([1.2212421 , nan, nan, 1.4113309 , nan,
nan, nan, nan, nan, nan,
0.5912253 , nan, nan, nan, 0.65457666,
nan, nan], dtype=float32)
××××××××××
Transforming model to R^n ✓
Transformed state has shape
(17,)
✓✓✓✓✓✓✓✓✓✓
Computing transformed state log density shape ✓
Transformed state log density has shape
(17,)
✓✓✓✓✓✓✓✓✓✓
Computing gradients of transformed log density ×
The gradient contains NaNs! Initial gradients has shape
(17,)
××××××××××
False
This is not an officially supported Google product.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file bayeux_ml-0.1.8.tar.gz
.
File metadata
- Download URL: bayeux_ml-0.1.8.tar.gz
- Upload date:
- Size: 29.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a603bde3f4c9d9c236c08fb81bf145feeb00bd17c40ac30ede996a7584ac079f |
|
MD5 | af5c23f2842b8c981d4232e5d00ead35 |
|
BLAKE2b-256 | a013ab502432667d09de8cf3b9979f1a2e41f675a59146ebd44c48026d64a136 |
Provenance
File details
Details for the file bayeux_ml-0.1.8-py3-none-any.whl
.
File metadata
- Download URL: bayeux_ml-0.1.8-py3-none-any.whl
- Upload date:
- Size: 42.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1f9d67ac25b69b30c8f53aca1761018d11a1204b5dad9215ec9898253922306e |
|
MD5 | 3327e4916cf02355265efa23fba8cab4 |
|
BLAKE2b-256 | 55cc81e0e7b331b43c0f2f25933495d3760b5376d4484c3e9b140dde143f5a69 |