JAX + Units
Project description
JAX + Units
This module provides an interface between JAX and Pint to allow JAX to support operations with units. The propagation of units happens at trace time, so jitted functions should see no runtime cost. This library is experimental so expect some sharp edges.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> import jpu
>>>
>>> u = jpu.UnitRegistry()
>>>
>>> @jax.jit
... def add_two_lengths(a, b):
... return a + b
...
>>> add_two_lengths(3 * u.m, jnp.array([4.5, 1.2, 3.9]) * u.cm)
<Quantity([3.045 3.012 3.039], 'meter')>
Installation
To install, use pip
:
python -m pip install jpu
The only dependencies are jax
and pint
, and these will also be installed, if
not already in your environment. Take a look at the JAX docs for more
information about installing JAX on different
systems.
Usage
Here is a slightly more complete example:
>>> import jax
>>> import numpy as np
>>> from jpu import UnitRegistry, numpy as jnpu
>>>
>>> u = UnitRegistry()
>>>
>>> @jax.jit
... def projectile_motion(v_init, theta, time, g=u.standard_gravity):
... """Compute the motion of a projectile with support for units"""
... x = v_init * time * jnpu.cos(theta)
... y = v_init * time * jnpu.sin(theta) - 0.5 * g * jnpu.square(time)
... return x.to(u.m), y.to(u.m)
...
>>> x, y = projectile_motion(
... 5.0 * u.km / u.h, 60 * u.deg, np.linspace(0, 1, 50) * u.s
... )
Technical details & limitations
The most significant limitation of this library is the fact that users must use
jpu.numpy
functions when interacting with "quantities" with units instead of
the jax.numpy
interface. This is because JAX does not (yet?) provide a general
interface for dispatching of ufuncs on custom array classes. I have played
around with the undocumented __jax_array__
interface, but it's not really
flexible enough, and it isn't currently compatible with Pytree objects.
So far, only a subset of the numpy
/jax.numpy
interface is implemented. Pull
requests adding broader support (including submodules) would be welcome!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file jpu-0.0.4.tar.gz
.
File metadata
- Download URL: jpu-0.0.4.tar.gz
- Upload date:
- Size: 19.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 03314c23504bec25bf95142d3684da5d1e962e2ead7067cd96819f5856bb6d57 |
|
MD5 | 28ef83e5e67571bc5baf18bfeada07ce |
|
BLAKE2b-256 | 7214e28417860c57092f62ff4bd56e5d3f9284ce9488f2f8893fa0db0da7ec50 |
File details
Details for the file jpu-0.0.4-py3-none-any.whl
.
File metadata
- Download URL: jpu-0.0.4-py3-none-any.whl
- Upload date:
- Size: 15.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 805669c2544130863f90b3031ac62d9d1eaa5081dc349c055799cc6e638750a9 |
|
MD5 | eb3c06b8ea3c2fcf4f99f78ed069e379 |
|
BLAKE2b-256 | c4488cf91fc33e11b340fc21e0bcadee1c7cfb4113a275bc972dde1a2a28eb76 |