Skip to content

Commit

Permalink
Refactor using the contextual lemmatizer so that the training script …
Browse files Browse the repository at this point in the history
…can evaluate that lemmatizer as well
  • Loading branch information
AngledLuffa committed Dec 22, 2024
1 parent 6529df0 commit 20c76c0
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 11 deletions.
24 changes: 24 additions & 0 deletions stanza/models/lemma/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn.init as init

import stanza.models.common.seq2seq_constant as constant
from stanza.models.common.doc import TEXT, UPOS
from stanza.models.common.foundation_cache import load_charlm
from stanza.models.common.seq2seq_model import Seq2SeqModel
from stanza.models.common.char_model import CharacterLanguageModelWordAdapter
Expand Down Expand Up @@ -171,6 +172,29 @@ def predict_contextual(self, sentence_words, sentence_tags, preds):
preds[sent_id][word_id] = pred
return preds

def update_contextual_preds(self, doc, preds):
"""
Update a flat list of preds with the output of the contextual lemmatizers
- First, it unflattens the preds based on the lengths of the sentences
- Then it uses the contextual lemmatizers
- Finally, it reflattens the preds into the format expected by the caller
"""
if len(self.contextual_lemmatizers) == 0:
return preds

sentence_words = doc.get([TEXT], as_sentences=True)
sentence_tags = doc.get([UPOS], as_sentences=True)
sentence_preds = []
start_index = 0
for sent in sentence_words:
end_index = start_index + len(sent)
sentence_preds.append(preds[start_index:end_index])
start_index += len(sent)
preds = self.predict_contextual(sentence_words, sentence_tags, sentence_preds)
preds = [lemma for sentence in preds for lemma in sentence]
return preds

def update_lr(self, new_lr):
utils.change_lr(self.optimizer, new_lr)

Expand Down
3 changes: 3 additions & 0 deletions stanza/models/lemmatizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ def evaluate(args):
logger.info("[Ensembling dict with seq2seq lemmatizer...]")
preds = trainer.ensemble(batch.doc.get([TEXT, UPOS]), preds)

if trainer.has_contextual_lemmatizers():
preds = trainer.update_contextual_preds(batch.doc, preds)

# write to file and score
batch.doc.set([LEMMA], preds)
CoNLL.write_doc2conll(batch.doc, system_pred_file)
Expand Down
11 changes: 1 addition & 10 deletions stanza/pipeline/lemma_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,7 @@ def process(self, document):
preds = self.trainer.postprocess(batch.doc.get([doc.TEXT]), preds, edits=edits)

if self.trainer.has_contextual_lemmatizers():
sentence_words = batch.doc.get([doc.TEXT], as_sentences=True)
sentence_tags = batch.doc.get([doc.UPOS], as_sentences=True)
sentence_preds = []
start_index = 0
for sent in sentence_words:
end_index = start_index + len(sent)
sentence_preds.append(preds[start_index:end_index])
start_index += len(sent)
preds = self.trainer.predict_contextual(sentence_words, sentence_tags, sentence_preds)
preds = [lemma for sentence in preds for lemma in sentence]
preds = self.trainer.update_contextual_preds(batch.doc, preds)

# map empty string lemmas to '_'
preds = [max([(len(x), x), (0, '_')])[1] for x in preds]
Expand Down
4 changes: 3 additions & 1 deletion stanza/utils/training/run_lemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def run_treebank(mode, paths, treebank, short_name,
'--output', save_name,
'--classifier', 'saved_models/lemma_classifier/%s_lemma_classifier.pt' % short_name]
attach_lemma_classifier.main(attach_args)
# TODO: rerun dev set / test set with the attached classifier?

# now we rerun the dev set - the HI in particular demonstrates some good improvement
lemmatizer.main(dev_args)

def main():
common.main(run_treebank, "lemma", "lemmatizer", add_lemma_args, sub_argparse=lemmatizer.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm)
Expand Down

0 comments on commit 20c76c0

Please sign in to comment.