Skip to content

Commit

Permalink
add multihot label encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamasb committed Nov 13, 2023
1 parent 5315f7e commit 6ba482b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
multihot_label_encoding:
_target_: proteinworkshop.tasks.multihot_label_encoding.MultiHotLabelEncoding
num_classes : ${dataset.num_classes}
24 changes: 24 additions & 0 deletions proteinworkshop/tasks/multihot_label_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Union

import torch
from graphein.protein.tensor.data import Protein
from torch_geometric import transforms as T
from torch_geometric.data import Data


class MultiHotLabelEncoding(T.BaseTransform):
"""
Transform to multihot encode labels for multilabel classification.
:param num_classes: Number of classes to encode.
:type num_classes: int
"""

def __init__(self, num_classes: int) -> None:
self.num_classes = num_classes

def __call__(self, data: Union[Protein, Data]) -> Union[Protein, Data]:
labels = torch.zeros((1, self.num_classes))
labels[:, data.graph_y] = 1
data.graph_y = labels
return data

0 comments on commit 6ba482b

Please sign in to comment.