Quantities in JAX
Project description
unxt
Unitful Quantities in JAX
Unxt is unitful quantities and calculations in JAX, built on Equinox and Quax.
Yes, it supports auto-differentiation (grad
, jacobian
, hessian
) and
vectorization (vmap
, etc).
Installation
pip install unxt
Documentation
Coming soon. In the meantime, if you've used astropy.units
, then unxt
should
be familiar!
Quick example
from unxt import Quantity
x = Quantity(jnp.arange(1, 5, dtype=float), "kpc")
print(x)
# Quantity['length'](Array([1., 2., 3., 4.], dtype=float64), unit='kpc')
# Addition / Subtraction
print(x + x)
# Quantity['length'](Array([2., 4., 6., 8.], dtype=float64), unit='kpc')
# Multiplication / Division
print(2 * x)
# Quantity['length'](Array([2., 4., 6., 8.], dtype=float64), unit='kpc')
y = Quantity(jnp.arange(4, 8, dtype=float), "Gyr")
print(x / y)
# Quantity['speed'](Array([0.25 , 0.4 , 0.5 , 0.57142857], dtype=float64), unit='kpc / Gyr')
# Exponentiation
print(x**2)
# Quantity['area'](Array([0., 1., 4., 9.], dtype=float64), unit='kpc2')
# Unit Checking on operations
try:
x + y
except Exception as e:
print(e)
# 'Gyr' (time) and 'kpc' (length) are not convertible
unxt
is built on quax
, which
enables custom array-ish objects in JAX. For convenience we use the
quaxed
library, which is just a quax
wrapper around jax
to avoid boilerplate wrappers.
from quaxed import grad, vmap
import quaxed.numpy as jnp
print(jnp.square(x))
# Quantity['area'](Array([ 1., 4., 9., 16.], dtype=float64), unit='kpc2')
print(qnp.power(x, 3))
# Quantity['volume'](Array([ 1., 8., 27., 64.], dtype=float64), unit='kpc3')
print(vmap(grad(lambda x: x**3))(x))
# Quantity['area'](Array([ 3., 12., 27., 48.], dtype=float64), unit='kpc2')
Since Quantity
is parametric, it can do runtime dimension checking!
LengthQuantity = Quantity["length"]
print(LengthQuantity(2, "km"))
# Quantity['length'](Array(2, dtype=int64, weak_type=True), unit='km')
try:
LengthQuantity(2, "s")
except ValueError as e:
print(e)
# Physical type mismatch.
Citation
If you found this library to be useful in academic work, then please cite.
Development
We welcome contributions!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file unxt-0.11.1.tar.gz
.
File metadata
- Download URL: unxt-0.11.1.tar.gz
- Upload date:
- Size: 43.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.0 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4fdefd6f69d197dc925ba79591b4963887e70b31ab20dbf55fdb76e1d5939c30 |
|
MD5 | ff8f5a4bee1f6a34b374a4ceea2803e9 |
|
BLAKE2b-256 | d9d01277d0b0fad0f4d93fdfd1f668b1b75f67ef6176f454d3557119f7055635 |
File details
Details for the file unxt-0.11.1-py3-none-any.whl
.
File metadata
- Download URL: unxt-0.11.1-py3-none-any.whl
- Upload date:
- Size: 38.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.0 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4dd0b43cfa0c6e49b1d142cc1fc76809145c7b7b210270cae18a4cc1f9a46bba |
|
MD5 | 54bbfb21a4ce1085aab185824e567682 |
|
BLAKE2b-256 | c2eda4beadb7b3b97e95bd082c2bd7ce9906138c6fbc6edc4cee9824afd79629 |