-
Notifications
You must be signed in to change notification settings - Fork 70
/
Copy pathconv_tasnet.py
92 lines (66 loc) · 2.9 KB
/
conv_tasnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from .utility import models, sdr
# Conv-TasNet
class TasNet(nn.Module):
def __init__(self, enc_dim=512, feature_dim=128, sr=16000, win=2, layer=8, stack=3,
kernel=3, num_spk=2, causal=False):
super(TasNet, self).__init__()
# hyper parameters
self.num_spk = num_spk
self.enc_dim = enc_dim
self.feature_dim = feature_dim
self.win = int(sr*win/1000)
self.stride = self.win // 2
self.layer = layer
self.stack = stack
self.kernel = kernel
self.causal = causal
# input encoder
self.encoder = nn.Conv1d(1, self.enc_dim, self.win, bias=False, stride=self.stride)
# TCN separator
self.TCN = models.TCN(self.enc_dim, self.enc_dim*self.num_spk, self.feature_dim, self.feature_dim*4,
self.layer, self.stack, self.kernel, causal=self.causal)
self.receptive_field = self.TCN.receptive_field
# output decoder
self.decoder = nn.ConvTranspose1d(self.enc_dim, 1, self.win, bias=False, stride=self.stride)
def pad_signal(self, input):
# input is the waveforms: (B, T) or (B, 1, T)
# reshape and padding
if input.dim() not in [2, 3]:
raise RuntimeError("Input can only be 2 or 3 dimensional.")
if input.dim() == 2:
input = input.unsqueeze(1)
batch_size = input.size(0)
nsample = input.size(2)
rest = self.win - (self.stride + nsample % self.win) % self.win
if rest > 0:
pad = Variable(torch.zeros(batch_size, 1, rest)).type(input.type())
input = torch.cat([input, pad], 2)
pad_aux = Variable(torch.zeros(batch_size, 1, self.stride)).type(input.type())
input = torch.cat([pad_aux, input, pad_aux], 2)
return input, rest
def forward(self, input):
# padding
output, rest = self.pad_signal(input)
batch_size = output.size(0)
# waveform encoder
enc_output = self.encoder(output) # B, N, L
# generate masks
masks = torch.sigmoid(self.TCN(enc_output)).view(batch_size, self.num_spk, self.enc_dim, -1) # B, C, N, L
masked_output = enc_output.unsqueeze(1) * masks # B, C, N, L
# waveform decoder
output = self.decoder(masked_output.view(batch_size*self.num_spk, self.enc_dim, -1)) # B*C, 1, L
output = output[:,:,self.stride:-(rest+self.stride)].contiguous() # B*C, 1, L
output = output.view(batch_size, self.num_spk, -1) # B, C, T
return output
def test_conv_tasnet():
x = torch.rand(2, 32000)
nnet = TasNet()
x = nnet(x)
s1 = x[0]
print(s1.shape)
if __name__ == "__main__":
test_conv_tasnet()