Skip to main content

Dataclasses that behave like numpy arrays (with indexing, slicing, vectorization).

Project description

Dataclass Array

Unittests PyPI version

DataclassArray are dataclasses which behave like numpy-like arrays (can be batched, reshaped, sliced,...), compatible with Jax, TensorFlow, and numpy (with torch support planned).

This reduce boilerplate and improve readability. See the motivating examples section bellow.

To view an example of dataclass arrays used in practice, see visu3d.

Documentation

Definition

To create a dca.DataclassArray, take a frozen dataclass and:

  • Inherit from dca.DataclassArray
  • Annotate the fields with dataclass_array.typing to specify the inner shape and dtype of the array (see below for static or nested dataclass fields). The array types are an alias from etils.array_types.
import dataclass_array as dca
from dataclass_array.typing import FloatArray


@dataclasses.dataclass(frozen=True)
class Ray(dca.DataclassArray):
  pos: FloatArray['*batch_shape 3']
  dir: FloatArray['*batch_shape 3']

Usage

Afterwards, the dataclass can be used as a numpy array:

ray = Ray(pos=jnp.zeros((3, 3)), dir=jnp.eye(3))


ray.shape == (3,)  # 3 rays batched together
ray.pos.shape == (3, 3)  # Individual fields still available

# Numpy slicing/indexing/masking
ray = ray[..., 1:2]
ray = ray[norm(ray.dir) > 1e-7]

# Shape transformation
ray = ray.reshape((1, 3))
ray = ray.reshape('h w -> w h')  # Native einops support
ray = ray.flatten()

# Stack multiple dataclass arrays together
ray = dca.stack([ray0, ray1, ...])

# Supports TF, Jax, Numpy (torch planned) and can be easily converted
ray = ray.as_jax()  # as_np(), as_tf()
ray.xnp == jax.numpy  # `numpy`, `jax.numpy`, `tf.experimental.numpy`

# Compatibility `with jax.tree_util`, `jax.vmap`,..
ray = jax.tree_util.tree_map(lambda x: x+1, ray)

A DataclassArray has 2 types of fields:

  • Array fields: Fields batched like numpy arrays, with reshape, slicing,... Can be xnp.ndarray or nested dca.DataclassArray.
  • Static fields: Other non-numpy field. Are not modified by reshaping,... Static fields are also ignored in jax.tree_map.
@dataclasses.dataclass(frozen=True)
class MyArray(dca.DataclassArray):
  # Array fields
  a: FloatArray['*batch_shape 3']  # Defined by `etils.array_types`
  b: FloatArray['*batch_shape _ _']  # Dynamic shape
  c: Ray  # Nested DataclassArray (equivalent to `Ray['*batch_shape']`)
  d: Ray['*batch_shape 6']

  # Array fields explicitly defined
  e: Any = dca.field(shape=(3,), dtype=np.float32)
  f: Any = dca.field(shape=(None,  None), dtype=np.float32)  # Dynamic shape
  g: Ray = dca.field(shape=(3,), dtype=Ray)  # Nested DataclassArray

  # Static field (everything not defined as above)
  static0: float
  static1: np.array

Vectorization

@dca.vectorize_method allow your dataclass method to automatically support batching:

  1. Implement method as if self.shape == ()
  2. Decorate the method with dca.vectorize_method
@dataclasses.dataclass(frozen=True)
class Camera(dca.DataclassArray):
  K: FloatArray['*batch_shape 4 4']
  resolution = tuple[int, int]

  @dca.vectorize_method
  def rays(self) -> Ray:
    # Inside `@dca.vectorize_method` shape is always guarantee to be `()`
    assert self.shape == ()
    assert self.K.shape == (4, 4)

    # Compute the ray as if there was only a single camera
    return Ray(pos=..., dir=...)

Afterward, we can generate rays for multiple camera batched together:

cams = Camera(K=K)  # K.shape == (num_cams, 4, 4)
rays = cams.rays()  # Generate the rays for all the cameras

cams.shape == (num_cams,)
rays.shape == (num_cams, h, w)

@dca.vectorize_method is similar to jax.vmap but:

  • Only work on dca.DataclassArray methods
  • Instead of vectorizing a single axis, @dca.vectorize_method will vectorize over *self.shape (not just self.shape[0]). This is like if vmap was applied to self.flatten()
  • When multiple arguments, axis with dimension 1 are brodcasted.

For example, with __matmul__(self, x: T) -> T:

() @ (*x,) -> (*x,)
(b,) @ (b, *x) -> (b, *x)
(b,) @ (1, *x) -> (b, *x)
(1,) @ (b, *x) -> (b, *x)
(b, h, w) @ (b, h, w, *x) -> (b, h, w, *x)
(1, h, w) @ (b, 1, 1, *x) -> (b, h, w, *x)
(a, *x) @ (b, *x) -> Error: Incompatible a != b

To test on Colab, see the visu3d dataclass Colab tutorial.

Motivating examples

dca.DataclassArray improve readability by simplifying common patterns:

  • Reshaping all fields of a dataclass:

    Before (rays is simple dataclass):

    num_rays = math.prod(rays.origins.shape[:-1])
    rays = jax.tree_map(lambda r: r.reshape((num_rays, -1)), rays)
    

    After (rays is DataclassArray):

    rays = rays.flatten()  # (b, h, w) -> (b*h*w,)
    
  • Rendering a video:

    Before (cams: list[Camera]):

    img = cams[0].render(scene)
    imgs = np.stack([cam.render(scene) for cam in cams[::2]])
    imgs = np.stack([cam.render(scene) for cam in cams])
    

    After (cams: Camera with cams.shape == (num_cams,)):

    img = cams[0].render(scene)  # Render only the first camera (to debug)
    imgs = cams[::2].render(scene)  # Render 1/2 frames (for quicker iteration)
    imgs = cams.render(scene)  # Render all cameras at once
    

Installation

pip install dataclass_array

This is not an official Google product

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dataclass_array-1.2.1.tar.gz (38.1 kB view details)

Uploaded Source

Built Distribution

dataclass_array-1.2.1-py3-none-any.whl (52.0 kB view details)

Uploaded Python 3

File details

Details for the file dataclass_array-1.2.1.tar.gz.

File metadata

  • Download URL: dataclass_array-1.2.1.tar.gz
  • Upload date:
  • Size: 38.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.8

File hashes

Hashes for dataclass_array-1.2.1.tar.gz
Algorithm Hash digest
SHA256 365cb348db40cb073ea8a45ff4c35e993b46bd3d4ff5e997b55b118807e426f3
MD5 d9cc794b042ebbd837e8eec0fed64d6e
BLAKE2b-256 cf5d75700cad107b9d3f48f288790303709d83d61b028333d31729949d01f5fa

See more details on using hashes here.

File details

Details for the file dataclass_array-1.2.1-py3-none-any.whl.

File metadata

File hashes

Hashes for dataclass_array-1.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 0a9f852ee7d5c0f21f7981962d5ef2a06e5215be41e956ba694e66f4cb879160
MD5 b5e27303588f3bb41e9e525766106d64
BLAKE2b-256 eb6d3df9c46365dd2fc3bb18cd6902f8f614a08face0fee0c33805869d0ad055

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page