Skip to main content

Simulation-based inference benchmark

Project description

PyPI version Python versions Contributions welcome Black

Simulation-Based Inference Benchmark

This repository contains a simulation-based inference benchmark framework, sbibm, which we describe in the associated manuscript "Benchmarking Simulation-based Inference". A short summary of the paper and interactive results can be found on the project website: https://sbi-benchmark.github.io

The benchmark framework includes tasks, reference posteriors, metrics, plotting, and integrations with SBI toolboxes. The framework is designed to be highly extensible and easily used in new research projects as we show below.

In order to emphasize that sbibm can be used independently of any particular analysis pipeline, we split the code for reproducing the experiments of the manuscript into a seperate repository hosted at github.com/sbi-benchmark/results/. Besides the pipeline to reproduce the manuscripts' experiments, full results including dataframes for quick comparisons are hosted in that repository.

If you have questions or comments, please do not hesitate to contact us or open an issue. We invite contributions, e.g., of new tasks, novel metrics, or wrappers for other SBI toolboxes.

Installation

Assuming you have a working Python environment, simply install sbibm via pip:

$ pip install sbibm

ODE based models (currently SIR and Lotka-Volterra models) use Julia via diffeqtorch. If you are planning to use these tasks, please additionally follow the installation instructions of diffeqtorch. If you are not planning to simulate these tasks for now, you can skip this step.

Quickstart

A quick demonstration of sbibm, see further below for more in-depth explanations:

import sbibm

task = sbibm.get_task("two_moons")  # See sbibm.get_available_tasks() for all tasks
prior = task.get_prior()
simulator = task.get_simulator()
observation = task.get_observation(num_observation=1)  # 10 per task

# These objects can then be used for custom inference algorithms, e.g.
# we might want to generate simulations by sampling from prior:
thetas = prior(num_samples=10_000)
xs = simulator(thetas)

# Alternatively, we can import existing algorithms, e.g:
from sbibm.algorithms import rej_abc  # See help(rej_abc) for keywords
posterior_samples, _, _ = rej_abc(task=task, num_samples=10_000, num_observation=1, num_simulations=100_000)

# Once we got samples from an approximate posterior, compare them to the reference:
from sbibm.metrics import c2st
reference_samples = task.get_reference_posterior_samples(num_observation=1)
c2st_accuracy = c2st(reference_samples, posterior_samples)

# Visualise both posteriors:
from sbibm.visualisation import fig_posterior
fig = fig_posterior(task_name="two_moons", observation=1, samples=[posterior_samples])  
# Note: Use fig.show() or fig.save() to show or save the figure

# Get results from other algorithms for comparison:
from sbibm.visualisation import fig_metric
results_df = sbibm.get_results(dataset="main_paper.csv")
fig = fig_metric(results_df.query("task == 'two_moons'"), metric="C2ST")

Tasks

You can then see the list of available tasks by calling sbibm.get_available_tasks(). If we wanted to use, say, the two_moons task, we can load it using sbibm.get_task, as in:

import sbibm
task = sbibm.get_task("slcp")

Next, we might want to get prior and simulator:

prior = task.get_prior()
simulator = task.get_simulator()

If we call prior() we get a single draw from the prior distribution. num_samples can be provided as an optional argument. The following would generate 100 samples from the simulator:

thetas = prior(num_samples=100)
xs = simulator(thetas)

xs is a torch.Tensor with shape (100, 8), since for SLCP the data is eight-dimensional. Note that if required, conversion to and from torch.Tensor is very easy: Convert to a numpy array using .numpy(), e.g., xs.numpy(). For the reverse, use torch.from_numpy() on a numpy array.

Some algorithms might require evaluating the pdf of the prior distribution, which can be obtained as a torch.Distribution instance using task.get_prior_dist(), which exposes log_prob and sample methods. The parameters of the prior can be picked up as a dictionary as parameters using task.get_prior_params().

For each task, the benchmark contains 10 observations and respective reference posteriors samples. To fetch the first observation and respective reference posterior samples:

observation = task.get_observation(num_observation=1)
reference_samples = task.get_reference_posterior_samples(num_observation=1)

Every tasks has a couple of informative attributes, including:

task.dim_data               # dimensionality data, here: 8
task.dim_parameters         # dimensionality parameters, here: 5
task.num_observations       # number of different observations x_o available, here: 10
task.name                   # name: slcp
task.name_display           # name_display: SLCP

Finally, if you want to have a look at the source code of the task, take a look in sbibm/tasks/slcp/task.py. If you wanted to implement a new task, we would recommend modelling them after the existing ones. You will see that each task has a private _setup method that was used to generate the reference posterior samples.

Algorithms

As mentioned in the intro, sbibm wraps a number of third-party packages to run various algorithms. We found it easiest to give each algorithm the same interface: In general, each algorithm specifies a run function that gets task and hyperparameters as arguments, and eventually returns the required num_posterior_samples. That way, one can simply import the run function of an algorithm, tune it on any given task, and return metrics on the returned samples. Wrappers for external toolboxes implementing algorithms are in the subfolder sbibm/algorithms. Currently, integrations with sbi, pyabc, pyabcranger, as well as an experimental integration with elfi are provided.

Metrics

In order to compare algorithms on the benchmarks, a number of different metrics can be computed. Each task comes with reference samples for each observation. Depending on the benchmark, these are either obtained by making use of an analytic solution for the posterior or a customized likelihood-based approach.

A number of metrics can be computed by comparing algorithm samples to reference samples. In order to do so, a number of different two-sample tests can be computed (see sbibm/metrics). These test follow a simple interface, just requiring to pass samples from reference and algorithm.

For example, in order to compute C2ST:

import torch
from sbibm.metrics.c2st import c2st
from sbibm.algorithms import rej_abc

reference_samples = task.get_reference_posterior_samples(num_observation=1)
algorithm_samples, _, _ = rej_abc(task=task, num_samples=10_000, num_simulations=100_000, num_observation=1)
c2st_accuracy = c2st(reference_samples, algorithm_samples)

For more info, see help(c2st).

Figures

sbibm includes code for plotting results, for instance, to plot metrics on a specific task:

from sbibm.visualisation import fig_metric

results_df = sbibm.get_results(dataset="main_paper.csv")
results_subset = results_df.query("task == 'two_moons'")
fig = fig_metric(results_subset, metric="C2ST")  # Use fig.show() or fig.save() to show or save the figure

It can also be used to plot posteriors, e.g., to compare the results of an inference algorithm against reference samples:

from sbibm.visualisation import fig_posterior
fig = fig_posterior(task_name="two_moons", observation=1, samples=[algorithm_samples])

Results and Experiments

We host results and the code for reproducing the experiments of the manuscript in a seperate repository at github.com/sbi-benchmark/results: This includes the pipeline to reproduce the manuscripts' experiments as well as dataframes for new comparisons.

Citation

The manuscript is available through PMLR:

 @InProceedings{lueckmann2021benchmarking, 
  title     = {Benchmarking Simulation-Based Inference},
  author    = {Lueckmann, Jan-Matthis and Boelts, Jan and Greenberg, David and Goncalves, Pedro and Macke, Jakob}, 
  booktitle = {Proceedings of The 24th International Conference on Artificial Intelligence and Statistics}, 
  pages     = {343--351}, 
  year      = {2021}, 
  editor    = {Banerjee, Arindam and Fukumizu, Kenji}, 
  volume    = {130}, 
  series    = {Proceedings of Machine Learning Research}, 
  month     = {13--15 Apr}, 
  publisher = {PMLR}
}  

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sbibm-1.0.7.tar.gz (18.5 MB view details)

Uploaded Source

Built Distribution

sbibm-1.0.7-py2.py3-none-any.whl (18.6 MB view details)

Uploaded Python 2 Python 3

File details

Details for the file sbibm-1.0.7.tar.gz.

File metadata

  • Download URL: sbibm-1.0.7.tar.gz
  • Upload date:
  • Size: 18.5 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for sbibm-1.0.7.tar.gz
Algorithm Hash digest
SHA256 06ffca69b64ea4eeeb753b9b63f107793386f40fbe73090ddb8d0393de45ea88
MD5 d253717829fdfa656ffb0a6b3cbc6a25
BLAKE2b-256 55148b95ad407f0414e46f44c7d4f7968ecd47a557e541ac9096f6262f17f514

See more details on using hashes here.

File details

Details for the file sbibm-1.0.7-py2.py3-none-any.whl.

File metadata

  • Download URL: sbibm-1.0.7-py2.py3-none-any.whl
  • Upload date:
  • Size: 18.6 MB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for sbibm-1.0.7-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 871d2bf89501c01f86cfef58423d06f52fb626d9ae2c046d5b456c13c9f8dc93
MD5 0f598399b0f694957952757c21c47890
BLAKE2b-256 9fa72f8be9f5ada471cf7f04c4d1e29f3a3d68c7dc2b92dc3f9c9da9f372fbe4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page