PyTorch implementation for FAENet from 'FAENet: Frame Averaging Equivariant GNN for Materials Modeling'
Project description
FAENet: Frame Averaging Equivariant GNN for Materials modeling
This repository contains an implementation of the paper FAENet: Frame Averaging Equivariant GNN for Materials modeling, accepted at ICML 2023. More precisely, you will find:
FrameAveraging
: the transform that projects your pytorch-geometric data into the canonical space defined in the paper.FAENet
GNN model for material modeling.model_forward
: a high-level forward function that computes appropriate model predictions for the Frame Averaging method, i.e. handling the different frames and mapping to equivariant predictions.
Also: https://github.com/vict0rsch/faenet
Installation
pip install faenet
⚠️ The above installation requires Python >= 3.8
, torch > 1.11
, torch_geometric > 2.1
to the best of our knowledge. Both mendeleev
and pandas
package are also required to derive physics-aware atom embeddings in FAENet.
Getting started
Frame Averaging Transform
FrameAveraging
is a Transform method applicable to pytorch-geometric Data
object. You can choose among several options ranging from Full FA to Stochastic FA (in 2D or 3D) including data augmentation DA. This method shall be applied in the get_item()
function of your Dataset
class. Note that although this transform is specific to pytorch-geometric data objects, it can be easily extended to new settings since the core functions frame_averaging_2D()
and frame_averaging_3D()
generalise to other data format.
import torch
from faenet.transform import FrameAveraging
frame_averaging = "3D" # symmetry preservation method used: {"3D", "2D", "DA", ""}:
fa_method = "stochastic" # the frame averaging method: {"det", "all", "se3-stochastic", "se3-det", "se3-all", ""}:
transform = FrameAveraging(frame_averaging, fa_method)
transform(g) # transform the PyG graph g
Model forward for Frame Averaging
model_forward()
aggregates model predictions when Frame Averaging is applied, as stipulated by the Equation (1) of the paper. It must be applied.
from faenet.fa_forward import model_forward
preds = model_forward(
batch=batch, # batch from, dataloader
model=model, # FAENet(**kwargs)
frame_averaging="3D", # ["2D", "3D", "DA", ""]
mode="train", # for training
crystal_task=True, # for crystals, with pbc conditions
)
FAENet GNN
Implementation of the FAENet GNN model, compatible with any dataset or transform. In short, FAENet is a very simple, scalable and expressive model. Since does not explicitly preserve data symmetries, it has the ability to process directly and unrestrictedly atom relative positions, which is very efficient. Note that the training procedure is not given here.
from faenet.model import FAENet
preds = FAENet(**kwargs)
model(batch)
Eval
The eval_model_symmetries()
function helps you evaluate the equivariant, invariant and other properties of a model, as we did in the paper.
Tests
The /tests
folder contains several useful unit-tests. Feel free to have a look at them to explore how the model can be used. For more advanced examples, please refer to the full repository used in our ICML paper to make predictions on OC20 IS2RE, S2EF, QM9 and QM7-X dataset.
This requires poetry
. Make sure to have torch
and torch_geometric
installed in your environment before you can run the tests. Unfortunately because of CUDA/torch compatibilities, neither torch
nor torch_geometric
are part of the explicit dependencies and must be installed independently.
git clone git@github.com:vict0rsch/faenet.git
poetry install --with dev
pytest --cov=faenet --cov-report term-missing
Testing on Macs you may encounter a Library Not Loaded Error
Contact
Authors: Alexandre Duval (alexandre.duval@mila.quebec) and Victor Schmidt (schmidtv@mila.quebec). We welcome your questions and feedback via email or GitHub Issues.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file faenet-0.1.1.tar.gz
.
File metadata
- Download URL: faenet-0.1.1.tar.gz
- Upload date:
- Size: 20.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.2.2 CPython/3.9.2 Darwin/22.4.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 29244945d5f5b3b9e07abf5be0d1a56d1f908414effe99ae7835fa9662fc1448 |
|
MD5 | 58c40afaa345087b7dd6bb7f55ab6111 |
|
BLAKE2b-256 | acbf3abc968940c1ccf1a48935eed74a9593ee0a1439fa534f6dad5373d0be7b |
File details
Details for the file faenet-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: faenet-0.1.1-py3-none-any.whl
- Upload date:
- Size: 21.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.2.2 CPython/3.9.2 Darwin/22.4.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5c2fdbbec8f123a67c9101019983a264e4cb3848937d4ccbc55e23d2ffdd8159 |
|
MD5 | e7d9a2390aea7029f48a89437f16ffe2 |
|
BLAKE2b-256 | 5fdfe61f0ce64d2cc668f482c31cdf6c3555cf2d184fff7652a715525b10f38d |