Skip to main content

MultiVariate Gaussian Kernel Density Estimator

Project description

mvgkde

MultiVariate Gaussian Kernel Density Estimator in JAX.

This is a micro-package, containing the single class MultiVarGaussianKDE (and helper function gaussian_kde) to estimate the probability density function of a multivariate dataset using a Gaussian kernel. This package modifies the jax.scipy.gaussian_kde class (which is based on the scipy.stats.gaussian_kde class), but allows for full control over the covariance matrix of the kernel, even per-dimension bandwidths. See the Documentation below for more information.

Installation

PyPI version PyPI platforms

pip install mvgkde

Documentation

Actions Status

For these examples we will use the following imports:

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np

from mvgkde import MultiVariateGaussianKDE, gaussian_kde  # This package

And we will generate a dataset to work with:

key = jr.key(0)
dataset = jr.normal(key, (2, 1000))

Lastly we will define a plotting function:

# Create a grid of points
(xmin, ymin) = dataset.min(axis=1)
(xmax, ymax) = dataset.max(axis=1)
X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([X.ravel(), Y.ravel()])


def plot_kde(kde: MultiVariateGaussianKDE) -> plt.Figure:
    # Evaluate the KDE on the grid
    Z = np.reshape(kde(positions).T, X.shape)

    # Plot the results
    fig, ax = plt.subplots()
    ax.imshow(np.rot90(Z), cmap=plt.cm.gist_earth_r, extent=[xmin, xmax, ymin, ymax])
    ax.plot(dataset[0], dataset[1], "k.", markersize=2)
    ax.set(
        title="2D Kernel Density Estimation using JAX",
        xlabel="X-axis",
        xlim=[xmin, xmax],
        ylabel="Y-axis",
        ylim=[ymin, ymax],
    )

    return fig

Here's an example that can be done with jax.scipy.gaussian_kde:

kde = gaussian_kde(dataset, bw="scott")

fig = plot_kde(kde)
plt.show()

Scotts Rule

Here's an example with a per-dimension bandwidth. This is not possible with the jax.scipy.gaussian_kde:

kde = gaussian_kde(dataset, bw=jnp.array([0.15, 1.3]))

fig = plot_kde(kde)
plt.show()

Per-Dimension Bandwidth

Lastly, here's an example with 2D bandwidth matrix:

bw = jnp.array([[0.15, 3], [3, 1.3]])
kde = gaussian_kde(dataset, bw=bw)

fig = plot_kde(kde)
plt.show()

2D Bandwidth Matrix

The previous examples are using the convenience function gaussian_kde. This actually just calls the constructor method MultiVariateGaussianKDE.from_bandwidth. This function allows for customixing the bandwidth factor on the data-driven covariance matrix, but does not allow for specifying the covariance matrix directly. To do that, you can call the MultiVariateGaussianKDE constructor directly, or the from_covariance constructor method. To illustrate the difference between modifying the bandwidth and setting the full covariance matrix, consider the following example:

kde = MultiVariateGaussianKDE.from_covariance(
    dataset,
    jnp.array([[0.15, 0.1], [0.1, 1.3]]),
)

fig = plot_kde(kde)
plt.show()

Covariance Matrix

Acknowledgments

This package modifies code from JAX, which is licensed under the Apache License 2.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mvgkde-0.1.0.tar.gz (334.5 kB view details)

Uploaded Source

Built Distribution

mvgkde-0.1.0-py3-none-any.whl (16.6 kB view details)

Uploaded Python 3

File details

Details for the file mvgkde-0.1.0.tar.gz.

File metadata

  • Download URL: mvgkde-0.1.0.tar.gz
  • Upload date:
  • Size: 334.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for mvgkde-0.1.0.tar.gz
Algorithm Hash digest
SHA256 7582a8eaf596a00739e72f57363c11b0dbde542b17a16beac9e00d02608def0a
MD5 4fb4b96051fcc297f3ae4168f2b954f2
BLAKE2b-256 6651638cca308c2afadf13328bb3f818d3028eeda4f9fe70bc4ea807355e2cad

See more details on using hashes here.

File details

Details for the file mvgkde-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: mvgkde-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 16.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for mvgkde-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 86aea61af51967a0716bb5cda5017c4890eea6ca541f3aaea31fd6de653d4a40
MD5 f604ce69fc54263f27541d823d5945ba
BLAKE2b-256 2aedba30b7467b11223d0d343aeb8d0c14fc663c4437f93c22eb07c22e89de53

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page