Skip to main content

Quantities in JAX

Project description

unxt

Unitful Quantities in JAX

Unxt is unitful quantities and calculations in JAX, built on Equinox and Quax.

Yes, it supports auto-differentiation (grad, jacobian, hessian) and vectorization (vmap, etc).

Installation

PyPI platforms PyPI version

pip install unxt

Documentation

Documentation Status

Coming soon. In the meantime, if you've used astropy.units, then unxt should be familiar!

Quick example

from unxt import Quantity

x = Quantity(jnp.arange(1, 5, dtype=float), "kpc")
print(x)
# Quantity['length'](Array([1., 2., 3., 4.], dtype=float64), unit='kpc')

# Addition / Subtraction
print(x + x)
# Quantity['length'](Array([2., 4., 6., 8.], dtype=float64), unit='kpc')

# Multiplication / Division
print(2 * x)
# Quantity['length'](Array([2., 4., 6., 8.], dtype=float64), unit='kpc')

y = Quantity(jnp.arange(4, 8, dtype=float), "Gyr")

print(x / y)
# Quantity['speed'](Array([0.25      , 0.4       , 0.5       , 0.57142857], dtype=float64), unit='kpc / Gyr')

# Exponentiation
print(x**2)
# Quantity['area'](Array([0., 1., 4., 9.], dtype=float64), unit='kpc2')

# Unit Checking on operations
try:
    x + y
except Exception as e:
    print(e)
# 'Gyr' (time) and 'kpc' (length) are not convertible

unxt is built on quax, which enables custom array-ish objects in JAX. For convenience we use the quaxed library, which is just a quax wrapper around jax to avoid boilerplate wrappers.

from quaxed import grad, vmap
import quaxed.numpy as jnp

print(jnp.square(x))
# Quantity['area'](Array([ 1.,  4.,  9., 16.], dtype=float64), unit='kpc2')

print(qnp.power(x, 3))
# Quantity['volume'](Array([ 1.,  8., 27., 64.], dtype=float64), unit='kpc3')

print(vmap(grad(lambda x: x**3))(x))
# Quantity['area'](Array([ 3., 12., 27., 48.], dtype=float64), unit='kpc2')

Since Quantity is parametric, it can do runtime dimension checking!

LengthQuantity = Quantity["length"]
print(LengthQuantity(2, "km"))
# Quantity['length'](Array(2, dtype=int64, weak_type=True), unit='km')

try:
    LengthQuantity(2, "s")
except ValueError as e:
    print(e)
# Physical type mismatch.

Citation

DOI

If you found this library to be useful in academic work, then please cite.

Development

Actions Status

We welcome contributions!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

unxt-0.11.1.tar.gz (43.5 kB view details)

Uploaded Source

Built Distribution

unxt-0.11.1-py3-none-any.whl (38.1 kB view details)

Uploaded Python 3

File details

Details for the file unxt-0.11.1.tar.gz.

File metadata

  • Download URL: unxt-0.11.1.tar.gz
  • Upload date:
  • Size: 43.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.0 CPython/3.12.4

File hashes

Hashes for unxt-0.11.1.tar.gz
Algorithm Hash digest
SHA256 4fdefd6f69d197dc925ba79591b4963887e70b31ab20dbf55fdb76e1d5939c30
MD5 ff8f5a4bee1f6a34b374a4ceea2803e9
BLAKE2b-256 d9d01277d0b0fad0f4d93fdfd1f668b1b75f67ef6176f454d3557119f7055635

See more details on using hashes here.

File details

Details for the file unxt-0.11.1-py3-none-any.whl.

File metadata

  • Download URL: unxt-0.11.1-py3-none-any.whl
  • Upload date:
  • Size: 38.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.0 CPython/3.12.4

File hashes

Hashes for unxt-0.11.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4dd0b43cfa0c6e49b1d142cc1fc76809145c7b7b210270cae18a4cc1f9a46bba
MD5 54bbfb21a4ce1085aab185824e567682
BLAKE2b-256 c2eda4beadb7b3b97e95bd082c2bd7ce9906138c6fbc6edc4cee9824afd79629

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page