Overview | Installation | Quickstart | Examples | Documentation | License
Haiku Geometric is a collection of graph neural networks (GNNs) implemented using JAX. It tries to provide object-oriented and easy-to-use modules for GNNs.
Haiku Geometric is built on top of Haiku and Jraph. It is deeply inspired by PyTorch Geometric. In most cases, Haiku Geometric tries to replicate the API of PyTorch Geometric to allow code sharing between the two.
Haiku Geometric is still under development and I would advise against using it in production.
Haiku Geometric can be installed from source:
pip install git+https://github.com/alexOarga/haiku-geometric.git
Alternatively, you can install Haiku Geometric using pip:
pip install haiku-geometric
For instance, we can create a simple graph convolutional network (GCN) of 2 layers as follows:
import jax
import haiku as hk
from haiku_geometric.nn import GCNConv
class GCN(hk.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(hidden_channels)
self.conv2 = GCNConv(hidden_channels)
self.linear = hk.Linear(out_channels)
def __call__(self, nodes,senders, receivers):
x = self.conv1(nodes, senders, receivers)
x = jax.nn.relu(x)
x = self.conv2(x, senders, receivers)
x = self.linear(nodes)
return x
def forward(nodes, senders, receivers):
gcn = GCN(16, 7)
return gcn(nodes, senders, receivers)
The GNN that we have defined is a Haiku Module.
To convert our module in a function that can be used with JAX, we transform
it using hk.transform
as described in the
Haiku documentation.
model = hk.transform(forward)
model = hk.without_apply_rng(model)
rng = jax.random.PRNGKey(42)
params = model.init(rng, nodes=nodes, senders=senders, receivers=receivers)
We can now run a forward pass on the model:
output = model.apply(params=params, nodes=nodes, senders=senders, receivers=receivers)
The documentation for Haiku Geometric can be found here.
Haiku Geometric comes with a few examples that showcase the usage of the library. The following examples are available:
Link | |
---|---|
Quickstart Example | |
Graph Convolution Networks with Karate Club dataset | |
Graph Attention Networks with CORA dataset | |
TopKPooling and GraphConv with PROTEINS dataset |
Currently, Haiku Geometric includes the following GNN modules:
The following positional encodings are currently available:
Model | Description |
---|---|
LaplacianEncoder | Laplacian positional encoding from the Rethinking Graph Transformers with Spectral Attention paper. |
MagLaplacianEncoder | Magnetic Laplacian positional encoding from the Transformers Meet Directed Graphs paper. |
If you encounter any issue, please open an issue.
Haiku Geometric can be tested using pytest
by running the following command:
python -m pytest test/