Skip to main content

A torch implementation of GFlowNets

Project description

Python Documentation Status

Documentation ~ Code ~ Paper

torchgfn: a Python package for GFlowNets

Please cite this paper if you are using the library for your research

Installing the package

The codebase requires python >= 3.10

To install the latest stable version:

pip install torchgfn

Optionally, to run scripts:

pip install torchgfn[scripts]

To install the cutting edge version (from the main branch):

git clone https://github.com/saleml/torchgfn.git
conda create -n gfn python=3.10
conda activate gfn
cd torchgfn
pip install .

About this repo

This repo serves the purpose of fast prototyping GFlowNet related algorithms. It decouples the environment definition, the sampling process, and the parametrization of the function approximators used to calculate the GFN loss.

Example scripts and notebooks are provided here.

Standalone example

This example, which shows how to use the library for a simple discrete environment, requires tqdm package to run. Use pip install tqdm or install all extra requirements with pip install .[scripts] or pip install torchgfn[scripts].

import torch
from tqdm import tqdm

from gfn.gflownet import TBGFlowNet
from gfn.gym import HyperGrid
from gfn.modules import DiscretePolicyEstimator
from gfn.samplers import Sampler
from gfn.utils import NeuralNet

if __name__ == "__main__":

    env = HyperGrid(ndim=4, height=8, R0=0.01)  # Grid of size 8x8x8x8

    module_PF = NeuralNet(
        input_dim=env.preprocessor.output_dim,
        output_dim=env.n_actions
    )
    module_PB = NeuralNet(
        input_dim=env.preprocessor.output_dim,
        output_dim=env.n_actions - 1,
        torso=module_PF.torso
    )

    pf_estimator = DiscretePolicyEstimator(env, module_PF, forward=True)
    pb_estimator = DiscretePolicyEstimator(env, module_PB, forward=False)

    gfn = TBGFlowNet(init_logZ=0., pf=pf_estimator, pb=pb_estimator)

    sampler = Sampler(estimator=pf_estimator))

    # Policy parameters have their own LR.
    non_logz_params = [v for k, v in dict(gfn.named_parameters()).items() if k != "logZ"]
    optimizer = torch.optim.Adam(non_logz_params, lr=1e-3)

    # Log Z gets dedicated learning rate (typically higher).
    logz_params = [dict(gfn.named_parameters())["logZ"]]
    optimizer.add_param_group({"params": logz_params, "lr": 1e-2})

    for i in (pbar := tqdm(range(1000))):
        trajectories = sampler.sample_trajectories(n_trajectories=16)
        optimizer.zero_grad()
        loss = gfn.loss(trajectories)
        loss.backward()
        optimizer.step()
        if i % 25 == 0:
            pbar.set_postfix({"loss": loss.item()})

Contributing

Before the first commit:

pip install -e .[dev,scripts]
pre-commit install
pre-commit run --all-files

Run pre-commit after staging, and before committing. Make sure all the tests pass (By running pytest). The codebase uses black formatter.

To make the docs locally:

cd docs
make html
open build/html/index.html

Details about the codebase

Defining an environment

See here

States

States are the primitive building blocks for GFlowNet objects such as transitions and trajectories, on which losses operate.

An abstract States class is provided. But for each environment, a States subclass is needed. A States object is a collection of multiple states (nodes of the DAG). A tensor representation of the states is required for batching. If a state is represented with a tensor of shape (*state_shape), a batch of states is represented with a States object, with the attribute tensor of shape (*batch_shape, *state_shape). Other representations are possible (e.g. a state as a string, a numpy array, a graph, etc...), but these representations cannot be batched, unless the user specifies a function that transforms these raw states to tensors.

The batch_shape attribute is required to keep track of the batch dimension. A trajectory can be represented by a States object with batch_shape = (n_states,). Multiple trajectories can be represented by a States object with batch_shape = (n_states, n_trajectories).

Because multiple trajectories can have different lengths, batching requires appending a dummy tensor to trajectories that are shorter than the longest trajectory. The dummy state is the $s_f$ attribute of the environment (e.g. [-1, ..., -1], or [-inf, ..., -inf], etc...). Which is never processed, and is used to pad the batch of states only.

For discrete environments, the action set is represented with the set ${0, \dots, n_{actions} - 1}$, where the $(n_{actions})$-th action always corresponds to the exit or terminate action, i.e. that results in a transition of the type $s \rightarrow s_f$, but not all actions are possible at all states. Each States object is endowed with two extra attributes: forward_masks and backward_masks, representing which actions are allowed at each state and which actions could have led to each state, respectively. Such states are instances of the DiscreteStates abstract subclass of States. The forward_masks tensor is of shape (*batch_shape, n_{actions}), and backward_masks is of shape (*batch_shape, n_{actions} - 1). Each subclass of DiscreteStates needs to implement the update_masks function, that uses the environment's logic to define the two tensors.

Actions

Actions should be though of as internal actions of an agent building a compositional object. They correspond to transitions $s \rightarrow s'$. An abstract Actions class is provided. It is automatically subclassed for discrete environments, but needs to be manually subclassed otherwise.

Similar to States objects, each action is a tensor of shape (*batch_shape, *action_shape). For discrete environments for instances, action_shape = (1,), representing an integer between $0$ and $n_{actions} - 1$.

Additionally, each subclass needs to define two more class variable tensors:

  • dummy_action: A tensor that is padded to sequences of actions in the shorter trajectories of a batch of trajectories. It is [-1] for discrete environments.
  • exit_action: A tensor that corresponds to the termination action. It is [n_{actions} - 1] fo discrete environments.

Containers

Containers are collections of States, along with other information, such as reward values, or densities $p(s' \mid s)$. Two containers are available:

  • Transitions, representing a batch of transitions $s \rightarrow s'$.
  • Trajectories, representing a batch of complete trajectories $\tau = s_0 \rightarrow s_1 \rightarrow \dots \rightarrow s_n \rightarrow s_f$.

These containers can either be instantiated using a States object, or can be initialized as empty containers that can be populated on the fly, allowing the usage of theReplayBuffer class.

They inherit from the base Container class, indicating some helpful methods.

In most cases, one needs to sample complete trajectories. From a batch of trajectories, a batch of states and batch of transitions can be defined using Trajectories.to_transitions() and Trajectories.to_states(). These exclude meaningless transitions and dummy states that were added to the batch of trajectories to allow for efficient batching.

Modules

Training GFlowNets requires one or multiple estimators, called GFNModules, which is an abstract subclass of torch.nn.Module. In addition to the usual forward function, GFNModules need to implement a required_output_dim attribute, to ensure that the outputs have the required dimension for the task at hand; and some (but not all) need to implement a to_probability_distribution function. They take the environment env as an input at initialization.

  • DiscretePolicyEstimator is a GFNModule that defines the policies $P_F(. \mid s)$ and $P_B(. \mid s)$ for discrete environments. When backward=False, the required output dimension is n = env.n_actions, and when backward=True, it is n = env.n_actions - 1. These n numbers represent the logits of a Categorical distribution. Additionally, they include exploration parameters, in order to define a tempered version of $P_F$, or a mixture of $P_F$ with a uniform distribution. Naturally, before defining the Categorical distributions, forbidden actions (that are encoded in the DiscreteStates' masks attributes), are given 0 probability by setting the corresponding logit to $-\infty$.
  • ScalarModule is a simple module with required output dimension 1. It is useful to define log-state flows $\log F(s)$.

For non-discrete environments, the user needs to specify their own policies $P_F$ and $P_B$. The module, taking as input a batch of states (as a States) object, should return the batched parameters of a torch.Distribution. The distribution depends on the environment. The to_probability_distribution function handles the conversion of the parameter outputs to an actual batched Distribution object, that implements at least the sample and log_prob functions. An example is provided here, for a square environment in which the forward policy has support either on a quarter disk, or on an arc-circle, such that the angle, and the radius (for the quarter disk part) are scaled samples from a mixture of Beta distributions. The provided example shows an intricate scenario, and it is not expected that user defined environment need this much level of details.

In all GFNModules, note that the input of the forward function is a States object. Meaning that they first need to be transformed to tensors. However, states.tensor does not necessarily include the structure that a neural network can used to generalize. It is common in these scenarios to have a function that transforms these raw tensor states to ones where the structure is clearer, via a Preprocessor object, that is part of the environment. More on this here. The default preprocessor of an environment is the identity preprocessor. The forward pass thus first calls the preprocessor attribute of the environment on States, before performing any transformation.

For discrete environments, tabular modules are provided, where a lookup table is used instead of a neural network. Additionally, a UniformPB module is provided, implementing a uniform backward policy.

Samplers

A Sampler object defines how actions are sampled (sample_actions()) at each state, and trajectories (sample_trajectories()), which can sample a batch of trajectories starting from a given set of initial states or starting from $s_0$. It requires a GFNModule that implements the to_probability_distribution function.

Losses

GFlowNets can be trained with different losses, each of which requires a different parametrization, which we call in this library a GFlowNet. A GFlowNet is a GFNModule that includes one or multiple GFNModules, at least one of which implements a to_probability_distribution function. They also need to implement a loss function, that takes as input either states, transitions, or trajectories, depending on the loss.

Currently, the implemented losses are:

  • Flow Matching
  • Detailed Balance (and it's modified variant).
  • Trajectory Balance
  • Sub-Trajectory Balance. By default, each sub-trajectory is weighted geometrically (within the trajectory) depending on its length. This corresponds to the strategy defined here. Other strategies exist and are implemented here.
  • Log Partition Variance loss. Introduced here

Scripts

Example scripts are provided here. They can be used to reproduce published results in the HyperGrid environment, and the Box environment.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchgfn-1.0.0.tar.gz (49.4 kB view details)

Uploaded Source

Built Distribution

torchgfn-1.0.0-py3-none-any.whl (57.5 kB view details)

Uploaded Python 3

File details

Details for the file torchgfn-1.0.0.tar.gz.

File metadata

  • Download URL: torchgfn-1.0.0.tar.gz
  • Upload date:
  • Size: 49.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.2 CPython/3.10.10 Linux/5.15.90.1-microsoft-standard-WSL2

File hashes

Hashes for torchgfn-1.0.0.tar.gz
Algorithm Hash digest
SHA256 3437e2b23fc7f425d4fca54455392fd7753151937ba2ccab1213f29f60fd9d0c
MD5 702538a7517e3b3effbf62574720804a
BLAKE2b-256 bd5786fb2ebf32cbbf293cfcca4b38a669721718dcf9835a2663cc2eb444e34c

See more details on using hashes here.

File details

Details for the file torchgfn-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: torchgfn-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 57.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.2 CPython/3.10.10 Linux/5.15.90.1-microsoft-standard-WSL2

File hashes

Hashes for torchgfn-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 07925065776343df50e8cdc1b0239023793c18d24c5207e2d0baa1b86b64c470
MD5 39a99b5ca1b0d467e2801b4a10033017
BLAKE2b-256 6d33f812cc226012c58b0e17f1041b5e80ab4a6121cb47c91fd823b5d1ffa619

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page