From d019c4c8815cb782f0079957808bdfb0055e4901 Mon Sep 17 00:00:00 2001 From: Eleven Liu Date: Fri, 26 Apr 2024 11:36:47 +0800 Subject: [PATCH 1/2] Remove embed_vecs from Model and PLTModel. Ignore word_dict and network from save_parameters --- libmultilabel/nn/attentionxml.py | 3 --- libmultilabel/nn/model.py | 5 +---- libmultilabel/nn/nn_utils.py | 1 - 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/libmultilabel/nn/attentionxml.py b/libmultilabel/nn/attentionxml.py index c327cbea1..a5166baa6 100644 --- a/libmultilabel/nn/attentionxml.py +++ b/libmultilabel/nn/attentionxml.py @@ -381,7 +381,6 @@ def fit(self, datasets): model_1 = PLTModel( classes=self.classes, word_dict=self.word_dict, - embed_vecs=self.embed_vecs, network=network, log_path=self.log_path, learning_rate=self.learning_rate, @@ -521,7 +520,6 @@ def __init__( self, classes, word_dict, - embed_vecs, network, loss_function="binary_cross_entropy_with_logits", log_path=None, @@ -530,7 +528,6 @@ def __init__( super().__init__( classes=classes, word_dict=word_dict, - embed_vecs=embed_vecs, network=network, loss_function=loss_function, log_path=log_path, diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index 040d35fb8..bfe3344c3 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -182,7 +182,6 @@ class Model(MultiLabelModel): Args: classes (list): List of class names. word_dict (torchtext.vocab.Vocab): A vocab object which maps tokens to indices. - embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). network (nn.Module): Network (i.e., CAML, KimCNN, or XMLCNN). loss_function (str, optional): Loss function name (i.e., binary_cross_entropy_with_logits, cross_entropy). Defaults to 'binary_cross_entropy_with_logits'. @@ -193,7 +192,6 @@ def __init__( self, classes, word_dict, - embed_vecs, network, loss_function="binary_cross_entropy_with_logits", log_path=None, @@ -201,10 +199,9 @@ def __init__( ): super().__init__(num_classes=len(classes), log_path=log_path, **kwargs) self.save_hyperparameters( - ignore=["log_path"] + ignore=["log_path", "word_dict", "network"] ) # If log_path is saved, loading the checkpoint will cause an error since each experiment has unique log_path (result_dir). self.word_dict = word_dict - self.embed_vecs = embed_vecs self.classes = classes self.network = network self.configure_loss_function(loss_function) diff --git a/libmultilabel/nn/nn_utils.py b/libmultilabel/nn/nn_utils.py index b75ecb4f7..a4ac82c22 100644 --- a/libmultilabel/nn/nn_utils.py +++ b/libmultilabel/nn/nn_utils.py @@ -100,7 +100,6 @@ def init_model( model = Model( classes=classes, word_dict=word_dict, - embed_vecs=embed_vecs, network=network, log_path=log_path, learning_rate=learning_rate, From 5e7148b3157aa2f67377d2acba7363f94a092d0d Mon Sep 17 00:00:00 2001 From: Eleven Liu Date: Fri, 26 Apr 2024 12:04:00 +0800 Subject: [PATCH 2/2] Cannot ignore word_dict and network for now. --- libmultilabel/nn/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index bfe3344c3..1f0ab95f8 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -199,7 +199,7 @@ def __init__( ): super().__init__(num_classes=len(classes), log_path=log_path, **kwargs) self.save_hyperparameters( - ignore=["log_path", "word_dict", "network"] + ignore=["log_path"] ) # If log_path is saved, loading the checkpoint will cause an error since each experiment has unique log_path (result_dir). self.word_dict = word_dict self.classes = classes