-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from chaitjo/dev
ICML camera-ready updates
- Loading branch information
Showing
30 changed files
with
1,828 additions
and
1,509 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from models.schnet import SchNetModel | ||
from models.dimenet import DimeNetPPModel | ||
from models.spherenet import SphereNetModel | ||
from models.egnn import EGNNModel | ||
from models.gvpgnn import GVPGNNModel | ||
from models.tfn import TFNModel | ||
from models.mace import MACEModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from typing import Callable, Union | ||
|
||
import torch | ||
from torch.nn import functional as F | ||
from torch_geometric.nn import DimeNetPlusPlus | ||
from torch_scatter import scatter | ||
|
||
|
||
class DimeNetPPModel(DimeNetPlusPlus): | ||
""" | ||
DimeNet model from "Directional message passing for molecular graphs". | ||
This class extends the DimeNetPlusPlus base class for PyG. | ||
""" | ||
def __init__( | ||
self, | ||
hidden_channels: int = 128, | ||
in_dim: int = 1, | ||
out_dim: int = 1, | ||
num_layers: int = 4, | ||
int_emb_size: int = 64, | ||
basis_emb_size: int = 8, | ||
out_emb_channels: int = 256, | ||
num_spherical: int = 7, | ||
num_radial: int = 6, | ||
cutoff: float = 10, | ||
max_num_neighbors: int = 32, | ||
envelope_exponent: int = 5, | ||
num_before_skip: int = 1, | ||
num_after_skip: int = 2, | ||
num_output_layers: int = 3, | ||
act: Union[str, Callable] = 'swish' | ||
): | ||
""" | ||
Initializes an instance of the DimeNetPPModel class with the provided parameters. | ||
Parameters: | ||
- hidden_channels (int): Number of channels in the hidden layers (default: 128) | ||
- in_dim (int): Input dimension of the model (default: 1) | ||
- out_dim (int): Output dimension of the model (default: 1) | ||
- num_layers (int): Number of layers in the model (default: 4) | ||
- int_emb_size (int): Embedding size for interaction features (default: 64) | ||
- basis_emb_size (int): Embedding size for basis functions (default: 8) | ||
- out_emb_channels (int): Number of channels in the output embeddings (default: 256) | ||
- num_spherical (int): Number of spherical harmonics (default: 7) | ||
- num_radial (int): Number of radial basis functions (default: 6) | ||
- cutoff (float): Cutoff distance for interactions (default: 10) | ||
- max_num_neighbors (int): Maximum number of neighboring atoms to consider (default: 32) | ||
- envelope_exponent (int): Exponent of the envelope function (default: 5) | ||
- num_before_skip (int): Number of layers before the skip connections (default: 1) | ||
- num_after_skip (int): Number of layers after the skip connections (default: 2) | ||
- num_output_layers (int): Number of output layers (default: 3) | ||
- act (Union[str, Callable]): Activation function (default: 'swish' or callable) | ||
Note: | ||
- The `act` parameter can be either a string representing a built-in activation function, | ||
or a callable object that serves as a custom activation function. | ||
""" | ||
super().__init__( | ||
hidden_channels, | ||
out_dim, | ||
num_layers, | ||
int_emb_size, | ||
basis_emb_size, | ||
out_emb_channels, | ||
num_spherical, | ||
num_radial, | ||
cutoff, | ||
max_num_neighbors, | ||
envelope_exponent, | ||
num_before_skip, | ||
num_after_skip, | ||
num_output_layers, | ||
act | ||
) | ||
|
||
def forward(self, batch): | ||
|
||
i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( | ||
batch.edge_index, num_nodes=batch.atoms.size(0)) | ||
|
||
# Calculate distances. | ||
dist = (batch.pos[i] - batch.pos[j]).pow(2).sum(dim=-1).sqrt() | ||
|
||
# Calculate angles. | ||
pos_i = batch.pos[idx_i] | ||
pos_ji, pos_ki = batch.pos[idx_j] - pos_i, batch.pos[idx_k] - pos_i | ||
a = (pos_ji * pos_ki).sum(dim=-1) | ||
b = torch.cross(pos_ji, pos_ki).norm(dim=-1) | ||
angle = torch.atan2(b, a) | ||
|
||
rbf = self.rbf(dist) | ||
sbf = self.sbf(dist, angle, idx_kj) | ||
|
||
# Embedding block. | ||
x = self.emb(batch.atoms, rbf, i, j) | ||
P = self.output_blocks[0](x, rbf, i, num_nodes=batch.pos.size(0)) | ||
|
||
# Interaction blocks. | ||
for interaction_block, output_block in zip(self.interaction_blocks, | ||
self.output_blocks[1:]): | ||
x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) | ||
P += output_block(x, rbf, i) | ||
|
||
return P.sum(dim=0) if batch is None else scatter(P, batch.batch, dim=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import torch | ||
from torch.nn import functional as F | ||
from torch_geometric.nn import global_add_pool, global_mean_pool | ||
|
||
from models.layers.egnn_layer import EGNNLayer | ||
|
||
|
||
class EGNNModel(torch.nn.Module): | ||
""" | ||
E-GNN model from "E(n) Equivariant Graph Neural Networks". | ||
""" | ||
def __init__( | ||
self, | ||
num_layers: int = 5, | ||
emb_dim: int = 128, | ||
in_dim: int = 1, | ||
out_dim: int = 1, | ||
activation: str = "relu", | ||
norm: str = "layer", | ||
aggr: str = "sum", | ||
pool: str = "sum", | ||
residual: bool = True, | ||
equivariant_pred: bool = False | ||
): | ||
""" | ||
Initializes an instance of the EGNNModel class with the provided parameters. | ||
Parameters: | ||
- num_layers (int): Number of layers in the model (default: 5) | ||
- emb_dim (int): Dimension of the node embeddings (default: 128) | ||
- in_dim (int): Input dimension of the model (default: 1) | ||
- out_dim (int): Output dimension of the model (default: 1) | ||
- activation (str): Activation function to be used (default: "relu") | ||
- norm (str): Normalization method to be used (default: "layer") | ||
- aggr (str): Aggregation method to be used (default: "sum") | ||
- pool (str): Global pooling method to be used (default: "sum") | ||
- residual (bool): Whether to use residual connections (default: True) | ||
- equivariant_pred (bool): Whether it is an equivariant prediction task (default: False) | ||
""" | ||
super().__init__() | ||
self.equivariant_pred = equivariant_pred | ||
self.residual = residual | ||
|
||
# Embedding lookup for initial node features | ||
self.emb_in = torch.nn.Embedding(in_dim, emb_dim) | ||
|
||
# Stack of GNN layers | ||
self.convs = torch.nn.ModuleList() | ||
for _ in range(num_layers): | ||
self.convs.append(EGNNLayer(emb_dim, activation, norm, aggr)) | ||
|
||
# Global pooling/readout function | ||
self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool] | ||
|
||
if self.equivariant_pred: | ||
# Linear predictor for equivariant tasks using geometric features | ||
self.pred = torch.nn.Linear(emb_dim + 3, out_dim) | ||
else: | ||
# MLP predictor for invariant tasks using only scalar features | ||
self.pred = torch.nn.Sequential( | ||
torch.nn.Linear(emb_dim, emb_dim), | ||
torch.nn.ReLU(), | ||
torch.nn.Linear(emb_dim, out_dim) | ||
) | ||
|
||
def forward(self, batch): | ||
|
||
h = self.emb_in(batch.atoms) # (n,) -> (n, d) | ||
pos = batch.pos # (n, 3) | ||
|
||
for conv in self.convs: | ||
# Message passing layer | ||
h_update, pos_update = conv(h, pos, batch.edge_index) | ||
|
||
# Update node features (n, d) -> (n, d) | ||
h = h + h_update if self.residual else h_update | ||
|
||
# Update node coordinates (no residual) (n, 3) -> (n, 3) | ||
pos = pos_update | ||
|
||
if not self.equivariant_pred: | ||
# Select only scalars for invariant prediction | ||
out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d) | ||
else: | ||
out = self.pool(torch.cat([h, pos], dim=-1), batch.batch) | ||
|
||
return self.pred(out) # (batch_size, out_dim) |
Oops, something went wrong.