diff --git a/sparselearning/funcs.py b/sparselearning/funcs.py index 562f828..2000b44 100644 --- a/sparselearning/funcs.py +++ b/sparselearning/funcs.py @@ -159,7 +159,7 @@ def magnitude_and_negativity_prune(masking, mask, weight, name): # remove the most negative weights x, idx = torch.sort(weight.data.view(-1)) - mask.data.view(-1)[idx[:math.ceil(num_remove/2.0)]] = 0.0 + mask.data.view(-1)[idx[math.ceil(num_remove/2.0):]] = 0.0 return mask