Skip to main content

PyTorch implementation for FAENet from 'FAENet: Frame Averaging Equivariant GNN for Materials Modeling'

Project description

💻  Code   •   Docs  📑

Python Documentation Status


FAENet: Frame Averaging Equivariant GNN for Materials modeling

This repository contains an implementation of the paper FAENet: Frame Averaging Equivariant GNN for Materials modeling, accepted at ICML 2023. More precisely, you will find:

  • FrameAveraging: the transform that projects your pytorch-geometric data into the canonical space defined in the paper.
  • FAENet GNN model for material modeling.
  • model_forward: a high-level forward function that computes appropriate model predictions for the Frame Averaging method, i.e. handling the different frames and mapping to equivariant predictions.

Also: https://github.com/vict0rsch/faenet

Installation

pip install faenet

⚠️ The above installation requires Python >= 3.8, torch > 1.11, torch_geometric > 2.1 to the best of our knowledge. Both mendeleev and pandas package are also required to derive physics-aware atom embeddings in FAENet.

Getting started

Frame Averaging Transform

FrameAveraging is a Transform method applicable to pytorch-geometric Data object. You can choose among several options ranging from Full FA to Stochastic FA (in 2D or 3D) including data augmentation DA. This method shall be applied in the get_item() function of your Dataset class. Note that although this transform is specific to pytorch-geometric data objects, it can be easily extended to new settings since the core functions frame_averaging_2D() and frame_averaging_3D() generalise to other data format.

import torch
from faenet.transform import FrameAveraging

frame_averaging = "3D"  # symmetry preservation method used: {"3D", "2D", "DA", ""}:
fa_method = "stochastic"  # the frame averaging method: {"det", "all", "se3-stochastic", "se3-det", "se3-all", ""}:
transform = FrameAveraging(frame_averaging, fa_method)
transform(g)  # transform the PyG graph g 

Model forward for Frame Averaging

model_forward() aggregates model predictions when Frame Averaging is applied, as stipulated by the Equation (1) of the paper. It must be applied.

from faenet.fa_forward import model_forward

preds = model_forward(
    batch=batch,   # batch from, dataloader
    model=model,  # FAENet(**kwargs)
    frame_averaging="3D", # ["2D", "3D", "DA", ""]
    mode="train",  # for training 
    crystal_task=True,  # for crystals, with pbc conditions
)

FAENet GNN

Implementation of the FAENet GNN model, compatible with any dataset or transform. In short, FAENet is a very simple, scalable and expressive model. Since does not explicitly preserve data symmetries, it has the ability to process directly and unrestrictedly atom relative positions, which is very efficient. Note that the training procedure is not given here.

from faenet.model import FAENet

preds = FAENet(**kwargs)
model(batch)

FAENet architecture

Eval

The eval_model_symmetries() function helps you evaluate the equivariant, invariant and other properties of a model, as we did in the paper.

Tests

The /tests folder contains several useful unit-tests. Feel free to have a look at them to explore how the model can be used. For more advanced examples, please refer to the full repository used in our ICML paper to make predictions on OC20 IS2RE, S2EF, QM9 and QM7-X dataset.

This requires poetry. Make sure to have torch and torch_geometric installed in your environment before you can run the tests. Unfortunately because of CUDA/torch compatibilities, neither torch nor torch_geometric are part of the explicit dependencies and must be installed independently.

git clone git@github.com:vict0rsch/faenet.git
poetry install --with dev
pytest --cov=faenet --cov-report term-missing

Testing on Macs you may encounter a Library Not Loaded Error

Contact

Authors: Alexandre Duval (alexandre.duval@mila.quebec) and Victor Schmidt (schmidtv@mila.quebec). We welcome your questions and feedback via email or GitHub Issues.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

faenet-0.1.1.tar.gz (20.4 kB view details)

Uploaded Source

Built Distribution

faenet-0.1.1-py3-none-any.whl (21.0 kB view details)

Uploaded Python 3

File details

Details for the file faenet-0.1.1.tar.gz.

File metadata

  • Download URL: faenet-0.1.1.tar.gz
  • Upload date:
  • Size: 20.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.2.2 CPython/3.9.2 Darwin/22.4.0

File hashes

Hashes for faenet-0.1.1.tar.gz
Algorithm Hash digest
SHA256 29244945d5f5b3b9e07abf5be0d1a56d1f908414effe99ae7835fa9662fc1448
MD5 58c40afaa345087b7dd6bb7f55ab6111
BLAKE2b-256 acbf3abc968940c1ccf1a48935eed74a9593ee0a1439fa534f6dad5373d0be7b

See more details on using hashes here.

File details

Details for the file faenet-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: faenet-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 21.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.2.2 CPython/3.9.2 Darwin/22.4.0

File hashes

Hashes for faenet-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 5c2fdbbec8f123a67c9101019983a264e4cb3848937d4ccbc55e23d2ffdd8159
MD5 e7d9a2390aea7029f48a89437f16ffe2
BLAKE2b-256 5fdfe61f0ce64d2cc668f482c31cdf6c3555cf2d184fff7652a715525b10f38d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page