-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy paths2s_sup.py
95 lines (83 loc) · 4.83 KB
/
s2s_sup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
Infer trained model.
"""
import argparse
import os
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5Tokenizer, T5ForConditionalGeneration
from s2s_hf_transformers import T5_PREFIX, T5_MODELS
from s2s_pt_transformer import Seq2SeqTransformer, \
NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMBED_SIZE, NHEAD, DIM_FFN_HID
from s2s_pt_transformer import translate as pt_transformer_translate
from s2s_pt_transformer import construct_dataset_meta as pt_transformer_construct_dataset_meta
from dataset import load_split_dataset
from evaluation import evaluate_lang_from_file
from utils import count_params
class Seq2Seq:
def __init__(self, model_type, **kwargs):
self.model_type = model_type
if args.model in T5_MODELS: # https://huggingface.co/docs/transformers/model_doc/t5
model_dir = f"model/{model_type}"
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
elif model_type == "pt_transformer":
self.model = Seq2SeqTransformer(kwargs["src_vocab_sz"], kwargs["tar_vocab_sz"],
NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMBED_SIZE, NHEAD,
DIM_FFN_HID)
self.model_translate = pt_transformer_translate
self.vocab_transform = vocab_transform
self.text_transform = text_transform
self.model.load_state_dict(torch.load(kwargs["fpath_load"]))
else:
raise ValueError(f'ERROR: unrecognized model: {model_type}')
def translate(self, queries):
if self.model_type in T5_MODELS:
inputs = [f"{T5_PREFIX}{query}" for query in queries] # add prefix
inputs = self.tokenizer(inputs, return_tensors="pt", padding=True)
output_tokens = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
do_sample=False,
)
ltls = self.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
elif self.model_type == "pt_transformer":
ltls = [self.model_translate(self.model, self.vocab_transform, self.text_transform, queries[0])]
else:
raise ValueError(f'ERROR: unrecognized model, {self.model_type}')
return ltls
def parameters(self):
return self.model.parameters()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--split_dataset_fpath', type=str, default='data/split_symbolic_no_perm_batch1_utt_0.2_42.pkl',
help='complete file path or prefix of file paths to train test split dataset')
parser.add_argument('--model', type=str, default="t5-base", choices=["t5-base", "t5-small", "pt_transformer"],
help='name of supervised seq2seq model')
args = parser.parse_args()
if "pkl" in args.split_dataset_fpath: # complete file path, e.g. data/split_symbolic_no_perm_batch1_utt_0.2_42.pkl
split_dataset_fpaths = [args.split_dataset_fpath]
else: # prefix of file paths, e.g. split_symbolic_no_perm_batch1_utt
split_dataset_fpaths = [os.path.join("data", fpath) for fpath in os.listdir("data") if args.split_dataset_fpath in fpath]
for split_dataset_fpath in split_dataset_fpaths:
# Load train, test data
train_iter, train_meta, valid_iter, valid_meta = load_split_dataset(split_dataset_fpath)
# Load trained model
if args.model in T5_MODELS: # pretrained T5 from Hugging Face
s2s = Seq2Seq(args.model)
elif args.model == "pt_transformer": # pretrained seq2seq transformer implemented in PyTorch
vocab_transform, text_transform, src_vocab_size, tar_vocab_size = pt_transformer_construct_dataset_meta(train_iter)
model_params = f"model/s2s_{args.model}_{Path(split_dataset_fpath).stem}.pth"
s2s = Seq2Seq(args.model,
vocab_transform=vocab_transform, text_transform=text_transform,
src_vocab_sz=src_vocab_size, tar_vocab_sz=tar_vocab_size, fpath_load=model_params)
else:
raise TypeError(f"ERROR: unrecognized model, {args.model}")
print(f"Number of trainable parameters in {args.model}: {count_params(s2s)}")
print(f"Number of training samples: {len(train_iter)}")
print(f"Number of validation samples: {len(valid_iter)}")
# Evaluation
result_log_fpath = f"results/s2s_{args.model}_{Path(split_dataset_fpath).stem}_log.csv"
analysis_fpath = "data/analysis_batch1.csv"
acc_fpath = f"results/s2s_{args.model}_{Path(split_dataset_fpath).stem}_acc.csv"
evaluate_lang_from_file(s2s, split_dataset_fpath, analysis_fpath, result_log_fpath, acc_fpath)