Skip to main content

No project description provided

Project description

JAX AI Stack

Continuous integration PyPI version

JAX is a Python package for array-oriented computation and program transformation. Built around it is a growing ecosystem of packages for specialized numerical computing across a range of domains; an up-to-date list of such projects can be found at Awesome JAX.

Though JAX is often compared to neural network libraries like PyTorch, the JAX core package itself contains very little that is specific to neural network models. Instead, JAX encourages modularity, where domain-specific libraries are developed separately from the core package: this helps drive innovation as researchers and other users explore what is possible.

Within this larger, distributed ecosystem, there are a number of projects that Google researchers and engineers have found useful for implementing and deploying the models behind generative AI tools like Imagen, Gemini, and more. The JAX AI stack serves as a single point-of-entry for this suite of libraries, so you can install and begin using many of the same open source packages that Google developers are using in their everyday work.

To get started with the JAX AI stack, you can check out Getting started with JAX. This is still a work-in-progress, please check back for more documentation and tutorials in the coming weeks!

Installing the stack

The stack can be installed with the following command:

pip install jax-ai-stack

This pins particular versions of component projects which are known to work correctly together via the integration tests in this repository. Packages include:

  • JAX: the core JAX package, which includes array operations and program transformations like jit, vmap, grad, etc.
  • flax: build neural networks with JAX
  • ml_dtypes: NumPy dtype extensions for machine learning.
  • optax: gradient processing and optimization in JAX.
  • orbax: checkpointing and persistence utilities for JAX.

Optional packages

Additionally, there are optional packages you can install with pip extras. The following command:

pip install jax-ai-stack[grain]

will install a compatible version of the grain data loader (currently linux-only).

Similarly, the following command:

pip install jax-ai-stack[tfds]

will install a compatible version of tensorflow and tensorflow-datasets.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_ai_stack-2024.11.1.tar.gz (8.3 kB view details)

Uploaded Source

Built Distribution

jax_ai_stack-2024.11.1-py3-none-any.whl (11.0 kB view details)

Uploaded Python 3

File details

Details for the file jax_ai_stack-2024.11.1.tar.gz.

File metadata

  • Download URL: jax_ai_stack-2024.11.1.tar.gz
  • Upload date:
  • Size: 8.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for jax_ai_stack-2024.11.1.tar.gz
Algorithm Hash digest
SHA256 3ea54ebeba9d0f140b8f55e1171d76cea6a32eace4f8392f36147752b29449d3
MD5 8fa43f4eb2c9281d68f7c3cb9e7f137b
BLAKE2b-256 8c368afd16c496d2e26f54be17231380cbd3325c91377d0f9ba8d7a5e47f4dfd

See more details on using hashes here.

File details

Details for the file jax_ai_stack-2024.11.1-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_ai_stack-2024.11.1-py3-none-any.whl
Algorithm Hash digest
SHA256 942e7fb7ae544de9459881d384b0cf1b8c3223b26e920f9dde6d6c24e7ae2c21
MD5 268f377752841614f41809c843c6246c
BLAKE2b-256 b8b47b50a8451cc00ec2cf6d74d6d006e1cad1ab80dd2d2574ac96f57fef7085

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page