-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.py
99 lines (85 loc) · 2.27 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python3
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utils.data import ParserLexicon, collate_fn
from model import Encoder, Decoder
from utils.config import DataConfig, ModelConfig, TrainConfig
# data prep
ds = ParserLexicon(
DataConfig.graphemes_path,
DataConfig.phonemes_path,
DataConfig.lexicon_path
)
dl = DataLoader(
ds,
collate_fn=collate_fn,
batch_size=TrainConfig.batch_size
)
# models
encoder_model = Encoder(
ModelConfig.graphemes_size,
ModelConfig.hidden_size
).to(TrainConfig.device)
decoder_model = Decoder(
ModelConfig.phonemes_size,
ModelConfig.hidden_size
).to(TrainConfig.device)
# log
log = SummaryWriter(TrainConfig.log_path)
# loss
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = torch.optim.Adam(
list(encoder_model.parameters()) +
list(decoder_model.parameters()),
lr=TrainConfig.lr
)
# training loop
counter = 0
for e in range(TrainConfig.epochs):
print('-' * 20 + f'epoch: {e+1:02d}' + '-' * 20)
for g, p in tqdm(dl):
g = g.to(TrainConfig.device)
p = p.to(TrainConfig.device)
# encode
enc = encoder_model(g)
# decoder
T, N = p.size()
outputs = []
hidden = torch.ones(
1,
N,
ModelConfig.hidden_size
).to(TrainConfig.device)
for t in range(T - 1):
out, hidden, _ = decoder_model(
p[t:t+1],
enc,
hidden
)
outputs.append(out)
outputs = torch.cat(outputs)
# flat Time and Batch, calculate loss
outputs = outputs.view((T-1) * N, -1)
p = p[1:] # trim first phoneme
p = p.view(-1)
loss = criterion(outputs, p)
# updata weights
optimizer.zero_grad()
loss.backward()
optimizer.step()
log.add_scalar('loss', loss.item(), counter)
counter += 1
# save model
torch.save(
encoder_model.state_dict(),
f'checkpoints/encoder_e{e+1:02d}.pth'
)
torch.save(
decoder_model.state_dict(),
f'checkpoints/decoder_e{e+1:02d}.pth'
)