Stitching together probabilistic models and inference.
Project description
Bayeux
Stitching together models and samplers
bayeux
lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods. The API aims to be simple, self descriptive, and helpful. Simply provide a log density function (which doesn't even have to be normalized), along with a single point (specified as a pytree) where that log density is finite. Then let bayeux
do the rest!
Installation
pip install bayeux-ml
Quickstart
We define a model by providing a log density in JAX. This could be defined using a probabilistic programming language (PPL) like numpyro, PyMC, TFP, distrax, oryx, coix, or directly in JAX.
import bayeux as bx
import jax
normal_density = bx.Model(
log_density=lambda x: -x*x,
test_point=1.)
seed = jax.random.key(0)
opt_results = normal_density.optimize.optax_adam(seed=seed)
# OR!
idata = normal_density.mcmc.numpyro_nuts(seed=seed)
# OR!
surrogate_posterior, loss = normal_density.vi.tfp_factored_surrogate_posterior(seed=seed)
Read more
- Defining models
- Inspecting models
- Testing and debugging
- Also see
bayeux
integration with numpyro, PyMC, and TFP!
This is not an officially supported Google product.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for bayeux_ml-0.1.10-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1b1bcd83c85487e12d5551a9181cc3f4478fce3bfc5033dfc8f3c751353fde50 |
|
MD5 | 411a7ce56268400799d7ae869c656dee |
|
BLAKE2b-256 | 259046d152df54cde409c007fa0225cbef9fbfb2f3b05ff8a9b013431349faf4 |