-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathinference.py
119 lines (97 loc) · 3.37 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python3
import os
import argparse
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from utils.data import ParserLexicon
from model import Encoder, Decoder
from utils.config import DataConfig, ModelConfig, TestConfig
from utils.text_tools import tokenize_pt
def load_model(model_path, model):
model.load_state_dict(torch.load(
model_path,
map_location=lambda storage,
loc: storage
))
model.to(TestConfig.device)
model.eval()
return model
class G2P(object):
def __init__(self):
# data
self.ds = ParserLexicon(
DataConfig.graphemes_path,
DataConfig.phonemes_path,
DataConfig.lexicon_path
)
# model
self.encoder_model = Encoder(
ModelConfig.graphemes_size,
ModelConfig.hidden_size
)
load_model(TestConfig.encoder_model_path, self.encoder_model)
self.decoder_model = Decoder(
ModelConfig.phonemes_size,
ModelConfig.hidden_size
)
load_model(TestConfig.decoder_model_path, self.decoder_model)
def __call__(self, word, visualize):
x = [0] + [self.ds.g2idx[ch] for ch in word] + [1]
x = torch.tensor(x).long().unsqueeze(1)
with torch.no_grad():
enc = self.encoder_model(x)
phonemes, att_weights = [], []
x = torch.zeros(1, 1).long().to(TestConfig.device)
hidden = torch.ones(
1,
1,
ModelConfig.hidden_size
).to(TestConfig.device)
t = 0
while True:
with torch.no_grad():
out, hidden, att_weight = self.decoder_model(
x,
enc,
hidden
)
att_weights.append(att_weight.detach().cpu())
max_index = out[0, 0].argmax()
x = max_index.unsqueeze(0).unsqueeze(0)
t += 1
phonemes.append(self.ds.idx2p[max_index.item()])
if max_index.item() == 1:
break
if visualize:
att_weights = torch.cat(att_weights).squeeze(1).numpy().T
y, x = att_weights.shape
plt.imshow(att_weights, cmap='gray')
plt.yticks(range(y), ['<sos>'] + list(word) + ['<eos>'])
plt.xticks(range(x), phonemes)
plt.savefig(f'attention/{DataConfig.language}/{word}.png')
return phonemes
def is_ponctuation(token):
if token in ['.','?','!',',',':',';']:
return True
return False
def inference(sentence, char_separator='|', visualize=False):
tokens = tokenize_pt(sentence)
g2p = G2P()
phone_phrase = ""
for item in tokens:
print(item)
if is_ponctuation(item):
phone_phrase += char_separator+item+char_separator+" "
else:
result = g2p(item, visualize)[:-1]
phoneme = char_separator+char_separator.join(result)+char_separator
phone_phrase += phoneme+" "
return phone_phrase.strip()[1:-1]
if __name__ == '__main__':
# get word
parser = argparse.ArgumentParser()
parser.add_argument('--sentence', type=str, default='olá, vamos testar esse projeto.')
parser.add_argument('--visualize', action='store_true')
args = parser.parse_args()
print(inference(args.sentence, char_separator='|'))