Skip to main content

Python wrapper for nuts-rs -- a NUTS sampler written in Rust.

Project description

nutpie: A fast sampler for bayesian posteriors

Installation

nutpie can be installed using conda or mamba from conda-forge with

mamba install -c conda-forge nutpie pymc

To install it from source, install a rust compiler (eg using rustup) and run

maturin develop --release

If you want to use the nightly simd implementation for some of the math functions, switch to rust nightly and then install with the simd_support feature in the nutpie dir:

rustup override set nightly
maturin develop --release --features=simd_support

Usage

First, we need to create a model, for example using pymc:

import pymc as pm
import numpy as np
import nutpie
import pandas as pd
import seaborn as sns

# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}

# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as pymc_model:
    intercept = pm.Normal("intercept", sigma=10)

    # County effects
    raw = pm.ZeroSumNormal("county_raw", dims="county")
    sd = pm.HalfNormal("county_sd")
    county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")

    # Global floor effect
    floor_effect = pm.Normal("floor_effect", sigma=2)

    # County:floor interaction
    raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
    sd = pm.HalfNormal("county_floor_sd")
    county_floor_effect = pm.Deterministic(
        "county_floor_effect", raw * sd, dims="county"
    )

    mu = (
        intercept
        + county_effect[county_idx]
        + floor_effect * data.floor.values
        + county_floor_effect[county_idx] * data.floor.values
    )

    sigma = pm.HalfNormal("sigma", sigma=1.5)
    pm.Normal(
        "log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id"
    )

We then compile this model and sample form the posterior:

compiled_model = nutpie.compile_pymc_model(pymc_model)
trace_pymc = nutpie.sample(compiled_model, chains=10)

trace_pymc now contains an arviz InferenceData object, including sampling statistics and the posterior of the variables defined above.

For more details, see the example notebook pytensor_logp

nutpie can also sample from stan models, it currently needs a patched version of httpstan do so so however. The required version can be found here. Make sure to follow the development installation instructions for httpstan.

Advantages

nutpie uses nuts-rs, a library written in rust, that implements NUTS as in pymc and stan, but with a slightly different mass matrix tuning method as those. It often produces a higher effective sample size per gradient evaluation, and tends to converge faster and with fewer gradient evaluation.

From the benchmarks I did, it seems to be the fastest CPU based sampler I could find, outperforming cmdstan and numpyro.

Unfortunately performance on pymc models is currently somewhat limited by an issue in numba, which hopefully will be fixed soon. Without the patch mentioned in the issue the model above samples in about 2s on my machine, with the patch it finished is about 700ms.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nutpie-0.5.0.tar.gz (406.5 kB view details)

Uploaded Source

Built Distribution

nutpie-0.5.0-cp310-cp310-macosx_11_0_arm64.whl (305.8 kB view details)

Uploaded CPython 3.10 macOS 11.0+ ARM64

File details

Details for the file nutpie-0.5.0.tar.gz.

File metadata

  • Download URL: nutpie-0.5.0.tar.gz
  • Upload date:
  • Size: 406.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/0.13.6

File hashes

Hashes for nutpie-0.5.0.tar.gz
Algorithm Hash digest
SHA256 223dc3456dc11f0d1ae750ec9fd793e0109a2d37aa7438316d61947af510c7d0
MD5 2c85b252143f8ba62a6cb3d75de023f9
BLAKE2b-256 7f10419ae77eb10c67ba6e62b5d62846eea6ac46502d2c71d902497db2f6d1c0

See more details on using hashes here.

File details

Details for the file nutpie-0.5.0-cp310-cp310-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for nutpie-0.5.0-cp310-cp310-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 936ab8a0481873f4779b758385e0c3dc0a0f1c4f50201bc93914bad4757ec5ee
MD5 a2aebde102ded6fc24e46a0164ecc98e
BLAKE2b-256 59e0abdcad4b54db91877dbec784844cf7350e71d50bbbbd4adc9de9eae11016

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page