MultiVariate Gaussian Kernel Density Estimator
Project description
mvgkde
MultiVariate Gaussian Kernel Density Estimator in JAX.
This is a micro-package, containing the single class MultiVarGaussianKDE
(and
helper function gaussian_kde
) to estimate the probability density function of
a multivariate dataset using a Gaussian kernel. This package modifies the
jax.scipy.stats.gaussian_kde
class (which is based on the
scipy.stats.gaussian_kde
class), but allows for full control over the
covariance matrix of the kernel, even per-dimension bandwidths. See the
Documentation below for more information.
Installation
pip install mvgkde
Documentation
For these examples we will use the following imports:
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np
from mvgkde import MultiVariateGaussianKDE, gaussian_kde # This package
And we will generate a dataset to work with:
key = jr.key(0)
dataset = jr.normal(key, (2, 1000))
Lastly we will define a plotting function:
# Create a grid of points
(xmin, ymin) = dataset.min(axis=1)
(xmax, ymax) = dataset.max(axis=1)
X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([X.ravel(), Y.ravel()])
def plot_kde(kde: MultiVariateGaussianKDE) -> plt.Figure:
# Evaluate the KDE on the grid
Z = np.reshape(kde(positions).T, X.shape)
# Plot the results
fig, ax = plt.subplots()
ax.imshow(np.rot90(Z), cmap=plt.cm.gist_earth_r, extent=[xmin, xmax, ymin, ymax])
ax.plot(dataset[0], dataset[1], "k.", markersize=2)
ax.set(
title="2D Kernel Density Estimation using JAX",
xlabel="X-axis",
xlim=[xmin, xmax],
ylabel="Y-axis",
ylim=[ymin, ymax],
)
return fig
Here's an example that can be done with jax.scipy.stats.gaussian_kde
:
kde = gaussian_kde(dataset, bw_method="scott")
fig = plot_kde(kde)
plt.show()
Here's an example with a per-dimension bandwidth. This is not possible with the
jax.scipy.stats.gaussian_kde
:
kde = gaussian_kde(dataset, bw_method=jnp.array([0.15, 1.3]))
fig = plot_kde(kde)
plt.show()
Lastly, here's an example with 2D bandwidth matrix:
bw = jnp.array([[0.15, 3], [3, 1.3]])
kde = gaussian_kde(dataset, bw_method=bw)
fig = plot_kde(kde)
plt.show()
The previous examples are using the convenience function gaussian_kde
. This
actually just calls the constructor method
MultiVariateGaussianKDE.from_bandwidth
. This function allows for customixing
the bandwidth factor on the data-driven covariance matrix, but does not allow
for specifying the covariance matrix directly. To do that, you can call the
MultiVariateGaussianKDE
constructor directly, or the from_covariance
constructor method. To illustrate the difference between modifying the bandwidth
and setting the full covariance matrix, consider the following example:
kde = MultiVariateGaussianKDE.from_covariance(
dataset,
jnp.array([[0.15, 0.1], [0.1, 1.3]]),
)
fig = plot_kde(kde)
plt.show()
Acknowledgments
This package modifies code from JAX, which is licensed under the Apache License 2.0.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file mvgkde-0.2.0.tar.gz
.
File metadata
- Download URL: mvgkde-0.2.0.tar.gz
- Upload date:
- Size: 336.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3cf6400661cd7be6d175d29c5760721aa89edbfd27769ea69c0f08bb64ab83b1 |
|
MD5 | 75b6b12c66edc2089a8bb70fc07fb10b |
|
BLAKE2b-256 | 33f54d3bc66ef6b5358a6693d052cc9a440a3ca064cfae8a044acf9af63a1632 |
Provenance
The following attestation bundles were made for mvgkde-0.2.0.tar.gz
:
Publisher:
cd.yml
on nstarman/mvgkde
-
Statement type:
https://in-toto.io/Statement/v1
- Predicate type:
https://docs.pypi.org/attestations/publish/v1
- Subject name:
mvgkde-0.2.0.tar.gz
- Subject digest:
3cf6400661cd7be6d175d29c5760721aa89edbfd27769ea69c0f08bb64ab83b1
- Sigstore transparency entry: 148435919
- Sigstore integration time:
- Predicate type:
File details
Details for the file mvgkde-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: mvgkde-0.2.0-py3-none-any.whl
- Upload date:
- Size: 17.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | cc16d22785d61d0654d322d6132436a507cd90c79379178c78a683efad231dc1 |
|
MD5 | 2319f6c4cd1e2de12bfa1801847e90f5 |
|
BLAKE2b-256 | d97f305dc18b35ed5f0e1a13acfb30efe9b9e997d21371778ce8e946430b04e1 |
Provenance
The following attestation bundles were made for mvgkde-0.2.0-py3-none-any.whl
:
Publisher:
cd.yml
on nstarman/mvgkde
-
Statement type:
https://in-toto.io/Statement/v1
- Predicate type:
https://docs.pypi.org/attestations/publish/v1
- Subject name:
mvgkde-0.2.0-py3-none-any.whl
- Subject digest:
cc16d22785d61d0654d322d6132436a507cd90c79379178c78a683efad231dc1
- Sigstore transparency entry: 148435920
- Sigstore integration time:
- Predicate type: