Stitching together probabilistic models and inference.
Project description
Bayeux
Stitching together models and samplers
bayeux
lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods. The API aims to be simple, self descriptive, and helpful. Simply provide a log density function (which doesn't even have to be normalized), along with a single point (specified as a pytree) where that log density is finite. Then let bayeux
do the rest!
Installation
pip install bayeux-ml
Quickstart
We define a model by providing a log density in JAX. This could be defined using a probabilistic programming language (PPL) like numpyro, PyMC, TFP, distrax, oryx, coix, or directly in JAX.
import bayeux as bx
import jax
normal_density = bx.Model(
log_density=lambda x: -x*x,
test_point=1.)
seed = jax.random.key(0)
opt_results = normal_density.optimize.optax_adam(seed=seed)
# OR!
idata = normal_density.mcmc.numpyro_nuts(seed=seed)
# OR!
surrogate_posterior, loss = normal_density.vi.tfp_factored_surrogate_posterior(seed=seed)
Read more
- Defining models
- Inspecting models
- Testing and debugging
- Also see
bayeux
integration with numpyro, PyMC, and TFP!
This is not an officially supported Google product.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for bayeux_ml-0.1.12-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c20cfdfe5b276b6ac68aad4217bb467ba8dd168bae12069ed9f09804db8126ee |
|
MD5 | a691d0780476191bd7170f431b2033d8 |
|
BLAKE2b-256 | 3f6a932c479bb972101f0f023bb963d14db613afa13fcf5f79a616a168a7161c |