Skip to main content

A JAX + AstroPy units mashup.

Project description

JAX + Units

Built with JAX and Pint!

This module provides and interface between JAX and Pint to allow JAX to support operations with units. The propagation of units happens at trace time, so jitted functions should see no runtime cost. This library is experimental so expect some sharp edges.

For example:

>>> import jax
>>> import jax.numpy as jnp
>>> import jpu
>>>
>>> u = jpu.UnitRegistry()
>>>
>>> @jax.jit
... def add_two_lengths(a, b):
...     return a + b
...
>>> add_two_lengths(3 * u.m, jnp.array([4.5, 1.2, 3.9]) * u.cm)
<Quantity([3.045 3.012 3.039], 'meter')>

Installation

To install, use pip:

python -m pip install jpu

The only dependencies are jax and pint, and these will also be installed, if not already in your environment. Take a look at the JAX docs for more information about installing JAX on different systems.

Usage

Here is a slightly more complete example:

>>> import jax
>>> import numpy as np
>>> from jpu import UnitRegistry, numpy as jnpu
>>>
>>> u = UnitRegistry()
>>>
>>> @jax.jit
... def projectile_motion(v_init, theta, time, g=u.standard_gravity):
...     """Compute the motion of a projectile with support for units"""
...     x = v_init * time * jnpu.cos(theta)
...     y = v_init * time * jnpu.sin(theta) - 0.5 * g * jnpu.square(time)
...     return x.to(u.m), y.to(u.m)
...
>>> x, y = projectile_motion(
...     5.0 * u.km / u.h, 60 * u.deg, np.linspace(0, 1, 50) * u.s
... )
>>> x[:3]
<Quantity([0.         0.01417234 0.02834467], 'meter')>

Technical details & limitations

The most significant limitation of this library is the fact that users must use jpu.numpy functions when interacting with "quantities" with units instead of the jax.numpy interface. This is because JAX does not (yet?) provide a general interface for dispatching of ufuncs on custom array classes. I have played around with the undocumented __jax_array__ interface, but it's not really flexible enough, and it isn't currently compatible with Pytree objects.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jpu-0.0.1rc1.tar.gz (14.1 kB view details)

Uploaded Source

Built Distribution

jpu-0.0.1rc1-py3-none-any.whl (8.2 kB view details)

Uploaded Python 3

File details

Details for the file jpu-0.0.1rc1.tar.gz.

File metadata

  • Download URL: jpu-0.0.1rc1.tar.gz
  • Upload date:
  • Size: 14.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.13

File hashes

Hashes for jpu-0.0.1rc1.tar.gz
Algorithm Hash digest
SHA256 26ddda223df8e3727e8d4ac4ae2a129707add1fec2ce683cde93c67a9143b994
MD5 4407f3ca27a2a434dd0bf5deaf5eb4a9
BLAKE2b-256 aef84c17a93a165ebbd6afd1e033378dbe4f43ece88560f48687e200e6c76f97

See more details on using hashes here.

File details

Details for the file jpu-0.0.1rc1-py3-none-any.whl.

File metadata

  • Download URL: jpu-0.0.1rc1-py3-none-any.whl
  • Upload date:
  • Size: 8.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.13

File hashes

Hashes for jpu-0.0.1rc1-py3-none-any.whl
Algorithm Hash digest
SHA256 d67324a05077fc0b082b2deeca255ffbb04d957254b806c9a406785064acfbd1
MD5 8bde1eb74a5c41f7c22177f0853e5193
BLAKE2b-256 2b74f1cd5f022a2c82c2f0623fbbb19b7058c6a13e1f81a55dc4888e3f7b602f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page