6 projects
optax
A gradient processing and optimization library in JAX.
chex
Chex: Testing made fun, in JAX!
dm-clrs
The CLRS Algorithmic Reasoning Benchmark.
dm-pix
PIX is an image processing library in JAX, for JAX.
distrax
Distrax: Probability distributions in JAX.
rlax
A library of reinforcement learning building blocks in JAX.