Skip to main content

MultiVariate Gaussian Kernel Density Estimator

Project description

mvgkde

MultiVariate Gaussian Kernel Density Estimator in JAX.

This is a micro-package, containing the single class MultiVarGaussianKDE (and helper function gaussian_kde) to estimate the probability density function of a multivariate dataset using a Gaussian kernel. This package modifies the jax.scipy.gaussian_kde class (which is based on the scipy.stats.gaussian_kde class), but allows for full control over the covariance matrix of the kernel, even per-dimension bandwidths. See the Documentation below for more information.

Installation

PyPI version PyPI platforms

pip install mvgkde

Documentation

Actions Status

For these examples we will use the following imports:

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np

from mvgkde import MultiVariateGaussianKDE, gaussian_kde  # This package

And we will generate a dataset to work with:

key = jr.key(0)
dataset = jr.normal(key, (2, 1000))

Lastly we will define a plotting function:

# Create a grid of points
(xmin, ymin) = dataset.min(axis=1)
(xmax, ymax) = dataset.max(axis=1)
X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([X.ravel(), Y.ravel()])


def plot_kde(kde: MultiVariateGaussianKDE) -> plt.Figure:
    # Evaluate the KDE on the grid
    Z = np.reshape(kde(positions).T, X.shape)

    # Plot the results
    fig, ax = plt.subplots()
    ax.imshow(np.rot90(Z), cmap=plt.cm.gist_earth_r, extent=[xmin, xmax, ymin, ymax])
    ax.plot(dataset[0], dataset[1], "k.", markersize=2)
    ax.set(
        title="2D Kernel Density Estimation using JAX",
        xlabel="X-axis",
        xlim=[xmin, xmax],
        ylabel="Y-axis",
        ylim=[ymin, ymax],
    )

    return fig

Here's an example that can be done with jax.scipy.gaussian_kde:

kde = gaussian_kde(dataset, bw="scott")

fig = plot_kde(kde)
plt.show()

Scotts Rule

Here's an example with a per-dimension bandwidth. This is not possible with the jax.scipy.gaussian_kde:

kde = gaussian_kde(dataset, bw=jnp.array([0.15, 1.3]))

fig = plot_kde(kde)
plt.show()

Per-Dimension Bandwidth

Lastly, here's an example with 2D bandwidth matrix:

bw = jnp.array([[0.15, 3], [3, 1.3]])
kde = gaussian_kde(dataset, bw=bw)

fig = plot_kde(kde)
plt.show()

2D Bandwidth Matrix

The previous examples are using the convenience function gaussian_kde. This actually just calls the constructor method MultiVariateGaussianKDE.from_bandwidth. This function allows for customixing the bandwidth factor on the data-driven covariance matrix, but does not allow for specifying the covariance matrix directly. To do that, you can call the MultiVariateGaussianKDE constructor directly, or the from_covariance constructor method. To illustrate the difference between modifying the bandwidth and setting the full covariance matrix, consider the following example:

kde = MultiVariateGaussianKDE.from_covariance(
    dataset,
    jnp.array([[0.15, 0.1], [0.1, 1.3]]),
)

fig = plot_kde(kde)
plt.show()

Covariance Matrix

Acknowledgments

This package modifies code from JAX, which is licensed under the Apache License 2.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mvgkde-0.1.1.tar.gz (334.6 kB view details)

Uploaded Source

Built Distribution

mvgkde-0.1.1-py3-none-any.whl (16.6 kB view details)

Uploaded Python 3

File details

Details for the file mvgkde-0.1.1.tar.gz.

File metadata

  • Download URL: mvgkde-0.1.1.tar.gz
  • Upload date:
  • Size: 334.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for mvgkde-0.1.1.tar.gz
Algorithm Hash digest
SHA256 f983e8fc3a0cc014207b057d5cb8495805c21f6f936a7d9bcb2b9baa7bb69215
MD5 97042eed8ea2f8063b030f7b027f2ebb
BLAKE2b-256 a63e33cc21b0c5e67b30373325c8e9fff8456c870c8ee117bd2d7c5e0b6817d3

See more details on using hashes here.

File details

Details for the file mvgkde-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: mvgkde-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 16.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for mvgkde-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 962d5a83dbcfb31ae0e3e98a914df453a7a3f92ad6e971cdac6ea6d76395eab3
MD5 d08b36bf47e3dacd86e4217a5046e859
BLAKE2b-256 b8c5c002c2bd82e880b001d3c2ee5fd238c667492d74e1a5dab1d1961f9947a6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page