Skip to main content

No project description provided

Project description

JAX AI Stack

JAX is a Python package for array-oriented computation and program transformation. Built around it is a growing ecosystem of packages for specialized numerical computing across a range of domains; an up-to-date list of such projects can be found at Awesome JAX.

Though JAX is often compared to neural network libraries like pytorch, the JAX core package itself contains very little that is specific to neural network models. Instead, JAX encourages modularity, where domain-specific libraries are developed separately from the core package: this helps drive innovation as researchers and other users explore what is possible.

Within this larger, distributed ecosystem, there are a number of projects that Google researchers and engineers have found useful for implementing and deploying the models behind generative AI tools like Imagen, Gemini, and more. The JAX AI stack serves as a single point-of-entry for this suite of libraries, so you can install and begin using many of the same open source packages that Google developers are using in their everyday work.

To get started with the JAX AI stack, you can check out Getting started with JAX. This is still a work-in-progress, please check back for more documentation and tutorials in the coming weeks!

Installing the stack

The stack can be installed with the following command:

pip install jax-ai-stack

This pins particular versions of component projects which are known to work correctly together via the integration tests in this repository. Packages include:

  • JAX: the core JAX package, which includes array operations and program transformations like jit, vmap, grad, etc.
  • flax: build neural networks with JAX
  • ml_dtypes: NumPy dtype extensions for machine learning.
  • optax: gradient processing and optimization in JAX.
  • orbax: checkpointing and persistence utilities for JAX.

Optional packages

Additionally, there are optional packages you can install with pip extras. The following command:

pip install jax-ai-stack[grain]

will install a compatible version of the grain data loader (currently linux-only).

Similarly, the following command:

pip install jax-ai-stack[tfds]

will install a compatible version of tensorflow and tensorflow-datasets.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_ai_stack-2024.10.1.tar.gz (8.2 kB view details)

Uploaded Source

Built Distribution

jax_ai_stack-2024.10.1-py3-none-any.whl (11.0 kB view details)

Uploaded Python 3

File details

Details for the file jax_ai_stack-2024.10.1.tar.gz.

File metadata

  • Download URL: jax_ai_stack-2024.10.1.tar.gz
  • Upload date:
  • Size: 8.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.0 CPython/3.12.6

File hashes

Hashes for jax_ai_stack-2024.10.1.tar.gz
Algorithm Hash digest
SHA256 3b27bc272fcc7c2e33d449289af204bf4c7a30caf585bf7601c0693cf727bd74
MD5 8eb3c0b548c6a99e0c4c485f29d294a4
BLAKE2b-256 aedf651b4074d9af3b5f143347e8cba00b8fca9aa1a4e10e2a7dc309ea6a3e35

See more details on using hashes here.

File details

Details for the file jax_ai_stack-2024.10.1-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_ai_stack-2024.10.1-py3-none-any.whl
Algorithm Hash digest
SHA256 81bfa0df88f448d77d4fb48103edb0ba0dca1e71241495f27240e5f90150f608
MD5 28dda19e3e48491b46fa9e3c4ec918a4
BLAKE2b-256 c4e243b373d19d07389938740859c01d72a83024847bf23ca248c3f99c8a59f6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page