Skip to main content

No project description provided

Project description

JAX AI Stack

JAX is a Python package for array-oriented computation and program transformation. Built around it is a growing ecosystem of packages for specialized numerical computing across a range of domains; an up-to-date list of such projects can be found at Awesome JAX.

Though JAX is often compared to neural network libraries like pytorch, the JAX core package itself contains very little that is specific to neural network models. Instead, JAX encourages modularity, where domain-specific libraries are developed separately from the core package: this helps drive innovation as researchers and other users explore what is possible.

Within this larger, distributed ecosystem, there are a number of projects that Google researchers and engineers have found useful for implementing and deploying the models behind generative AI tools like Imagen, Gemini, and more. The JAX AI stack serves as a single point-of-entry for this suite of libraries, so you can install and begin using many of the same open source packages that Google developers are using in their everyday work.

To get started with the JAX AI stack, you can check out Getting started with JAX. This is still a work-in-progress, please check back for more documentation and tutorials in the coming weeks!

Installing the stack

The stack can be installed with the following command:

pip install jax-ai-stack

This pins particular versions of component projects which are known to work correctly together via the integration tests in this repository. Packages include:

  • JAX: the core JAX package, which includes array operations and program transformations like jit, vmap, grad, etc.
  • flax: build neural networks with JAX
  • ml_dtypes: NumPy dtype extensions for machine learning.
  • optax: gradient processing and optimization in JAX.
  • orbax: checkpointing and persistence utilities for JAX.

Optional packages

Additionally, there are optional packages you can install with pip extras. The following command:

pip install jax-ai-stack[grain]

will install a compatible version of the grain data loader (currently linux-only).

Similarly, the following command:

pip install jax-ai-stack[tfds]

will install a compatible version of tensorflow and tensorflow-datasets.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_ai_stack-2024.10.1b1.tar.gz (8.2 kB view details)

Uploaded Source

Built Distribution

jax_ai_stack-2024.10.1b1-py3-none-any.whl (11.0 kB view details)

Uploaded Python 3

File details

Details for the file jax_ai_stack-2024.10.1b1.tar.gz.

File metadata

  • Download URL: jax_ai_stack-2024.10.1b1.tar.gz
  • Upload date:
  • Size: 8.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.0 CPython/3.12.6

File hashes

Hashes for jax_ai_stack-2024.10.1b1.tar.gz
Algorithm Hash digest
SHA256 631c67bd070d7e72ee39eb60e884f1d3fee0c55e3be5fdc70a60f823f35c6e94
MD5 699c1f7a2c0ddfda2b5eb2fe708dc3b6
BLAKE2b-256 7a8fdea6401ac827ce001ee96b5456cb922f599c0376eb730d43fbf137adfa74

See more details on using hashes here.

File details

Details for the file jax_ai_stack-2024.10.1b1-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_ai_stack-2024.10.1b1-py3-none-any.whl
Algorithm Hash digest
SHA256 dbbee8847a81b4925fe42c1e30db8b0a7c6bdf0913f40c2a2cffb5bc27eec5cd
MD5 95d2f859d2723bc6023c81638017393d
BLAKE2b-256 b9503d89db8d9408edffc0d854ecb5d2178b855bce42834b05eab6a3562fd7bc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page