Skip to main content

PyTorch implementation for FAENet from 'FAENet: Frame Averaging Equivariant GNN for Materials Modeling'

Project description

💻  Code   •   Docs  📑

Python Documentation Status


FAENet: Frame Averaging Equivariant GNN for Materials modeling

This repository contains an implementation of the paper FAENet: Frame Averaging Equivariant GNN for Materials modeling, accepted at ICML 2023. More precisely, you will find:

  • FrameAveraging: the transform that projects your pytorch-geometric data into a canonical space of all euclidean transformations, as defined in the paper.
  • FAENet: a GNN architecture for material modeling.
  • model_forward: a high-level forward function that computes appropriate equivariant model predictions for the Frame Averaging method, i.e. handling the different frames and mapping to equivariant predictions.

Also: https://github.com/vict0rsch/faenet

Installation

pip install faenet

⚠️ The above installation requires Python >= 3.8, torch > 1.11, torch_geometric > 2.1 to the best of our knowledge. Both mendeleev and pandas package are also required to derive physics-aware atom embeddings in FAENet.

Getting started

Frame Averaging Transform

FrameAveraging is a Transform method applicable to pytorch-geometric Data object, which shall be used in the get_item() function of your Dataset class. This method derives a new canonical position for the atomic graph, identical for all euclidean symmetries, and stores it under the data attribute fa_pos. You can choose among several options for the frame averaging, ranging from Full FA to Stochastic FA (in 2D or 3D) including traditional data augmentation DA with rotated samples. See the full doc for more details. Note that, although this transform is specific to pytorch-geometric data objects, it can be easily extended to new settings since the core functions frame_averaging_2D() and frame_averaging_3D() generalise to other data format.

import torch
from faenet.transforms import FrameAveraging

frame_averaging = "3D"  # symmetry preservation method used: {"3D", "2D", "DA", ""}:
fa_method = "stochastic"  # the frame averaging method: {"det", "all", "se3-stochastic", "se3-det", "se3-all", ""}:
transform = FrameAveraging(frame_averaging, fa_method)
transform(data)  # transform the PyG graph data

Model forward for Frame Averaging

model_forward() aggregates the predictions of a chosen ML model (e.g FAENet) when Frame Averaging is applied, as stipulated by the Equation (1) of the paper. INded, applying the model on canonical positions (fa_pos) directly would not yield equivariant predictions. This method must be applied at training and inference time to compute all model predictions. It requires batch to have pos, batch and frame averaging attributes (see docu).

from faenet.fa_forward import model_forward

preds = model_forward(
    batch=batch,   # batch from, dataloader
    model=model,  # FAENet(**kwargs)
    frame_averaging="3D", # ["2D", "3D", "DA", ""]
    mode="train",  # for training 
    crystal_task=True,  # for crystals, with pbc conditions
)

FAENet GNN

Implementation of the FAENet GNN model, compatible with any dataset or transform. In short, FAENet is a very simple, scalable and expressive model. Since does not explicitly preserve data symmetries, it has the ability to process directly and unrestrictedly atom relative positions, which is very efficient and powerful. Although it was specifically designed to be applied with Frame Averaging above, to preserve symmetries without any design restrictions, note that it can also be applied without. When applied with Frame Averaging, we need to use the model_forward() function above to compute model predictions, model(data) is not enough. Note that the training procedure is not given here, you should refer to the original github repository. Check the documentation to see all input parameters.

Note that the model assumes input data (e.g.batch below) to have certain attributes, like atomic_numbers, batch, pos or edge_index. If your data does not have these attributes, you can apply custom pre-processing functions, taking pbc_preprocess or base_preprocess in utils.py as inspiration. You simply need to pass them as argument to FAENet (preprocess).

from faenet.model import FAENet

preds = FAENet(**kwargs)
model(batch)

FAENet architecture

Eval

The eval_model_symmetries() function helps you evaluate the equivariant, invariant and other properties of a model, as we did in the paper.

Note: you can predict any atom-level or graph-level property, although the code explicitly refers to energy and forces.

Tests

The /tests folder contains several useful unit-tests. Feel free to have a look at them to explore how the model can be used. For more advanced examples, please refer to the full repository used in our ICML paper to make predictions on OC20 IS2RE, S2EF, QM9 and QM7-X dataset.

This requires poetry. Make sure to have torch and torch_geometric installed in your environment before you can run the tests. Unfortunately because of CUDA/torch compatibilities, neither torch nor torch_geometric are part of the explicit dependencies and must be installed independently.

git clone git@github.com:vict0rsch/faenet.git
poetry install --with dev
pytest --cov=faenet --cov-report term-missing

Testing on Macs you may encounter a Library Not Loaded Error

Contact

Authors: Alexandre Duval (alexandre.duval@mila.quebec) and Victor Schmidt (schmidtv@mila.quebec). We welcome your questions and feedback via email or GitHub Issues.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

faenet-0.1.2.tar.gz (23.2 kB view details)

Uploaded Source

Built Distribution

faenet-0.1.2-py3-none-any.whl (23.5 kB view details)

Uploaded Python 3

File details

Details for the file faenet-0.1.2.tar.gz.

File metadata

  • Download URL: faenet-0.1.2.tar.gz
  • Upload date:
  • Size: 23.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.2.2 CPython/3.9.2 Darwin/22.4.0

File hashes

Hashes for faenet-0.1.2.tar.gz
Algorithm Hash digest
SHA256 2a9fd832a9bb0eb83989fe9c7021dce89338412acb556105d87eb6dde588596e
MD5 1aef56eb0b83faf2042b3a43dc4a0752
BLAKE2b-256 3880ddc6363ff5619d2c971d5734023c9e17b9ddc63f01ef8dc3793a9c6da3aa

See more details on using hashes here.

File details

Details for the file faenet-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: faenet-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 23.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.2.2 CPython/3.9.2 Darwin/22.4.0

File hashes

Hashes for faenet-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 eeb47df994dad34904eb54582f09cd391b7da953c3466be14212c3aaea9d4d96
MD5 c0b33722c3760e4aad9cab974a2de0c2
BLAKE2b-256 2c04b0bc88582ff1ae7619b49d56ccc159ab1ce3b74f169c5dc3b0369f0b7005

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page