Skip to main content

A torch inplementation of GFlowNets

Project description

Python Documentation Status

Documentation ~ Code

gfn: a Python package for GFlowNets

Installing the packages

The codebase requires python >= 3.10

git clone https://github.com/saleml/gfn.git
conda create -n gfn python=3.10
conda activate gfn
cd gfn
pip install .

Optionally, to run scripts, and for wandb logging

pip install .[scripts]
wandb login

About this repo

This repo serves the purpose of fast prototyping GFlowNet related algorithms. It decouples the environment definition, the sampling process, and the parametrization used for the GFN loss.

An example script is provided here. To run the code, use one of the following:

python train.py --env HyperGrid --env.ndim 4 --env.height 8 --n_iterations 100000 --loss TB
python train.py --env DiscreteEBM --env.ndim 4 --env.alpha 0.5 --n_iterations 10000 --batch_size 64 --temperature 2.
python train.py --env HyperGrid --env.ndim 2 --env.height 64 --n_iterations 100000 --loss DB --replay_buffer_size 1000 --logit_PB.module_name Uniform --optim sgd --optim.lr 5e-3
python train.py --env HyperGrid --env.ndim 4 --env.height 8 --env.R0 0.01 --loss FM --optim adam --optim.lr 1e-4

Example, in a few lines

(⬇️ This example require the tqdm package to run. pip install tqdm or install all extra requirements with pip install .[scripts])

import torch
from tqdm import tqdm

from gfn import LogitPBEstimator, LogitPFEstimator, LogZEstimator
from gfn.envs import HyperGrid
from gfn.losses import TBParametrization, TrajectoryBalance
from gfn.samplers import DiscreteActionsSampler, TrajectoriesSampler

if __name__ == "__main__":

    env = HyperGrid(ndim=4, height=8, R0=0.01)  # Grid of size 8x8x8x8

    logit_PF = LogitPFEstimator(env=env, module_name="NeuralNet")
    logit_PB = LogitPBEstimator(
        env=env,
        module_name="NeuralNet",
        torso=logit_PF.module.torso,  # To share parameters between PF and PB
    )
    logZ = LogZEstimator(torch.tensor(0.0))

    parametrization = TBParametrization(logit_PF, logit_PB, logZ)

    actions_sampler = DiscreteActionsSampler(estimator=logit_PF)
    trajectories_sampler = TrajectoriesSampler(env=env, actions_sampler=actions_sampler)

    loss_fn = TrajectoryBalance(parametrization=parametrization)

    params = [
        {
            "params": [
                val for key, val in parametrization.parameters.items() if "logZ" not in key
            ],
            "lr": 0.001,
        },
        {"params": [val for key, val in parametrization.parameters.items() if "logZ" in key], "lr": 0.1},
    ]
    optimizer = torch.optim.Adam(params)

    for i in (pbar := tqdm(range(1000))):
        trajectories = trajectories_sampler.sample(n_trajectories=16)
        optimizer.zero_grad()
        loss = loss_fn(trajectories)
        loss.backward()
        optimizer.step()
        if i % 25 == 0:
            pbar.set_postfix({"loss": loss.item()})

Contributing

Before the first commit:

pip install .[dev]
pre-commit install
pre-commit run --all-files

Run pre-commit after staging, and before committing. Make sure all the tests pass (By running pytest). The codebase uses black formatter.

To make the docs locally:

cd docs
make html
open build/html/index.html

Details about the codebase

Defining an environment

A pointed DAG environment (or GFN environment, or environment for short) is a representation for the pointed DAG. The abstract class Env specifies the requirements for a valid environment definition. To obtain such a representation, the environment needs to specify the following attributes, properties, or methods:

  • The action_space. Which should be a gymnasium.spaces.Discrete object for discrete environments. The last action should correspond to the exit action.
  • The initial state s_0, as a torch.Tensor of arbitrary dimension.
  • (Optional) The sink state s_f, as a torch.Tensor of the same shape as s_0, used to represent complete trajectories only (within a batch of trajectories of different lengths), and never processed by any model. If not specified, it is set to torch.full_like(s_0, -float('inf')).
  • The method make_States_class that creates a subclass of States. The instances of the resulting class should represent a batch of states of arbitrary shape, which is useful to define a trajectory, or a batch of trajectories. s_0 and s_f, along with a tuple called state_shape should be defined as class variables, and the subclass (of States) should implement masking methods, that specify which actions are possible, in a discrete environment.
  • The methods maskless_step and maskless_backward_step that specify how an action changes a state (going forward and backward). These functions do not need to handle masking, checking whether actions are allowed, checking whether a state is the sink state, etc... These checks are handled in Env.step and Env.backward_step
  • The log_reward function that assigns a nonnegative reward to every terminating state (i.e. state with all $s_f$ as a child in the DAG). If log_reward is not implemented, reward needs to be.

If the states (as represented in the States class) need to be transformed to another format before being processed (by neural networks for example), then the environment should define a preprocessor attribute, which should be an instance of the base preprocessor class. If no preprocessor is defined, the states are used as is (actually transformed using IdentityPreprocessor, which transforms the state tensors to FloatTensors). Implementing your own preprocessor requires defining the preprocess function, and the output_shape attribute, which is a tuple representing the shape of one preprocessed state.

Optionally, you can define a static get_states_indices method that assigns a unique integer number to each state if the environment allows it, and a n_states property that returns an integer representing the number of states (excluding $s_f$) in the environment. get_terminating_states_indices can also be implemented and serves the purpose of uniquely identifying terminating states of the environment.

For more details, take a look at HyperGrid, an environment where all states are terminating states, or at DiscreteEBM, where all trajectories are of the same length but only some states are terminating.

Other containers

Besides the States class, other containers of states are available:

  • Transitions, representing a batch of transitions $s \rightarrow s'$.
  • Trajectories, representing a batch of complete trajectories $\tau = s_0 \rightarrow s_1 \rightarrow \dots \rightarrow s_n \rightarrow s_f$.

These containers can either be instantiated using a States object, or can be initialized as empty containers that can be populated on the fly, allowing the usage of ReplayBuffers.

They inherit from the base Container class, indicating some helpful methods.

In most cases, one needs to sample complete trajectories. From a batch of trajectories, a batch of states and batch of transitions can be defined using Trajectories.to_transitions() and Trajectories.to_states(). These exclude meaningless transitions and states that were added to the batch of trajectories to allow for efficient batching.

Estimators and Modules

Training GFlowNets requires one or multiple estimators. As of now, only discrete environments are handled. All estimators are subclasses of FunctionEstimator, implementing a __call__ function that takes as input a batch of States.

  • LogEdgeFlowEstimator. It outputs a (*batch_shape, n_actions) tensor representing $\log F(s \rightarrow s')$, including when $s' = s_f$.
  • LogStateFlowEstimator. It outputs a (*batch_shape, 1) tensor representing $\log F(s)$. When used with forward_looking=True, $\log F(s)$ is parametrized as the sum of a function approximator and $\log R(s)$ - which is only possible for environments where all states are terminating.
  • LogitPFEstimator. It outputs a (*batch_shape, n_actions) tensor representing $logit(s' \mid s)$, such that $P_F(s' \mid s) = softmax_{s'}\ logit(s' \mid s)$, including when $s' = s_f$.
  • LogitPBEstimator. It outputs a (*batch_shape, n_actions - 1) tensor representing $logit(s' \mid s)$, such that $P_B(s' \mid s) = softmax_{s'}\ logit(s' \mid s)$.

Defining an estimator requires the environment, and a module instance. Modules inherit from the GFNModule class, which can be seen as an extension of torch.nn.Module. Alternatively, a module is created by providing which module type to use (e.g. "NeuralNet" or "Uniform" or "Zero"). A Basic MLP is provided as the NeuralNet class, but any function approximator should be possible.

Said differently, a States object is first transformed via the environment's preprocessor to a (*batch_shape, *output_shape) float tensor. The preprocessor's output shape should match the module input shape (if any). The preprocessed states are then passed as inputs to the module, returning the desired output (either flows or probabilities over children in the DAG).

Each module has a named_parameters functions that returns a dictionary of the learnable parameters. This attribute is transferred to the corresponding estimator.

Additionally, a LogZEstimator is provided, which is a container for a scalar tensor representing $\log Z$, the log-partition function, useful for the Trajectory Balance loss for example. This estimator also has a named_parameters function.

Samplers

An ActionsSampler object defines how actions are sampled at each state of the DAG. As of now, only DiscreteActionsSamplers are implemented. The require an estimator (of $P_F$, $P_B$, or edge flows) defining the action probabilities. These estimators can contain any type of modules (including random action sampling for example). A BackwardDiscreteActionsSampler class is provided to sample parents of a state, which is helpful to sample trajectories starting from their last states.

They are at the core of TrajectoriesSamplers, which implements the sample_trajectories method, that sample a batch of trajectories starting from a given set of initial states or starting from $s_0$.

Losses

GFlowNets can be trained with different losses, each of which requires a different parametrization. A parametrization is a dataclass, which can be seen as a container of different estimators. Each parametrization defines a distribution over trajectories, via the parametrization.Pi method, and a distribution over terminating states, via the parametrization.P_T method. Both distributions should be instances of the classes defined here.

The base classes for losses and parametrizations are provided here.

Currently, the implemented losses are:

  • Flow Matching
  • Detailed Balance
  • Trajectory Balance
  • Sub-Trajectory Balance. By default, each sub-trajectory is weighted geometrically (within the trajectory) depending on its length. This corresponds to the strategy defined here. Other strategies exist and are implemented here.
  • Log Partition Variance loss. Introduced here

Solving for the flows using Dynamic Programming

A simple script that propagates trajectories rewards through the DAG to define edge flows in a deterministic way (by visiting each edge once only) is provided here. Do not use the script on large environments !

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchgfn-0.1.1.tar.gz (38.0 kB view details)

Uploaded Source

Built Distribution

torchgfn-0.1.1-py3-none-any.whl (45.6 kB view details)

Uploaded Python 3

File details

Details for the file torchgfn-0.1.1.tar.gz.

File metadata

  • Download URL: torchgfn-0.1.1.tar.gz
  • Upload date:
  • Size: 38.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.0 CPython/3.9.12 Darwin/22.4.0

File hashes

Hashes for torchgfn-0.1.1.tar.gz
Algorithm Hash digest
SHA256 e87205963bb55d3bf0a5463ad5bc0500ba0d5195accab22c4e536288c8d73311
MD5 36303b62db6975f4a7f96ed857e26353
BLAKE2b-256 7b9568c7b8b84e03721ea1ca38d1357f4baeb87732966b37b6fd27664fd7ce2d

See more details on using hashes here.

File details

Details for the file torchgfn-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: torchgfn-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 45.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.0 CPython/3.9.12 Darwin/22.4.0

File hashes

Hashes for torchgfn-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 d0e0f70c96fda8ed9f26ded155d7dfad8d48b113a3d1419566bbfb211c1f20b1
MD5 94483d06d90a67b267dd4b8e98f736c4
BLAKE2b-256 0f1c9b2d0ca70cb2d28dec67895af084a843c1818102888fb065e4e0dee18ed1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page