Skip to main content

JAX bindings for the Flatiron Institute Nonuniform Fast Fourier Transform library

Project description

JAX bindings to FINUFFT

GitHub Tests Jenkins Tests

This package provides a JAX interface to (a subset of) the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library. Take a look at the FINUFFT docs for all the necessary definitions, conventions, and more information about the algorithms and their implementation. This package uses a low-level interface to directly expose the FINUFFT library to JAX's XLA backend, as well as implementing differentiation rules for the transforms.

Included features

This library includes CPU and GPU (CUDA) support. GPU support is implemented through the cuFINUFFT interface of the FINUFFT library.

Type 1 and 2 transforms are supported in 1-, 2-, and 3-dimensions. All of these functions support forward, reverse, and higher-order differentiation, as well as batching using vmap.

Installation

For now, only a source build is supported.

For building, you should only need a recent version of Python (>3.6) and FFTW. GPU-enabled builds also require a working CUDA compiler (i.e. the CUDA Toolkit), CUDA >= 11.8, and a compatible cuDNN (older versions of CUDA may work but are untested). At runtime, you'll need numpy and jax.

First, clone the repo and cd into the repo root (don't forget the --recursive flag because FINUFFT is included as a submodule):

git clone --recursive https://github.com/flatironinstitute/jax-finufft
cd jax-finufft

Then, you can use conda to set up a build environment (but you're welcome to use whatever workflow works for you!). For example, for a CPU build, you can use:

conda create -n jax-finufft -c conda-forge python=3.10 numpy scipy fftw cxx-compiler
conda activate jax-finufft
export CPATH=$CONDA_PREFIX/include:$CPATH
python -m pip install "jax[cpu]"
python -m pip install .

The CPATH export is needed so that the build can find the headers for libraries like FFTW installed through conda.

For a GPU build, while the CUDA libraries and compiler are nominally available through conda, our experience trying to install them this way suggests that the "traditional" way of obtaining the CUDA Toolkit directly from NVIDIA may work best (see related advice for Horovod). After installing the CUDA Toolkit, one can set up the rest of the dependencies with:

conda create -n gpu-jax-finufft -c conda-forge python=3.10 numpy scipy fftw 'gxx<12'
conda activate gpu-jax-finufft
export CPATH=$CONDA_PREFIX/include:$CPATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"
python -m pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python -m pip install .

Other ways of installing JAX are given on the JAX website; the "local CUDA" install methods are preferred for jax-finufft as this ensures the CUDA extensions are compiled with the same Toolkit version as the CUDA runtime.

In the above CMAKE_ARGS line, you'll need to select the CUDA architecture(s) you wish to compile for. To query your GPU's CUDA architecture (compute capability), you can run:

$ nvidia-smi --query-gpu=compute_cap --format=csv,noheader
7.0

This corresponds to CMAKE_CUDA_ARCHITECTURES=70, i.e.:

export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"

Note that the pip installation is running CMake, so CMAKE_ARGS has to be set before then, but is not needed at runtime.

At runtime, you may also need:

export LD_LIBRARY_PATH="$CUDA_PATH/extras/CUPTI/lib64:$LD_LIBRARY_PATH"

If CUDA_PATH isn't set, you'll need to replace it with the path to your CUDA installation in the above line, often something like /usr/local/cuda.

For Flatiron users, the following environment setup script can be used instead of conda:

Environment script
ml modules/2.2
ml gcc
ml python/3.11
ml fftw
ml cuda/11
ml cudnn
ml nccl

export LD_LIBRARY_PATH=$CUDA_HOME/extras/CUPTI/lib64:$LD_LIBRARY_PATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON"

Usage

This library provides two high-level functions (and these should be all that you generally need to interact with): nufft1 and nufft2 (for the two "types" of transforms). If you're already familiar with the Python interface to FINUFFT, please note that the function signatures here are different!

For example, here's how you can do a 1-dimensional type 1 transform:

import numpy as np
from jax_finufft import nufft1

M = 100000
N = 200000

x = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
f = nufft1(N, c, x, eps=1e-6, iflag=1)

[!WARNING] As described in the FINUFFT documentation, the non-uniform points must lie within the range [-3pi, 3pi], but this is not checked, because JAX currently doesn't have a good interface for runtime value checking. Unexpected crashes may occur if this condition is not met.

Noting that the eps and iflag are optional, and that (for good reason, I promise!) the order of the positional arguments is reversed from the finufft Python package.

The syntax for a 2-, or 3-dimensional transform is:

f = nufft1((Nx, Ny), c, x, y)  # 2D
f = nufft1((Nx, Ny, Nz), c, x, y, z)  # 3D

The syntax for a type 2 transform is (also allowing optional iflag and eps parameters):

c = nufft2(f, x)  # 1D
c = nufft2(f, x, y)  # 2D
c = nufft2(f, x, y, z)  # 3D

All of these functions support batching using vmap, and forward and reverse mode differentiation.

Advanced usage

The tuning parameters for the library can be set using the opts parameter to nufft1 and nufft2. For example, to explicitly set the CPU up-sampling factor that FINUFFT should use, you can update the example from above as follows:

from jax_finufft import options

opts = options.Opts(upsampfac=2.0)
nufft1(N, c, x, opts=opts)

The corresponding option for the GPU is gpu_upsampfac. In fact, all options for the GPU are prefixed with gpu_.

One complication here is that the vector-Jacobian product for a NUFFT requires evaluating a NUFFT of a different type. This means that you might want to separately tune the options for the forward and backward pass. This can be achieved using the options.NestedOpts interface. For example, to use a different up-sampling factor for the forward and backward passes, the code from above becomes:

import jax

opts = options.NestedOpts(
  forward=options.Opts(upsampfac=2.0),
  backward=options.Opts(upsampfac=1.25),
)
jax.grad(lambda args: nufft1(N, *args, opts=opts).real.sum())((c, x))

or, in this case equivalently:

opts = options.NestedOpts(
  type1=options.Opts(upsampfac=2.0),
  type2=options.Opts(upsampfac=1.25),
)

See the FINUFFT docs for descriptions of all the CPU tuning parameters. The corresponding GPU parameters are currently only listed in source code form in cufinufft_opts.h.

Similar libraries

  • finufft: The "official" Python bindings to FINUFFT. A good choice if you're not already using JAX and if you don't need to differentiate through your transform.
  • mrphys/tensorflow-nufft: TensorFlow bindings for FINUFFT and cuFINUFFT.

License & attribution

This package, developed by Dan Foreman-Mackey is licensed under the Apache License, Version 2.0, with the following copyright:

Copyright 2021, 2022, 2023 The Simons Foundation, Inc.

If you use this software, please cite the primary references listed on the FINUFFT docs.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_finufft-0.1.0rc2.tar.gz (2.6 MB view details)

Uploaded Source

Built Distributions

jax_finufft-0.1.0rc2-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB view details)

Uploaded CPython 3.12+ manylinux: glibc 2.17+ x86-64

jax_finufft-0.1.0rc2-cp312-abi3-macosx_11_0_arm64.whl (1.3 MB view details)

Uploaded CPython 3.12+ macOS 11.0+ ARM64

jax_finufft-0.1.0rc2-cp312-abi3-macosx_10_14_x86_64.whl (2.9 MB view details)

Uploaded CPython 3.12+ macOS 10.14+ x86-64

jax_finufft-0.1.0rc2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ x86-64

jax_finufft-0.1.0rc2-cp311-cp311-macosx_11_0_arm64.whl (1.3 MB view details)

Uploaded CPython 3.11 macOS 11.0+ ARM64

jax_finufft-0.1.0rc2-cp311-cp311-macosx_10_14_x86_64.whl (2.9 MB view details)

Uploaded CPython 3.11 macOS 10.14+ x86-64

jax_finufft-0.1.0rc2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

jax_finufft-0.1.0rc2-cp310-cp310-macosx_11_0_arm64.whl (1.3 MB view details)

Uploaded CPython 3.10 macOS 11.0+ ARM64

jax_finufft-0.1.0rc2-cp310-cp310-macosx_10_14_x86_64.whl (2.9 MB view details)

Uploaded CPython 3.10 macOS 10.14+ x86-64

jax_finufft-0.1.0rc2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

jax_finufft-0.1.0rc2-cp39-cp39-macosx_11_0_arm64.whl (1.3 MB view details)

Uploaded CPython 3.9 macOS 11.0+ ARM64

jax_finufft-0.1.0rc2-cp39-cp39-macosx_10_14_x86_64.whl (2.9 MB view details)

Uploaded CPython 3.9 macOS 10.14+ x86-64

File details

Details for the file jax_finufft-0.1.0rc2.tar.gz.

File metadata

  • Download URL: jax_finufft-0.1.0rc2.tar.gz
  • Upload date:
  • Size: 2.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for jax_finufft-0.1.0rc2.tar.gz
Algorithm Hash digest
SHA256 f04c61f51d9c312d5cafb59482e5c2d67cac5b0963182ea41608b487db6ea49c
MD5 4a62b99e81e22f87997a0ddc55d4ae69
BLAKE2b-256 67689561da3428c7964cf2e00fa19884939339375e8e53f5deac7fec56afb2bd

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0rc2-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0rc2-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 66cf11e6f36727b0ee9a1f457243f3b10b0fe4888029cb29148e26e90b37ec76
MD5 cc5856de33d883ec41efbed4eb538a28
BLAKE2b-256 1ca328211ea386f76af0a0828e19b3f4032266f895568ce3ba9f3a693f342356

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0rc2-cp312-abi3-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0rc2-cp312-abi3-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 5111362c5be681891eeea118056a262c543aee06513cac7cb1a50b3a9780143b
MD5 3c051522bd6df5096b5e0e619e3c5357
BLAKE2b-256 c1ce9c84325f8765b2c9b0b101f5090488d501ee70046709b1d7b510d3dc4eb5

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0rc2-cp312-abi3-macosx_10_14_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0rc2-cp312-abi3-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 60cadb91c31e45215e5923014bc0deedab25428f973619e1834070b9c2929511
MD5 b0630f398bcc04704068f23cfc2058aa
BLAKE2b-256 2c2074f14159b38428676996a55a7be9bcec9550265c3c333b74a4e09acd57f5

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0rc2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0rc2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 96d2d87b9b402e34abb51dfba8ddb25bfecb3f5a31a2f520204c61708f1f855e
MD5 603709a8fa620ab5452e727bd6ff3de0
BLAKE2b-256 f9760b821717c219aa5dd31648661197d0a1ecf685483b81fc268a01cb74f30a

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0rc2-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0rc2-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 b8ad7d7619b474a3a664528128b29416e2e59ddd11f420b28f0142531b7fd5c8
MD5 575e97c399c0b3b26cac5bad3677b464
BLAKE2b-256 4df4109ae2a60edeca644dc4e0d22b89248152e311c081fb2ee12fedc87593bd

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0rc2-cp311-cp311-macosx_10_14_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0rc2-cp311-cp311-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 8d9c26da9685df760eda78492167542fb0c0926fcf28b55c03257d4d52f283d4
MD5 00662976210c8b2b5b0c4b6b1bfd0fe5
BLAKE2b-256 93ca870c504548b9b35bc7acc0a36017415092f0044a65ebc27bb7d1bca24553

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0rc2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0rc2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 547e237973432ac25c32177e94389ede9963065f06a9a846838e389c548c1016
MD5 55bea3a5cb284af189ca130a2774cddb
BLAKE2b-256 29246a0e9d5e543a017b9f63d7062fefd48159af3d01dd35a1fa2867a24820b2

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0rc2-cp310-cp310-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0rc2-cp310-cp310-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 577adbbb1cb474ea9445d0734ce570b4786d2ad3d62711e9c8acf4e3e80bc083
MD5 655899c7b9eb2d30d20523eb72b0a5c8
BLAKE2b-256 305dbe41a176b2ef1d06bc96e6db599b274f108fe4fe0a01d600290ff61b0ee2

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0rc2-cp310-cp310-macosx_10_14_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0rc2-cp310-cp310-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 7645ff1f04ebc0c8cd258a3199ab737a90ade987f395d08a55aa4bcfdb2e3320
MD5 1824780c1f523aaecc74f695409cb077
BLAKE2b-256 5d95621aaead958896baec5f83339c49085c46bcccc4808248a8a573ba418458

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0rc2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0rc2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 007da54232ccddd5e586b5f3ca84fadd0c454eaa3e89f3ae3e5d45b480e0ccd5
MD5 8ad77b7bc0149f830fccb2c83791e4b8
BLAKE2b-256 e2e41b6d1c59cb4089be90049d35c995fca9433facb29e00117208c31baf7c71

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0rc2-cp39-cp39-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0rc2-cp39-cp39-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 3339e6476883703639fe0688c6d9b9fb934b2a3c99513e2344ddfe2c7ea4ef51
MD5 59336184499547fa124e10182ad740de
BLAKE2b-256 67b709044651c28843d6b877c7f8f3d9a25e6b524eaad4d352be38b0e0bf7276

See more details on using hashes here.

File details

Details for the file jax_finufft-0.1.0rc2-cp39-cp39-macosx_10_14_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-0.1.0rc2-cp39-cp39-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 559132fa4ecaac218474009c5cab64ff2dc262c423253b2fb5b2d03492c90f70
MD5 9649ac7ec0c4f3533630eb53c367edeb
BLAKE2b-256 fca1c4ba21327474a9cebf8143032c0e749118aed50d3204a5ca902f71dd987a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page