A torch inplementation of GFlowNets
Project description
gfn: a Python package for GFlowNets
Installing the packages
The codebase requires python >= 3.10
git clone https://github.com/saleml/gfn.git
conda create -n gfn python=3.10
conda activate gfn
cd gfn
pip install .
Optionally, to run scripts, and for wandb logging
pip install .[scripts]
wandb login
About this repo
This repo serves the purpose of fast prototyping GFlowNet related algorithms. It decouples the environment definition, the sampling process, and the parametrization used for the GFN loss.
An example script is provided here. To run the code, use one of the following:
python train.py --env HyperGrid --env.ndim 4 --env.height 8 --n_iterations 100000 --loss TB
python train.py --env DiscreteEBM --env.ndim 4 --env.alpha 0.5 --n_iterations 10000 --batch_size 64 --temperature 2.
python train.py --env HyperGrid --env.ndim 2 --env.height 64 --n_iterations 100000 --loss DB --replay_buffer_size 1000 --logit_PB.module_name Uniform --optim sgd --optim.lr 5e-3
python train.py --env HyperGrid --env.ndim 4 --env.height 8 --env.R0 0.01 --loss FM --optim adam --optim.lr 1e-4
Example, in a few lines
(⬇️ This example require the tqdm
package to run. pip install tqdm
or install all extra requirements with pip install .[scripts]
)
import torch
from tqdm import tqdm
from gfn import LogitPBEstimator, LogitPFEstimator, LogZEstimator
from gfn.envs import HyperGrid
from gfn.losses import TBParametrization, TrajectoryBalance
from gfn.samplers import DiscreteActionsSampler, TrajectoriesSampler
if __name__ == "__main__":
env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8
logit_PF = LogitPFEstimator(env=env, module_name="NeuralNet")
logit_PB = LogitPBEstimator(
env=env,
module_name="NeuralNet",
torso=logit_PF.module.torso, # To share parameters between PF and PB
)
logZ = LogZEstimator(torch.tensor(0.0))
parametrization = TBParametrization(logit_PF, logit_PB, logZ)
actions_sampler = DiscreteActionsSampler(estimator=logit_PF)
trajectories_sampler = TrajectoriesSampler(env=env, actions_sampler=actions_sampler)
loss_fn = TrajectoryBalance(parametrization=parametrization)
params = [
{
"params": [
val for key, val in parametrization.parameters.items() if "logZ" not in key
],
"lr": 0.001,
},
{"params": [val for key, val in parametrization.parameters.items() if "logZ" in key], "lr": 0.1},
]
optimizer = torch.optim.Adam(params)
for i in (pbar := tqdm(range(1000))):
trajectories = trajectories_sampler.sample(n_trajectories=16)
optimizer.zero_grad()
loss = loss_fn(trajectories)
loss.backward()
optimizer.step()
if i % 25 == 0:
pbar.set_postfix({"loss": loss.item()})
Contributing
Before the first commit:
pip install .[dev]
pre-commit install
pre-commit run --all-files
Run pre-commit
after staging, and before committing. Make sure all the tests pass (By running pytest
).
The codebase uses black
formatter.
To make the docs locally:
cd docs
make html
open build/html/index.html
Details about the codebase
Defining an environment
A pointed DAG environment (or GFN environment, or environment for short) is a representation for the pointed DAG. The abstract class Env specifies the requirements for a valid environment definition. To obtain such a representation, the environment needs to specify the following attributes, properties, or methods:
- The
action_space
. Which should be agymnasium.spaces.Discrete
object for discrete environments. The last action should correspond to the exit action. - The initial state
s_0
, as atorch.Tensor
of arbitrary dimension. - (Optional) The sink state
s_f
, as atorch.Tensor
of the same shape ass_0
, used to represent complete trajectories only (within a batch of trajectories of different lengths), and never processed by any model. If not specified, it is set totorch.full_like(s_0, -float('inf'))
. - The method
make_States_class
that creates a subclass of States. The instances of the resulting class should represent a batch of states of arbitrary shape, which is useful to define a trajectory, or a batch of trajectories.s_0
ands_f
, along with a tuple calledstate_shape
should be defined as class variables, and the subclass (ofStates
) should implement masking methods, that specify which actions are possible, in a discrete environment. - The methods
maskless_step
andmaskless_backward_step
that specify how an action changes a state (going forward and backward). These functions do not need to handle masking, checking whether actions are allowed, checking whether a state is the sink state, etc... These checks are handled inEnv.step
andEnv.backward_step
- The
log_reward
function that assigns a nonnegative reward to every terminating state (i.e. state with all $s_f$ as a child in the DAG). Iflog_reward
is not implemented,reward
needs to be.
If the states (as represented in the States
class) need to be transformed to another format before being processed (by neural networks for example), then the environment should define a preprocessor
attribute, which should be an instance of the base preprocessor class. If no preprocessor is defined, the states are used as is (actually transformed using IdentityPreprocessor
, which transforms the state tensors to FloatTensor
s). Implementing your own preprocessor requires defining the preprocess
function, and the output_shape
attribute, which is a tuple representing the shape of one preprocessed state.
Optionally, you can define a static get_states_indices
method that assigns a unique integer number to each state if the environment allows it, and a n_states
property that returns an integer representing the number of states (excluding $s_f$) in the environment. get_terminating_states_indices
can also be implemented and serves the purpose of uniquely identifying terminating states of the environment.
For more details, take a look at HyperGrid, an environment where all states are terminating states, or at DiscreteEBM, where all trajectories are of the same length but only some states are terminating.
Other containers
Besides the States
class, other containers of states are available:
- Transitions, representing a batch of transitions $s \rightarrow s'$.
- Trajectories, representing a batch of complete trajectories $\tau = s_0 \rightarrow s_1 \rightarrow \dots \rightarrow s_n \rightarrow s_f$.
These containers can either be instantiated using a States
object, or can be initialized as empty containers that can be populated on the fly, allowing the usage of ReplayBuffers.
They inherit from the base Container
class, indicating some helpful methods.
In most cases, one needs to sample complete trajectories. From a batch of trajectories, a batch of states and batch of transitions can be defined using Trajectories.to_transitions()
and Trajectories.to_states()
. These exclude meaningless transitions and states that were added to the batch of trajectories to allow for efficient batching.
Estimators and Modules
Training GFlowNets requires one or multiple estimators. As of now, only discrete environments are handled. All estimators are subclasses of FunctionEstimator, implementing a __call__
function that takes as input a batch of States.
- LogEdgeFlowEstimator. It outputs a
(*batch_shape, n_actions)
tensor representing $\log F(s \rightarrow s')$, including when $s' = s_f$. - LogStateFlowEstimator. It outputs a
(*batch_shape, 1)
tensor representing $\log F(s)$. When used withforward_looking=True
, $\log F(s)$ is parametrized as the sum of a function approximator and $\log R(s)$ - which is only possible for environments where all states are terminating. - LogitPFEstimator. It outputs a
(*batch_shape, n_actions)
tensor representing $logit(s' \mid s)$, such that $P_F(s' \mid s) = softmax_{s'}\ logit(s' \mid s)$, including when $s' = s_f$. - LogitPBEstimator. It outputs a
(*batch_shape, n_actions - 1)
tensor representing $logit(s' \mid s)$, such that $P_B(s' \mid s) = softmax_{s'}\ logit(s' \mid s)$.
Defining an estimator requires the environment, and a module instance. Modules inherit from the GFNModule class, which can be seen as an extension of torch.nn.Module
. Alternatively, a module is created by providing which module type to use (e.g. "NeuralNet" or "Uniform" or "Zero"). A Basic MLP is provided as the NeuralNet class, but any function approximator should be possible.
Said differently, a States
object is first transformed via the environment's preprocessor to a (*batch_shape, *output_shape)
float tensor. The preprocessor's output shape should match the module input shape (if any). The preprocessed states are then passed as inputs to the module, returning the desired output (either flows or probabilities over children in the DAG).
Each module has a named_parameters
functions that returns a dictionary of the learnable parameters. This attribute is transferred to the corresponding estimator.
Additionally, a LogZEstimator is provided, which is a container for a scalar tensor representing $\log Z$, the log-partition function, useful for the Trajectory Balance loss for example. This estimator also has a named_parameters
function.
Samplers
An ActionsSampler object defines how actions are sampled at each state of the DAG. As of now, only DiscreteActionsSamplers are implemented. The require an estimator (of $P_F$, $P_B$, or edge flows) defining the action probabilities. These estimators can contain any type of modules (including random action sampling for example). A BackwardDiscreteActionsSampler class is provided to sample parents of a state, which is helpful to sample trajectories starting from their last states.
They are at the core of TrajectoriesSamplers, which implements the sample_trajectories
method, that sample a batch of trajectories starting from a given set of initial states or starting from $s_0$.
Losses
GFlowNets can be trained with different losses, each of which requires a different parametrization. A parametrization is a dataclass, which can be seen as a container of different estimators. Each parametrization defines a distribution over trajectories, via the parametrization.Pi
method, and a distribution over terminating states, via the parametrization.P_T
method. Both distributions should be instances of the classes defined here.
The base classes for losses and parametrizations are provided here.
Currently, the implemented losses are:
- Flow Matching
- Detailed Balance
- Trajectory Balance
- Sub-Trajectory Balance. By default, each sub-trajectory is weighted geometrically (within the trajectory) depending on its length. This corresponds to the strategy defined here. Other strategies exist and are implemented here.
- Log Partition Variance loss. Introduced here
Solving for the flows using Dynamic Programming
A simple script that propagates trajectories rewards through the DAG to define edge flows in a deterministic way (by visiting each edge once only) is provided here. Do not use the script on large environments !
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchgfn-0.1.1.tar.gz
.
File metadata
- Download URL: torchgfn-0.1.1.tar.gz
- Upload date:
- Size: 38.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.5.0 CPython/3.9.12 Darwin/22.4.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e87205963bb55d3bf0a5463ad5bc0500ba0d5195accab22c4e536288c8d73311 |
|
MD5 | 36303b62db6975f4a7f96ed857e26353 |
|
BLAKE2b-256 | 7b9568c7b8b84e03721ea1ca38d1357f4baeb87732966b37b6fd27664fd7ce2d |
File details
Details for the file torchgfn-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: torchgfn-0.1.1-py3-none-any.whl
- Upload date:
- Size: 45.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.5.0 CPython/3.9.12 Darwin/22.4.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d0e0f70c96fda8ed9f26ded155d7dfad8d48b113a3d1419566bbfb211c1f20b1 |
|
MD5 | 94483d06d90a67b267dd4b8e98f736c4 |
|
BLAKE2b-256 | 0f1c9b2d0ca70cb2d28dec67895af084a843c1818102888fb065e4e0dee18ed1 |