-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_transformer.py
26 lines (22 loc) · 1.27 KB
/
train_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
"Train Transformer"
import os
import argparse
def train(config):
model_config = f"--arch transformer --share-all-embeddings " \
f"--optimizer adam --lr {config.lr} --label-smoothing 0.1 --dropout {config.dropout} " \
f"--attention-dropout {config.dropout} --max-tokens 4000 --min-lr '1e-09' --lr-scheduler inverse_sqrt " \
f"--weight-decay 0.0001 --criterion label_smoothed_cross_entropy " \
f"--max-epoch {config.max_epoch} --warmup-updates 4000 --warmup-init-lr '1e-07' " \
f"--adam-betas '(0.9, 0.98)' --save-interval-updates 5000 --clip-norm 0.1 " \
f"--share-decoder-input-output-embed --layernorm-embedding "
train_prompt = f"fairseq-train {config.data_path} {model_config} --save-dir {config.ckpt} "
os.system(train_prompt)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--max_epoch', type=int, default=10)
parser.add_argument('--data_path', type=str, default='data-bin')
parser.add_argument('--ckpt', type=str, default='ckpt_transformer')
args = parser.parse_args()
train(args)