Skip to content

Commit

Permalink
Simplify lemma data preparation & usage - no need for both an 'in' an…
Browse files Browse the repository at this point in the history
…d a 'gold' file, at least not at present
  • Loading branch information
AngledLuffa committed Dec 30, 2024
1 parent 59117ba commit 8d1e040
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 18 deletions.
5 changes: 5 additions & 0 deletions stanza/models/lemma/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
Utils and wrappers for scoring lemmatizers.
"""

import logging

from stanza.models.common.utils import ud_scores

logger = logging.getLogger('stanza')

def score(system_conllu_file, gold_conllu_file):
""" Wrapper for lemma scorer. """
logger.debug("Evaluating system file %s vs gold file %s", system_conllu_file, gold_conllu_file)
evaluation = ud_scores(gold_conllu_file, system_conllu_file)
el = evaluation["Lemmas"]
p, r, f = el.precision, el.recall, el.f1
Expand Down
13 changes: 5 additions & 8 deletions stanza/models/lemmatizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@
def build_argparse():
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data/lemma', help='Directory for all lemma data.')
parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
parser.add_argument('--train_file', type=str, default=None, help='Training input file for data loader.')
parser.add_argument('--eval_file', type=str, default=None, help='Evaluation input file for data loader.')
parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')
parser.add_argument('--gold_file', type=str, default=None, help='Output CoNLL-U file.')

parser.add_argument('--mode', default='train', choices=['train', 'predict'])
parser.add_argument('--shorthand', type=str, help='Shorthand for the dataset to use. lang_dataset')
Expand Down Expand Up @@ -147,7 +146,7 @@ def train(args):

# pred and gold path
system_pred_file = args['output_file']
gold_file = args['gold_file']
gold_file = args['eval_file']

utils.print_config(args)

Expand Down Expand Up @@ -258,7 +257,6 @@ def train(args):
def evaluate(args):
# file paths
system_pred_file = args['output_file']
gold_file = args['gold_file']
model_file = build_model_filename(args)

# load model
Expand Down Expand Up @@ -304,10 +302,9 @@ def evaluate(args):
# write to file and score
batch.doc.set([LEMMA], preds)
CoNLL.write_doc2conll(batch.doc, system_pred_file)
if gold_file is not None:
_, _, score = scorer.score(system_pred_file, gold_file)

logger.info("Finished evaluation\nLemma score:\n{} {:.2f}".format(args['shorthand'], score*100))
_, _, score = scorer.score(system_pred_file, args['eval_file'])
logger.info("Finished evaluation\nLemma score:\n{} {:.2f}".format(args['shorthand'], score*100))

if __name__ == '__main__':
main()
3 changes: 0 additions & 3 deletions stanza/utils/datasets/prepare_tokenizer_treebank.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ def copy_conllu_treebank(treebank, model_type, paths, dest_dir, postprocess=None
postprocess(tokenizer_dir, "train.gold", dest_dir, "train.in", short_name)
postprocess(tokenizer_dir, "dev.gold", dest_dir, "dev.in", short_name)
postprocess(tokenizer_dir, "test.gold", dest_dir, "test.in", short_name)
if model_type is not common.ModelType.POS and model_type is not common.ModelType.DEPPARSE:
copy_conllu_file(dest_dir, "dev.in", dest_dir, "dev.gold", short_name)
copy_conllu_file(dest_dir, "test.in", dest_dir, "test.gold", short_name)

def split_train_file(treebank, train_input_conllu, train_output_conllu, dev_output_conllu):
# set the seed for each data file so that the results are the same
Expand Down
7 changes: 0 additions & 7 deletions stanza/utils/training/run_lemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,8 @@ def run_treebank(mode, paths, treebank, short_name,
lemma_dir = paths["LEMMA_DATA_DIR"]
train_file = f"{lemma_dir}/{short_name}.train.in.conllu"
dev_in_file = f"{lemma_dir}/{short_name}.dev.in.conllu"
dev_gold_file = f"{lemma_dir}/{short_name}.dev.gold.conllu"
dev_pred_file = temp_output_file if temp_output_file else f"{lemma_dir}/{short_name}.dev.pred.conllu"
test_in_file = f"{lemma_dir}/{short_name}.test.in.conllu"
test_gold_file = f"{lemma_dir}/{short_name}.test.gold.conllu"
test_pred_file = temp_output_file if temp_output_file else f"{lemma_dir}/{short_name}.test.pred.conllu"

charlm_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm)
Expand All @@ -99,15 +97,13 @@ def run_treebank(mode, paths, treebank, short_name,
train_args = ["--train_file", train_file,
"--eval_file", dev_in_file,
"--output_file", dev_pred_file,
"--gold_file", dev_gold_file,
"--shorthand", short_name]
logger.info("Running identity lemmatizer for {} with args {}".format(treebank, train_args))
identity_lemmatizer.main(train_args)
elif mode == Mode.SCORE_TEST:
train_args = ["--train_file", train_file,
"--eval_file", test_in_file,
"--output_file", test_pred_file,
"--gold_file", test_gold_file,
"--shorthand", short_name]
logger.info("Running identity lemmatizer for {} with args {}".format(treebank, train_args))
identity_lemmatizer.main(train_args)
Expand All @@ -122,7 +118,6 @@ def run_treebank(mode, paths, treebank, short_name,
train_args = ["--train_file", train_file,
"--eval_file", dev_in_file,
"--output_file", dev_pred_file,
"--gold_file", dev_gold_file,
"--shorthand", short_name,
"--num_epoch", num_epochs,
"--mode", "train"]
Expand All @@ -133,7 +128,6 @@ def run_treebank(mode, paths, treebank, short_name,
if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
dev_args = ["--eval_file", dev_in_file,
"--output_file", dev_pred_file,
"--gold_file", dev_gold_file,
"--shorthand", short_name,
"--mode", "predict"]
dev_args = dev_args + charlm_args + extra_args
Expand All @@ -143,7 +137,6 @@ def run_treebank(mode, paths, treebank, short_name,
if mode == Mode.SCORE_TEST:
test_args = ["--eval_file", test_in_file,
"--output_file", test_pred_file,
"--gold_file", test_gold_file,
"--shorthand", short_name,
"--mode", "predict"]
test_args = test_args + charlm_args + extra_args
Expand Down

0 comments on commit 8d1e040

Please sign in to comment.