Skip to main content

MultiVariate Gaussian Kernel Density Estimator

Project description

mvgkde

MultiVariate Gaussian Kernel Density Estimator in JAX.

This is a micro-package, containing the single class MultiVarGaussianKDE (and helper function gaussian_kde) to estimate the probability density function of a multivariate dataset using a Gaussian kernel. This package modifies the jax.scipy.stats.gaussian_kde class (which is based on the scipy.stats.gaussian_kde class), but allows for full control over the covariance matrix of the kernel, even per-dimension bandwidths. See the Documentation below for more information.

Installation

PyPI version PyPI platforms

pip install mvgkde

Documentation

Actions Status

For these examples we will use the following imports:

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np

from mvgkde import MultiVariateGaussianKDE, gaussian_kde  # This package

And we will generate a dataset to work with:

key = jr.key(0)
dataset = jr.normal(key, (2, 1000))

Lastly we will define a plotting function:

# Create a grid of points
(xmin, ymin) = dataset.min(axis=1)
(xmax, ymax) = dataset.max(axis=1)
X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([X.ravel(), Y.ravel()])


def plot_kde(kde: MultiVariateGaussianKDE) -> plt.Figure:
    # Evaluate the KDE on the grid
    Z = np.reshape(kde(positions).T, X.shape)

    # Plot the results
    fig, ax = plt.subplots()
    ax.imshow(np.rot90(Z), cmap=plt.cm.gist_earth_r, extent=[xmin, xmax, ymin, ymax])
    ax.plot(dataset[0], dataset[1], "k.", markersize=2)
    ax.set(
        title="2D Kernel Density Estimation using JAX",
        xlabel="X-axis",
        xlim=[xmin, xmax],
        ylabel="Y-axis",
        ylim=[ymin, ymax],
    )

    return fig

Here's an example that can be done with jax.scipy.stats.gaussian_kde:

kde = gaussian_kde(dataset, bw_method="scott")

fig = plot_kde(kde)
plt.show()

Scotts Rule

Here's an example with a per-dimension bandwidth. This is not possible with the jax.scipy.stats.gaussian_kde:

kde = gaussian_kde(dataset, bw_method=jnp.array([0.15, 1.3]))

fig = plot_kde(kde)
plt.show()

Per-Dimension Bandwidth

Lastly, here's an example with 2D bandwidth matrix:

bw = jnp.array([[0.15, 3], [3, 1.3]])
kde = gaussian_kde(dataset, bw_method=bw)

fig = plot_kde(kde)
plt.show()

2D Bandwidth Matrix

The previous examples are using the convenience function gaussian_kde. This actually just calls the constructor method MultiVariateGaussianKDE.from_bandwidth. This function allows for customixing the bandwidth factor on the data-driven covariance matrix, but does not allow for specifying the covariance matrix directly. To do that, you can call the MultiVariateGaussianKDE constructor directly, or the from_covariance constructor method. To illustrate the difference between modifying the bandwidth and setting the full covariance matrix, consider the following example:

kde = MultiVariateGaussianKDE.from_covariance(
    dataset,
    jnp.array([[0.15, 0.1], [0.1, 1.3]]),
)

fig = plot_kde(kde)
plt.show()

Covariance Matrix

Acknowledgments

This package modifies code from JAX, which is licensed under the Apache License 2.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mvgkde-0.2.0.tar.gz (336.0 kB view details)

Uploaded Source

Built Distribution

mvgkde-0.2.0-py3-none-any.whl (17.4 kB view details)

Uploaded Python 3

File details

Details for the file mvgkde-0.2.0.tar.gz.

File metadata

  • Download URL: mvgkde-0.2.0.tar.gz
  • Upload date:
  • Size: 336.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for mvgkde-0.2.0.tar.gz
Algorithm Hash digest
SHA256 3cf6400661cd7be6d175d29c5760721aa89edbfd27769ea69c0f08bb64ab83b1
MD5 75b6b12c66edc2089a8bb70fc07fb10b
BLAKE2b-256 33f54d3bc66ef6b5358a6693d052cc9a440a3ca064cfae8a044acf9af63a1632

See more details on using hashes here.

Provenance

The following attestation bundles were made for mvgkde-0.2.0.tar.gz:

Publisher: cd.yml on nstarman/mvgkde

Attestations:

File details

Details for the file mvgkde-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: mvgkde-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 17.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for mvgkde-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cc16d22785d61d0654d322d6132436a507cd90c79379178c78a683efad231dc1
MD5 2319f6c4cd1e2de12bfa1801847e90f5
BLAKE2b-256 d97f305dc18b35ed5f0e1a13acfb30efe9b9e997d21371778ce8e946430b04e1

See more details on using hashes here.

Provenance

The following attestation bundles were made for mvgkde-0.2.0-py3-none-any.whl:

Publisher: cd.yml on nstarman/mvgkde

Attestations:

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page