Skip to main content

Penzai: A JAX research toolkit for building, editing, and visualizing neural networks.

Project description

Penzai

盆 ("pen", tray) 栽 ("zai", planting) - an ancient Chinese art of forming trees and landscapes in miniature, also called penjing and an ancestor of the Japanese art of bonsai.

Penzai is a JAX library for writing models as legible, functional pytree data structures, along with tools for visualizing, modifying, and analyzing them. Penzai focuses on making it easy to do stuff with models after they have been trained, making it a great choice for research involving reverse-engineering or ablating model components, inspecting and probing internal activations, performing model surgery, debugging architectures, and more. (But if you just want to build and train a model, you can do that too!)

With Penzai, your neural networks could look like this:

Screenshot of the Gemma model in Penzai

Penzai is structured as a collection of modular tools, designed together but each useable independently:

  • A superpowered interactive Python pretty-printer:

    • Treescope (pz.ts): A drop-in replacement for the ordinary IPython/Colab renderer, originally a part of Penzai but now available as a standalone package. It's designed to help understand Penzai models and other deeply-nested JAX pytrees, with built-in support for visualizing arbitrary-dimensional NDArrays.
  • A set of JAX tree and array manipulation utilities:

    • penzai.core.selectors (pz.select): A pytree swiss-army-knife, generalizing JAX's .at[...].set(...) syntax to arbitrary type-driven pytree traversals, and making it easy to do complex rewrites or on-the-fly patching of Penzai models and other data structures.

    • penzai.core.named_axes (pz.nx): A lightweight named axis system which lifts ordinary JAX functions to vectorize over named axes, and allows you to seamlessly switch between named and positional programming styles without having to learn a new array API.

  • A declarative combinator-based neural network library, where models are represented as easy-to-modify data structures:

    • penzai.nn (pz.nn): An alternative to other neural network libraries like Flax, Haiku, Keras, or Equinox, which exposes the full structure of your model's forward pass using declarative combinators. Like Equinox, models are represented as JAX PyTrees, which means you can see everything your model does by pretty printing it, and inject new runtime logic with jax.tree_util. However, penzai.nn models may also contain mutable variables at the leaves of the tree, allowing them to keep track of mutable state and parameter sharing.
  • A modular implementation of common Transformer architectures, to support research into interpretability, model surgery, and training dynamics:

    • penzai.models.transformer: A reference Transformer implementation that can load the pre-trained weights for the Gemma, Llama, Mistral, and GPT-NeoX / Pythia architectures. Built using modular components and named axes, to simplify complex model-manipulation workflows.

Documentation on Penzai can be found at https://penzai.readthedocs.io.

[!IMPORTANT] Penzai 0.2 includes a number of breaking changes to the neural network API. These changes are intended to simplify common workflows by introducing first-class support for mutable state and parameter sharing and removing unnecessary boilerplate. You can read about the differences between the old "V1" API and the current "V2" API in the "Changes in the V2 API" overview.

If you are currently using the V1 API and have not yet converted to the V2 system, you can instead keep the old behavior by importing from the penzai.deprecated.v1 submodule, e.g. ::

from penzai.deprecated.v1 import pz
from penzai.deprecated.v1.example_models import simple_mlp

Getting Started

If you haven't already installed JAX, you should do that first, since the installation process depends on your platform. You can find instructions in the JAX documentation. Afterward, you can install Penzai using

pip install penzai

and import it using

import penzai
from penzai import pz

(penzai.pz is an alias namespace, which makes it easier to reference common Penzai objects.)

When working in an Colab or IPython notebook, we recommend also configuring Treescope (Penzai's companion pretty-printer) as the default pretty printer, and enabling some utilities for interactive use:

import treescope
treescope.basic_interactive_setup(autovisualize_arrays=True)

Here's how you could initialize and visualize a simple neural network:

from penzai.models import simple_mlp
mlp = simple_mlp.MLP.from_config(
    name="mlp",
    init_base_rng=jax.random.key(0),
    feature_sizes=[8, 32, 32, 8]
)

# Models and arrays are visualized automatically when you output them from a
# Colab/IPython notebook cell:
mlp

Here's how you could capture and extract the activations after the elementwise nonlinearities:

@pz.pytree_dataclass
class AppendIntermediate(pz.nn.Layer):
  saved: pz.StateVariable[list[Any]]
  def __call__(self, x: Any, **unused_side_inputs) -> Any:
    self.saved.value = self.saved.value + [x]
    return x

var = pz.StateVariable(value=[], label="my_intermediates")

# Make a copy of the model that saves its activations:
saving_model = (
    pz.select(mlp)
    .at_instances_of(pz.nn.Elementwise)
    .insert_after(AppendIntermediate(var))
)

output = saving_model(pz.nx.ones({"features": 8}))
intermediates = var.value

To learn more about how to build and manipulate neural networks with Penzai, we recommend starting with the "How to Think in Penzai" tutorial or one of the other tutorials in the Penzai documentation.

Citation

If you have found Penzai to be useful for your research, please consider citing the following writeup (also avaliable on arXiv):

@article{johnson2024penzai,
    author={Daniel D. Johnson},
    title={{Penzai} + {Treescope}: A Toolkit for Interpreting, Visualizing, and Editing Models As Data},
    year={2024},
    journal={ICML 2024 Workshop on Mechanistic Interpretability}
}

This is not an officially supported Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

penzai-0.2.2.tar.gz (36.6 MB view details)

Uploaded Source

Built Distribution

penzai-0.2.2-py3-none-any.whl (314.5 kB view details)

Uploaded Python 3

File details

Details for the file penzai-0.2.2.tar.gz.

File metadata

  • Download URL: penzai-0.2.2.tar.gz
  • Upload date:
  • Size: 36.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.9

File hashes

Hashes for penzai-0.2.2.tar.gz
Algorithm Hash digest
SHA256 f08b1c7151ea07dfe80b99abc5c749942fb0da112ed2ba82d5588ec255f8e8be
MD5 19853b630c723588e956eee6ebf0ddbc
BLAKE2b-256 076148ce2a1d8a16a4b778837981bed648cb9ea4aefb7b91bc6490b405a17b87

See more details on using hashes here.

Provenance

File details

Details for the file penzai-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: penzai-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 314.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.9

File hashes

Hashes for penzai-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 387caf4a0af4067658528ae931648421b8fe2d62d0b32eb0c9254b2d95771e7e
MD5 c7d5c9a13c9abc0b415acbcd5ff57873
BLAKE2b-256 1189c96aa7941afceddef928e0b3ffe99c2eb25b0012b164dc5234fc049dd3aa

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page