Skip to main content

JAX bindings for the Flatiron Institute Nonuniform Fast Fourier Transform library

Project description

JAX bindings to FINUFFT

GitHub Tests Jenkins Tests

This package provides a JAX interface to (a subset of) the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library. Take a look at the FINUFFT docs for all the necessary definitions, conventions, and more information about the algorithms and their implementation. This package uses a low-level interface to directly expose the FINUFFT library to JAX's XLA backend, as well as implementing differentiation rules for the transforms.

Included features

This library includes CPU and GPU (CUDA) support. GPU support is implemented through the cuFINUFFT interface of the FINUFFT library.

Type 1 and 2 transforms are supported in 1-, 2-, and 3-dimensions. All of these functions support forward, reverse, and higher-order differentiation, as well as batching using vmap.

Installation

For now, only a source build is supported.

For building, you should only need a recent version of Python (>3.6) and FFTW. GPU-enabled builds also require a working CUDA compiler (i.e. the CUDA Toolkit), CUDA >= 11.8, and a compatible cuDNN (older versions of CUDA may work but are untested). At runtime, you'll need numpy and jax.

First, clone the repo and cd into the repo root (don't forget the --recursive flag because FINUFFT is included as a submodule):

git clone --recursive https://github.com/flatironinstitute/jax-finufft
cd jax-finufft

Then, you can use conda to set up a build environment (but you're welcome to use whatever workflow works for you!). For example, for a CPU build, you can use:

conda create -n jax-finufft -c conda-forge python=3.10 numpy scipy fftw cxx-compiler
conda activate jax-finufft
export CPATH=$CONDA_PREFIX/include:$CPATH
python -m pip install "jax[cpu]"
python -m pip install .

The CPATH export is needed so that the build can find the headers for libraries like FFTW installed through conda.

For a GPU build, while the CUDA libraries and compiler are nominally available through conda, our experience trying to install them this way suggests that the "traditional" way of obtaining the CUDA Toolkit directly from NVIDIA may work best (see related advice for Horovod). After installing the CUDA Toolkit, one can set up the rest of the dependencies with:

conda create -n gpu-jax-finufft -c conda-forge python=3.10 numpy scipy fftw 'gxx<12'
conda activate gpu-jax-finufft
export CPATH=$CONDA_PREFIX/include:$CPATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"
python -m pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python -m pip install .

Other ways of installing JAX are given on the JAX website; the "local CUDA" install methods are preferred for jax-finufft as this ensures the CUDA extensions are compiled with the same Toolkit version as the CUDA runtime.

In the above CMAKE_ARGS line, you'll need to select the CUDA architecture(s) you wish to compile for. To query your GPU's CUDA architecture (compute capability), you can run:

$ nvidia-smi --query-gpu=compute_cap --format=csv,noheader
7.0

This corresponds to CMAKE_CUDA_ARCHITECTURES=70, i.e.:

export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"

Note that the pip installation is running CMake, so CMAKE_ARGS has to be set before then, but is not needed at runtime.

At runtime, you may also need:

export LD_LIBRARY_PATH="$CUDA_PATH/extras/CUPTI/lib64:$LD_LIBRARY_PATH"

If CUDA_PATH isn't set, you'll need to replace it with the path to your CUDA installation in the above line, often something like /usr/local/cuda.

For Flatiron users, the following environment setup script can be used instead of conda:

Environment script
ml modules/2.2
ml gcc
ml python/3.11
ml fftw
ml cuda/11
ml cudnn
ml nccl

export LD_LIBRARY_PATH=$CUDA_HOME/extras/CUPTI/lib64:$LD_LIBRARY_PATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON"

Usage

This library provides two high-level functions (and these should be all that you generally need to interact with): nufft1 and nufft2 (for the two "types" of transforms). If you're already familiar with the Python interface to FINUFFT, please note that the function signatures here are different!

For example, here's how you can do a 1-dimensional type 1 transform:

import numpy as np
from jax_finufft import nufft1

M = 100000
N = 200000

x = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
f = nufft1(N, c, x, eps=1e-6, iflag=1)

[!WARNING] As described in the FINUFFT documentation, the non-uniform points must lie within the range [-3pi, 3pi], but this is not checked, because JAX currently doesn't have a good interface for runtime value checking. Unexpected crashes may occur if this condition is not met.

Noting that the eps and iflag are optional, and that (for good reason, I promise!) the order of the positional arguments is reversed from the finufft Python package.

The syntax for a 2-, or 3-dimensional transform is:

f = nufft1((Nx, Ny), c, x, y)  # 2D
f = nufft1((Nx, Ny, Nz), c, x, y, z)  # 3D

The syntax for a type 2 transform is (also allowing optional iflag and eps parameters):

c = nufft2(f, x)  # 1D
c = nufft2(f, x, y)  # 2D
c = nufft2(f, x, y, z)  # 3D

All of these functions support batching using vmap, and forward and reverse mode differentiation.

Advanced usage

The tuning parameters for the library can be set using the opts parameter to nufft1 and nufft2. For example, to explicitly set the CPU up-sampling factor that FINUFFT should use, you can update the example from above as follows:

from jax_finufft import options

opts = options.Opts(upsampfac=2.0)
nufft1(N, c, x, opts=opts)

The corresponding option for the GPU is gpu_upsampfac. In fact, all options for the GPU are prefixed with gpu_.

One complication here is that the vector-Jacobian product for a NUFFT requires evaluating a NUFFT of a different type. This means that you might want to separately tune the options for the forward and backward pass. This can be achieved using the options.NestedOpts interface. For example, to use a different up-sampling factor for the forward and backward passes, the code from above becomes:

import jax

opts = options.NestedOpts(
  forward=options.Opts(upsampfac=2.0),
  backward=options.Opts(upsampfac=1.25),
)
jax.grad(lambda args: nufft1(N, *args, opts=opts).real.sum())((c, x))

or, in this case equivalently:

opts = options.NestedOpts(
  type1=options.Opts(upsampfac=2.0),
  type2=options.Opts(upsampfac=1.25),
)

See the FINUFFT docs for descriptions of all the CPU tuning parameters. The corresponding GPU parameters are currently only listed in source code form in cufinufft_opts.h.

Similar libraries

  • finufft: The "official" Python bindings to FINUFFT. A good choice if you're not already using JAX and if you don't need to differentiate through your transform.
  • mrphys/tensorflow-nufft: TensorFlow bindings for FINUFFT and cuFINUFFT.

License & attribution

This package, developed by Dan Foreman-Mackey is licensed under the Apache License, Version 2.0, with the following copyright:

Copyright 2021, 2022, 2023 The Simons Foundation, Inc.

If you use this software, please cite the primary references listed on the FINUFFT docs.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_finufft-0.1.0.tar.gz (2.6 MB view details)

Uploaded Source

Built Distributions

jax_finufft-0.1.0-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB view details)

Uploaded CPython 3.12+ manylinux: glibc 2.17+ x86-64

jax_finufft-0.1.0-cp312-abi3-macosx_11_0_arm64.whl (1.3 MB view details)

Uploaded CPython 3.12+ macOS 11.0+ ARM64

jax_finufft-0.1.0-cp312-abi3-macosx_10_14_x86_64.whl (2.9 MB view details)

Uploaded CPython 3.12+ macOS 10.14+ x86-64

jax_finufft-0.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ x86-64

jax_finufft-0.1.0-cp311-cp311-macosx_11_0_arm64.whl (1.3 MB view details)

Uploaded CPython 3.11 macOS 11.0+ ARM64

jax_finufft-0.1.0-cp311-cp311-macosx_10_14_x86_64.whl (2.9 MB view details)

Uploaded CPython 3.11 macOS 10.14+ x86-64

jax_finufft-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

jax_finufft-0.1.0-cp310-cp310-macosx_11_0_arm64.whl (1.3 MB view details)

Uploaded CPython 3.10 macOS 11.0+ ARM64

jax_finufft-0.1.0-cp310-cp310-macosx_10_14_x86_64.whl (2.9 MB view details)

Uploaded CPython 3.10 macOS 10.14+ x86-64

jax_finufft-0.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

jax_finufft-0.1.0-cp39-cp39-macosx_11_0_arm64.whl (1.3 MB view details)

Uploaded CPython 3.9 macOS 11.0+ ARM64

jax_finufft-0.1.0-cp39-cp39-macosx_10_14_x86_64.whl (2.9 MB view details)

Uploaded CPython 3.9 macOS 10.14+ x86-64

File details

Details for the file jax_finufft-0.1.0.tar.gz.

File metadata

  • Download URL: jax_finufft-0.1.0.tar.gz
  • Upload date:
  • Size: 2.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for jax_finufft-0.1.0.tar.gz
Algorithm Hash digest
SHA256 0c9173837fa0ae47b61074f8c05b246d9ca5b21bda6174beda8c27ea75c4f152
MD5 564167555b26b8c01788a52653612fbb
BLAKE2b-256 b7269aa275d78c4ae4abca4c8d095d2c1c1bf137dab8aaea07eab2b2f6e71ebb

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 fd0cebca0ac3da173d30b5bf413fe7ba8fa5b0bf8483b5de7b1d2c55e0d640ce
MD5 e98274fdcf1fcc1f4f2de664a7184ad1
BLAKE2b-256 c957713d26c173c245d42e4fa5afc00fe1dd5feb27bff4d6d6a355e95dcfa755

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0-cp312-abi3-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0-cp312-abi3-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 8faebf86576abddb46d282fc0e5de740a3780691e0e430547412f321c481645d
MD5 bc12923035af8032d5baf2aabd5a42cd
BLAKE2b-256 abc76bc1b5f70502bc31b3f4b91e0f350db68bdaee6e4f5e666661b57a4ea08f

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0-cp312-abi3-macosx_10_14_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0-cp312-abi3-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 b1cb11ec47b264e9a25cd87970bdc8828d7886dd5ebd481f12b5f5e6a02d104b
MD5 69f526ab3048f97bd24a73edba484d82
BLAKE2b-256 4a5db78b4553b31a43351b4e20fa2669209882b2e25dfe9a0c4353790ed27fe5

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 b09ad5ed078fe49ccbea7cd17043ea4739941cb0be7835a695c55d5129def919
MD5 30cba15526b52fc5d9818ba754bf5919
BLAKE2b-256 5e20f14375d8a9eb4662562d273803e2378a83497a74426a3ac852df8fa35f24

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 e0f4688b6f831c47591179cf79988481bf1437da5cb4775c839317c3b5c84ee4
MD5 979e6c4ab1948c91bfd15aec6642d8fc
BLAKE2b-256 f52468611ac67f151f3cdc56f78ad044fd407d499ca548574475fe28688de614

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0-cp311-cp311-macosx_10_14_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0-cp311-cp311-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 e5a66da56d3956c077bc06faa54f2f615bda6abad4298480ea8d8f7d70a4af7f
MD5 e3247199df1b16c2a758835bfc3b0fa4
BLAKE2b-256 6baf63574c18bec5039d4fd478cbb9feef2d7707de82320094ef76d9fc680620

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 eb6c95c3223d6e4fec1a6559436b81e33463863a0c31e0fb7426152fe6c834a4
MD5 52caf206d76827acf3ca68e17e2e02a3
BLAKE2b-256 8a80ea22b5cba3f2e2ec57def22821ba2967d1035631b652882f7038bbb4ab51

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0-cp310-cp310-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0-cp310-cp310-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 2bbe7c67020cc8bd2fc9ab70421fa6aa3a3ca2b7cdfba18d2bdf9452101fb896
MD5 71bd0ca1e2e71532daad901563a24ea0
BLAKE2b-256 1eb98b72df683f3d5031d38f440a001dd8d3eb59a12fbaa01fa652e5e625a190

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0-cp310-cp310-macosx_10_14_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0-cp310-cp310-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 e2fb906afd0d95cc8729a6b7139a64c5ded28c800b6007ade61b97cdb3b36c40
MD5 9c357307ba0c6bf0ec01b9269bd216cb
BLAKE2b-256 ee4c72c18cd06804cb34869411515c1dd0ae9d8a84308ddb50f222ac5053e58a

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 606cb916883e1ec2842a0147bb71b8e24a6618506f2536ec177df00720a738a7
MD5 f8f25325953e15ca59d53df8af3db722
BLAKE2b-256 50f4159d92173959ef1621ff79b9712a8140bc18321356250c3cebcecf66f39b

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0-cp39-cp39-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0-cp39-cp39-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 5f354bcb22140f2c29014a0d270c27d374afe5264391904c56f1a6c562c10139
MD5 41d1d8d53ed96f1807a0114eb3e39dd6
BLAKE2b-256 46206704e153c2523c41a7e7e93d2557a457526b9a43169d3e20740b7f9227ee

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0-cp39-cp39-macosx_10_14_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0-cp39-cp39-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 c48382b45866fb078f6187810fc34632496fd8338e750adf7c0a7a66536b4118
MD5 03eae66bf0484c9814a39c81fd0bebd7
BLAKE2b-256 304d0bd8f3f262612bfa95b48feadfabd3730271dfba3e4597bb0ec6d8a2f927

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page