Skip to main content

MultiVariate Gaussian Kernel Density Estimator

Project description

mvgkde

MultiVariate Gaussian Kernel Density Estimator in JAX.

This is a micro-package, containing the single class MultiVarGaussianKDE (and helper function gaussian_kde) to estimate the probability density function of a multivariate dataset using a Gaussian kernel. This package modifies the jax.scipy.stats.gaussian_kde class (which is based on the scipy.stats.gaussian_kde class), but allows for full control over the covariance matrix of the kernel, even per-dimension bandwidths. See the Documentation below for more information.

Installation

PyPI version PyPI platforms

pip install mvgkde

Documentation

Actions Status

For these examples we will use the following imports:

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np

from mvgkde import MultiVariateGaussianKDE, gaussian_kde  # This package

And we will generate a dataset to work with:

key = jr.key(0)
dataset = jr.normal(key, (2, 1000))

Lastly we will define a plotting function:

# Create a grid of points
(xmin, ymin) = dataset.min(axis=1)
(xmax, ymax) = dataset.max(axis=1)
X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([X.ravel(), Y.ravel()])


def plot_kde(kde: MultiVariateGaussianKDE) -> plt.Figure:
    # Evaluate the KDE on the grid
    Z = np.reshape(kde(positions).T, X.shape)

    # Plot the results
    fig, ax = plt.subplots()
    ax.imshow(np.rot90(Z), cmap=plt.cm.gist_earth_r, extent=[xmin, xmax, ymin, ymax])
    ax.plot(dataset[0], dataset[1], "k.", markersize=2)
    ax.set(
        title="2D Kernel Density Estimation using JAX",
        xlabel="X-axis",
        xlim=[xmin, xmax],
        ylabel="Y-axis",
        ylim=[ymin, ymax],
    )

    return fig

Here's an example that can be done with jax.scipy.stats.gaussian_kde:

kde = gaussian_kde(dataset, bw_method="scott")

fig = plot_kde(kde)
plt.show()

Scotts Rule

Here's an example with a per-dimension bandwidth. This is not possible with the jax.scipy.stats.gaussian_kde:

kde = gaussian_kde(dataset, bw_method=jnp.array([0.15, 1.3]))

fig = plot_kde(kde)
plt.show()

Per-Dimension Bandwidth

Lastly, here's an example with 2D bandwidth matrix:

bw = jnp.array([[0.15, 3], [3, 1.3]])
kde = gaussian_kde(dataset, bw_method=bw)

fig = plot_kde(kde)
plt.show()

2D Bandwidth Matrix

The previous examples are using the convenience function gaussian_kde. This actually just calls the constructor method MultiVariateGaussianKDE.from_bandwidth. This function allows for customixing the bandwidth factor on the data-driven covariance matrix, but does not allow for specifying the covariance matrix directly. To do that, you can call the MultiVariateGaussianKDE constructor directly, or the from_covariance constructor method. To illustrate the difference between modifying the bandwidth and setting the full covariance matrix, consider the following example:

kde = MultiVariateGaussianKDE.from_covariance(
    dataset,
    jnp.array([[0.15, 0.1], [0.1, 1.3]]),
)

fig = plot_kde(kde)
plt.show()

Covariance Matrix

Acknowledgments

This package modifies code from JAX, which is licensed under the Apache License 2.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mvgkde-0.1.2.tar.gz (334.7 kB view details)

Uploaded Source

Built Distribution

mvgkde-0.1.2-py3-none-any.whl (16.7 kB view details)

Uploaded Python 3

File details

Details for the file mvgkde-0.1.2.tar.gz.

File metadata

  • Download URL: mvgkde-0.1.2.tar.gz
  • Upload date:
  • Size: 334.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for mvgkde-0.1.2.tar.gz
Algorithm Hash digest
SHA256 5852574bcf7f0cb331931b73ee085f21f63166aa78a1fe907a4bf5b00e1164bf
MD5 1c96e2139e36ea13a8f556323f18cf7f
BLAKE2b-256 e8ed064f687805821bf7c20dcabeeeb3b2f2dd97d6e1c4435592f270d040f9f4

See more details on using hashes here.

File details

Details for the file mvgkde-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: mvgkde-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 16.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for mvgkde-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 7826c7be59bb14316d12a7bfc6d8b6d2efe3da4253fa932e150835a4d257df54
MD5 98f882b24deba1fdb3dd4e3ee57bc380
BLAKE2b-256 befc7ffc21a352fbda6128da19f015834fbfe8d7fbfd56a150151327f9e4b017

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page