diff --git a/attacut/dataloaders.py b/attacut/dataloaders.py index 21c3f07..a84c72a 100644 --- a/attacut/dataloaders.py +++ b/attacut/dataloaders.py @@ -175,7 +175,7 @@ def _process_line(line): def collate_fn(batch): total_samples = len(batch) - seq_lengths = np.array(list(map(lambda x: x[0][1], batch))) + seq_lengths = np.array(list(map(lambda x: x[0][1], batch)), dtype=np.int64) max_length = np.max(seq_lengths) features = np.zeros((total_samples, 2, max_length), dtype=np.int64) diff --git a/scripts/attacut-cli b/scripts/attacut-cli index f4a5a34..a75e171 100755 --- a/scripts/attacut-cli +++ b/scripts/attacut-cli @@ -165,8 +165,10 @@ if __name__ == "__main__": pred = preds[after_sorting_ix, :] token = tokens[ori_ix] - - words = preprocessing.find_words_from_preds(token, pred) + if len(token) != 0: + words = preprocessing.find_words_from_preds(token, pred) + else: + words = "" fout.write("%s\n" % SEP.join(words)) tq.update(n=preds.shape[0]) \ No newline at end of file