JAX bindings for the Flatiron Institute Nonuniform Fast Fourier Transform library
Project description
JAX bindings to FINUFFT
This package provides a JAX interface to (a subset of) the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library. Take a look at the FINUFFT docs for all the necessary definitions, conventions, and more information about the algorithms and their implementation. This package uses a low-level interface to directly expose the FINUFFT library to JAX's XLA backend, as well as implementing differentiation rules for the transforms.
Included features
This library includes CPU and GPU (CUDA) support. GPU support is implemented through the cuFINUFFT interface of the FINUFFT library.
Type 1 and 2 transforms
are supported in 1-, 2-, and 3-dimensions. All of these functions support
forward, reverse, and higher-order differentiation, as well as batching using
vmap
.
Installation
For now, only a source build is supported.
For building, you should only need a recent version of Python (>3.6) and
FFTW. GPU-enabled builds also require a working CUDA
compiler (i.e. the CUDA Toolkit), CUDA >= 11.8, and a compatible cuDNN (older versions of CUDA may work but
are untested). At runtime, you'll need numpy
and jax
.
First, clone the repo and cd
into the repo root (don't forget the --recursive
flag because FINUFFT is included as a submodule):
git clone --recursive https://github.com/flatironinstitute/jax-finufft
cd jax-finufft
Then, you can use conda
to set up a build environment (but you're welcome to
use whatever workflow works for you!). For example, for a CPU build, you can use:
conda create -n jax-finufft -c conda-forge python=3.10 numpy scipy fftw cxx-compiler
conda activate jax-finufft
export CPATH=$CONDA_PREFIX/include:$CPATH
python -m pip install "jax[cpu]"
python -m pip install .
The CPATH
export is needed so that the build can find the headers for libraries like FFTW installed through conda.
For a GPU build, while the CUDA libraries and compiler are nominally available through conda, our experience trying to install them this way suggests that the "traditional" way of obtaining the CUDA Toolkit directly from NVIDIA may work best (see related advice for Horovod). After installing the CUDA Toolkit, one can set up the rest of the dependencies with:
conda create -n gpu-jax-finufft -c conda-forge python=3.10 numpy scipy fftw 'gxx<12'
conda activate gpu-jax-finufft
export CPATH=$CONDA_PREFIX/include:$CPATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"
python -m pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python -m pip install .
Other ways of installing JAX are given on the JAX website; the "local CUDA" install methods are preferred for jax-finufft as this ensures the CUDA extensions are compiled with the same Toolkit version as the CUDA runtime.
In the above CMAKE_ARGS
line, you'll need to select the CUDA architecture(s) you wish to compile for. To query your GPU's CUDA architecture (compute capability), you can run:
$ nvidia-smi --query-gpu=compute_cap --format=csv,noheader
7.0
This corresponds to CMAKE_CUDA_ARCHITECTURES=70
, i.e.:
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"
Note that the pip installation is running CMake, so CMAKE_ARGS
has to be set before then, but is not needed at runtime.
At runtime, you may also need:
export LD_LIBRARY_PATH="$CUDA_PATH/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
If CUDA_PATH
isn't set, you'll need to replace it with the path to your CUDA installation in the above line, often something like /usr/local/cuda
.
For Flatiron users, the following environment setup script can be used instead of conda:
Environment script
ml modules/2.2
ml gcc
ml python/3.11
ml fftw
ml cuda/11
ml cudnn
ml nccl
export LD_LIBRARY_PATH=$CUDA_HOME/extras/CUPTI/lib64:$LD_LIBRARY_PATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON"
Usage
This library provides two high-level functions (and these should be all that you
generally need to interact with): nufft1
and nufft2
(for the two "types" of
transforms). If you're already familiar with the Python
interface to FINUFFT,
please note that the function signatures here are different!
For example, here's how you can do a 1-dimensional type 1 transform:
import numpy as np
from jax_finufft import nufft1
M = 100000
N = 200000
x = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
f = nufft1(N, c, x, eps=1e-6, iflag=1)
[!WARNING] As described in the FINUFFT documentation, the non-uniform points must lie within the range
[-3pi, 3pi]
, but this is not checked, because JAX currently doesn't have a good interface for runtime value checking. Unexpected crashes may occur if this condition is not met.
Noting that the eps
and iflag
are optional, and that (for good reason, I
promise!) the order of the positional arguments is reversed from the finufft
Python package.
The syntax for a 2-, or 3-dimensional transform is:
f = nufft1((Nx, Ny), c, x, y) # 2D
f = nufft1((Nx, Ny, Nz), c, x, y, z) # 3D
The syntax for a type 2 transform is (also allowing optional iflag
and eps
parameters):
c = nufft2(f, x) # 1D
c = nufft2(f, x, y) # 2D
c = nufft2(f, x, y, z) # 3D
All of these functions support batching using vmap
, and forward and reverse
mode differentiation.
Advanced usage
The tuning parameters for the library can be set using the opts
parameter to
nufft1
and nufft2
. For example, to explicitly set the CPU up-sampling
factor that FINUFFT should
use, you can update the example from above as follows:
from jax_finufft import options
opts = options.Opts(upsampfac=2.0)
nufft1(N, c, x, opts=opts)
The corresponding option for the GPU is gpu_upsampfac
. In fact, all options
for the GPU are prefixed with gpu_
.
One complication here is that the vector-Jacobian
product
for a NUFFT requires evaluating a NUFFT of a different type. This means that you
might want to separately tune the options for the forward and backward pass.
This can be achieved using the options.NestedOpts
interface. For example, to
use a different up-sampling factor for the forward and backward passes, the code
from above becomes:
import jax
opts = options.NestedOpts(
forward=options.Opts(upsampfac=2.0),
backward=options.Opts(upsampfac=1.25),
)
jax.grad(lambda args: nufft1(N, *args, opts=opts).real.sum())((c, x))
or, in this case equivalently:
opts = options.NestedOpts(
type1=options.Opts(upsampfac=2.0),
type2=options.Opts(upsampfac=1.25),
)
See the FINUFFT docs for
descriptions of all the CPU tuning parameters. The corresponding GPU parameters
are currently only listed in source code form in
cufinufft_opts.h
.
Similar libraries
- finufft: The "official" Python bindings to FINUFFT. A good choice if you're not already using JAX and if you don't need to differentiate through your transform.
- mrphys/tensorflow-nufft: TensorFlow bindings for FINUFFT and cuFINUFFT.
License & attribution
This package, developed by Dan Foreman-Mackey is licensed under the Apache License, Version 2.0, with the following copyright:
Copyright 2021, 2022, 2023 The Simons Foundation, Inc.
If you use this software, please cite the primary references listed on the FINUFFT docs.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
File details
Details for the file jax_finufft-0.1.0.tar.gz
.
File metadata
- Download URL: jax_finufft-0.1.0.tar.gz
- Upload date:
- Size: 2.6 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0c9173837fa0ae47b61074f8c05b246d9ca5b21bda6174beda8c27ea75c4f152 |
|
MD5 | 564167555b26b8c01788a52653612fbb |
|
BLAKE2b-256 | b7269aa275d78c4ae4abca4c8d095d2c1c1bf137dab8aaea07eab2b2f6e71ebb |
File details
Details for the file jax_finufft-0.1.0-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
.
File metadata
- Download URL: jax_finufft-0.1.0-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- Upload date:
- Size: 2.2 MB
- Tags: CPython 3.12+, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fd0cebca0ac3da173d30b5bf413fe7ba8fa5b0bf8483b5de7b1d2c55e0d640ce |
|
MD5 | e98274fdcf1fcc1f4f2de664a7184ad1 |
|
BLAKE2b-256 | c957713d26c173c245d42e4fa5afc00fe1dd5feb27bff4d6d6a355e95dcfa755 |
File details
Details for the file jax_finufft-0.1.0-cp312-abi3-macosx_11_0_arm64.whl
.
File metadata
- Download URL: jax_finufft-0.1.0-cp312-abi3-macosx_11_0_arm64.whl
- Upload date:
- Size: 1.3 MB
- Tags: CPython 3.12+, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8faebf86576abddb46d282fc0e5de740a3780691e0e430547412f321c481645d |
|
MD5 | bc12923035af8032d5baf2aabd5a42cd |
|
BLAKE2b-256 | abc76bc1b5f70502bc31b3f4b91e0f350db68bdaee6e4f5e666661b57a4ea08f |
File details
Details for the file jax_finufft-0.1.0-cp312-abi3-macosx_10_14_x86_64.whl
.
File metadata
- Download URL: jax_finufft-0.1.0-cp312-abi3-macosx_10_14_x86_64.whl
- Upload date:
- Size: 2.9 MB
- Tags: CPython 3.12+, macOS 10.14+ x86-64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b1cb11ec47b264e9a25cd87970bdc8828d7886dd5ebd481f12b5f5e6a02d104b |
|
MD5 | 69f526ab3048f97bd24a73edba484d82 |
|
BLAKE2b-256 | 4a5db78b4553b31a43351b4e20fa2669209882b2e25dfe9a0c4353790ed27fe5 |
File details
Details for the file jax_finufft-0.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
.
File metadata
- Download URL: jax_finufft-0.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- Upload date:
- Size: 2.2 MB
- Tags: CPython 3.11, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b09ad5ed078fe49ccbea7cd17043ea4739941cb0be7835a695c55d5129def919 |
|
MD5 | 30cba15526b52fc5d9818ba754bf5919 |
|
BLAKE2b-256 | 5e20f14375d8a9eb4662562d273803e2378a83497a74426a3ac852df8fa35f24 |
File details
Details for the file jax_finufft-0.1.0-cp311-cp311-macosx_11_0_arm64.whl
.
File metadata
- Download URL: jax_finufft-0.1.0-cp311-cp311-macosx_11_0_arm64.whl
- Upload date:
- Size: 1.3 MB
- Tags: CPython 3.11, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e0f4688b6f831c47591179cf79988481bf1437da5cb4775c839317c3b5c84ee4 |
|
MD5 | 979e6c4ab1948c91bfd15aec6642d8fc |
|
BLAKE2b-256 | f52468611ac67f151f3cdc56f78ad044fd407d499ca548574475fe28688de614 |
File details
Details for the file jax_finufft-0.1.0-cp311-cp311-macosx_10_14_x86_64.whl
.
File metadata
- Download URL: jax_finufft-0.1.0-cp311-cp311-macosx_10_14_x86_64.whl
- Upload date:
- Size: 2.9 MB
- Tags: CPython 3.11, macOS 10.14+ x86-64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e5a66da56d3956c077bc06faa54f2f615bda6abad4298480ea8d8f7d70a4af7f |
|
MD5 | e3247199df1b16c2a758835bfc3b0fa4 |
|
BLAKE2b-256 | 6baf63574c18bec5039d4fd478cbb9feef2d7707de82320094ef76d9fc680620 |
File details
Details for the file jax_finufft-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
.
File metadata
- Download URL: jax_finufft-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- Upload date:
- Size: 2.2 MB
- Tags: CPython 3.10, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | eb6c95c3223d6e4fec1a6559436b81e33463863a0c31e0fb7426152fe6c834a4 |
|
MD5 | 52caf206d76827acf3ca68e17e2e02a3 |
|
BLAKE2b-256 | 8a80ea22b5cba3f2e2ec57def22821ba2967d1035631b652882f7038bbb4ab51 |
File details
Details for the file jax_finufft-0.1.0-cp310-cp310-macosx_11_0_arm64.whl
.
File metadata
- Download URL: jax_finufft-0.1.0-cp310-cp310-macosx_11_0_arm64.whl
- Upload date:
- Size: 1.3 MB
- Tags: CPython 3.10, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2bbe7c67020cc8bd2fc9ab70421fa6aa3a3ca2b7cdfba18d2bdf9452101fb896 |
|
MD5 | 71bd0ca1e2e71532daad901563a24ea0 |
|
BLAKE2b-256 | 1eb98b72df683f3d5031d38f440a001dd8d3eb59a12fbaa01fa652e5e625a190 |
File details
Details for the file jax_finufft-0.1.0-cp310-cp310-macosx_10_14_x86_64.whl
.
File metadata
- Download URL: jax_finufft-0.1.0-cp310-cp310-macosx_10_14_x86_64.whl
- Upload date:
- Size: 2.9 MB
- Tags: CPython 3.10, macOS 10.14+ x86-64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e2fb906afd0d95cc8729a6b7139a64c5ded28c800b6007ade61b97cdb3b36c40 |
|
MD5 | 9c357307ba0c6bf0ec01b9269bd216cb |
|
BLAKE2b-256 | ee4c72c18cd06804cb34869411515c1dd0ae9d8a84308ddb50f222ac5053e58a |
File details
Details for the file jax_finufft-0.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
.
File metadata
- Download URL: jax_finufft-0.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- Upload date:
- Size: 2.2 MB
- Tags: CPython 3.9, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 606cb916883e1ec2842a0147bb71b8e24a6618506f2536ec177df00720a738a7 |
|
MD5 | f8f25325953e15ca59d53df8af3db722 |
|
BLAKE2b-256 | 50f4159d92173959ef1621ff79b9712a8140bc18321356250c3cebcecf66f39b |
File details
Details for the file jax_finufft-0.1.0-cp39-cp39-macosx_11_0_arm64.whl
.
File metadata
- Download URL: jax_finufft-0.1.0-cp39-cp39-macosx_11_0_arm64.whl
- Upload date:
- Size: 1.3 MB
- Tags: CPython 3.9, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5f354bcb22140f2c29014a0d270c27d374afe5264391904c56f1a6c562c10139 |
|
MD5 | 41d1d8d53ed96f1807a0114eb3e39dd6 |
|
BLAKE2b-256 | 46206704e153c2523c41a7e7e93d2557a457526b9a43169d3e20740b7f9227ee |
File details
Details for the file jax_finufft-0.1.0-cp39-cp39-macosx_10_14_x86_64.whl
.
File metadata
- Download URL: jax_finufft-0.1.0-cp39-cp39-macosx_10_14_x86_64.whl
- Upload date:
- Size: 2.9 MB
- Tags: CPython 3.9, macOS 10.14+ x86-64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c48382b45866fb078f6187810fc34632496fd8338e750adf7c0a7a66536b4118 |
|
MD5 | 03eae66bf0484c9814a39c81fd0bebd7 |
|
BLAKE2b-256 | 304d0bd8f3f262612bfa95b48feadfabd3730271dfba3e4597bb0ec6d8a2f927 |