Skip to content

stdereka/egnn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

EGNN

Simple and clean implementation of Graph Neural Networks and E(n) Equivariant Graph Neural Networks from the paper.

Quick Start

Install the package

git clone [email protected]:stdereka/egnn.git
cd egnn
pip install -e .

Download NBody and QM9 datasets and unpack them.

Then you can run egnn package as a Python 3 module. Note: check dataset root directory in .yaml config.

python -m egnn -c config/qm9_egnn_cv.yaml

This command trains EGNN model on QM9 dataset and stores Tensorboard logs in ./logs. You may find other config examples in ./config.

Using as a Library

For more details see help() for GNN and EGNN classes.

import torch
from egnn import GNN

gnn = GNN(
    input_node_dim=3,
    input_edge_dim=2,
    output_dim=1,
    hidden_dim=64,
    num_layers=3,
)

# 3 nodes with 3 features
node_features = torch.tensor(
    [[0.1, 0.2, 0.3],
     [23.0, 0.0, 1.0],
     [0.0, 0.0, 10.1]]
)

# 2 edges: 0-1 and 1-2
edge_ids = [
    torch.tensor([0, 1]),
    torch.tensor([1, 2]),
]

# Each edge has 2 features
edge_features = torch.tensor(
    [[1.1, 0.2],
     [2.0, 0.0]],
)

out = gnn(node_features, edge_ids, edge_features, 1)[0]
# Model prediction for each node
# tensor([[-0.0889],
#         [-4.1104],
#         [ 1.2860]], grad_fn=<AddmmBackward0>)

About

E(n) Equivariant Graph Neural Networks

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages