Skip to content

Commit

Permalink
updated wavesplit implementation. Tested only distance loss
Browse files Browse the repository at this point in the history
  • Loading branch information
popcornell committed Jul 15, 2020
1 parent 8d1ca2e commit d8e7ef7
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 184 deletions.
24 changes: 9 additions & 15 deletions egs/wham/WaveSplit/README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
### WaveSplit

things currently not clear:
---
- not clear if different encoders are used for separation and speaker stack. (from image in the paper it seems so)
- what is embedding dimension ? It seems 512 but it is not explicit in the paper
- mask used (sigmoid ?)
- when speakers in an example < sep stack outputs loss is simply masked or an embedding for silence is used ? (Probably masked)
- is VAD used in WSJ02MiX/ WHAM for determining speech activity at frame level ? Some files can have pauses of even one second
- loss right now is prone to go NaN especially if we don't take the mean after l2-distances computation.

---
structure:
- train.py contains training loop (nets instantiation lines 48-60, training loop lines 100- 116)
- losses.py wavesplit losses
- wavesplit.py sep and speaker stacks nets
- wavesplitwham.py dataset parsing
we train on 1 sec now.

tried with 256 embedding dimension.

still does not work with oracle embeddings.

not clear how in sep stack loss at every layer is computed ( is the same output layer used in all ?).
Also no mention in the paper about output layer and that first conv has no skip connection.

4 changes: 2 additions & 2 deletions egs/wham/WaveSplit/local/conf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ training:
num_workers: 4
half_lr: yes
early_stop: yes
gradient_clipping: 5
gradient_clipping: 5000
# Optim config
optim:
optimizer: adam
Expand All @@ -38,4 +38,4 @@ data:
nondefault_nsrc:
sample_rate: 8000
mode: min
segment: 0.750
segment: 1
10 changes: 5 additions & 5 deletions egs/wham/WaveSplit/local/preprocess_wham.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def preprocess_task(task, in_dir, out_dir):
examples = []
for mix in mix_both:
filename = mix.split("/")[-1]
spk1_id = filename.split("_")[0][:3]
spk2_id = filename.split("_")[2][:3]
spk1_id = filename.split("_")[0]
spk2_id = filename.split("_")[2]
length = len(sf.SoundFile(mix))

noise = os.path.join(in_dir, "noise", filename)
Expand All @@ -33,8 +33,8 @@ def preprocess_task(task, in_dir, out_dir):
examples = []
for mix in mix_clean:
filename = mix.split("/")[-1]
spk1_id = filename.split("_")[0][:3]
spk2_id = filename.split("_")[2][:3]
spk1_id = filename.split("_")[0]
spk2_id = filename.split("_")[2]
length = len(sf.SoundFile(mix))

s1 = os.path.join(in_dir, "s1", filename)
Expand All @@ -51,7 +51,7 @@ def preprocess_task(task, in_dir, out_dir):
examples = []
for mix in mix_single:
filename = mix.split("/")[-1]
spk1_id = filename.split("_")[0][:3]
spk1_id = filename.split("_")[0]
length = len(sf.SoundFile(mix))

s1 = os.path.join(in_dir, "s1", filename)
Expand Down
74 changes: 30 additions & 44 deletions egs/wham/WaveSplit/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
from torch.nn import functional as F
from itertools import permutations
from asteroid.losses.sdr import MultiSrcNegSDR
from asteroid.losses.sdr import MultiSrcNegSDR, SingleSrcNegSDR
from asteroid.losses import PITLossWrapper, PairwiseNegSDR,pairwise_neg_sisdr
import math


Expand All @@ -12,7 +13,7 @@ class ClippedSDR(nn.Module):
def __init__(self, clip_value=-30):
super(ClippedSDR, self).__init__()

self.snr = MultiSrcNegSDR("snr")
self.snr = PITLossWrapper(pairwise_neg_sisdr)
self.clip_value = float(clip_value)

def forward(self, est_targets, targets):
Expand All @@ -23,12 +24,9 @@ def forward(self, est_targets, targets):
class SpeakerVectorLoss(nn.Module):

def __init__(self, n_speakers, embed_dim=32, learnable_emb=True, loss_type="global",
weight=10, distance_reg=0.3, gaussian_reg=0.2, return_oracle=True):
weight=2, distance_reg=0.3, gaussian_reg=0.2, return_oracle=False):
super(SpeakerVectorLoss, self).__init__()


# not clear how embeddings are initialized.

self.learnable_emb = learnable_emb
self.loss_type = loss_type
self.weight = float(weight)
Expand All @@ -38,36 +36,30 @@ def __init__(self, n_speakers, embed_dim=32, learnable_emb=True, loss_type="glob

assert loss_type in ["distance", "global", "local"]

# I initialize embeddings to be on unit sphere as speaker stack uses euclidean normalization

spk_emb = torch.rand((n_speakers, embed_dim))
norms = torch.sum(spk_emb ** 2, -1, keepdim=True).sqrt()
spk_emb = spk_emb / norms # generate points on n-dimensional unit sphere
spk_emb = torch.eye(max(n_speakers, embed_dim)) # one-hot init works better according to Neil
spk_emb = spk_emb[:n_speakers, :embed_dim]

if learnable_emb == True:
self.spk_embeddings = nn.Parameter(spk_emb)
else:
self.register_buffer("spk_embeddings", spk_emb)

if loss_type != "dist":
self.alpha = nn.Parameter(torch.Tensor([1.])) # not clear how these are initialized...
if loss_type != "distance":
self.alpha = nn.Parameter(torch.Tensor([1.]))
self.beta = nn.Parameter(torch.Tensor([0.]))


### losses go to NaN if I follow strictly the formulas maybe I am missing something...

@staticmethod
def _l_dist_speaker(c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask):

utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2)
c_spk = c_spk_vec_perm[:, 0]
pair_dist = ((c_spk.unsqueeze(1) - c_spk_vec_perm)**2).sum(2)
pair_dist = pair_dist[:, 1:].sqrt()
distance = ((c_spk_vec_perm - utt_embeddings)**2).sum(2).sqrt()
return (distance + F.relu(1. - pair_dist).sum(1).unsqueeze(1)).sum(1)
pair_dist = pair_dist[:, 1:]
distance = ((c_spk_vec_perm - utt_embeddings)**2).sum(dim=(1,2))
return distance + F.relu(1. - pair_dist).sum(dim=(1))

def _l_local_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask):

raise NotImplemented
utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2)
alpha = torch.clamp(self.alpha, 1e-8)

Expand All @@ -79,42 +71,37 @@ def _l_local_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask)
return out.sum(1)

def _l_global_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask):

raise NotImplemented
utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2)
alpha = torch.clamp(self.alpha, 1e-8)

distance_utt = alpha*((c_spk_vec_perm - utt_embeddings)**2).sum(2).sqrt() + self.beta
distance_utt = alpha*((c_spk_vec_perm - utt_embeddings)**2).sum(2) + self.beta

B, src, embed_dim, frames = c_spk_vec_perm.size()
spk_embeddings = spk_embeddings.reshape(1, spk_embeddings.shape[0], embed_dim, 1).expand(B, -1, -1, frames)
distances = alpha * ((c_spk_vec_perm.unsqueeze(1) - spk_embeddings.unsqueeze(2)) ** 2).sum(3).sqrt() + self.beta
# exp normalize trick
with torch.no_grad():
b = torch.max(distances, dim=1, keepdim=True)[0]
out = -distance_utt + b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1))
return out.sum(1)
#with torch.no_grad():
# b = torch.max(distances, dim=1, keepdim=True)[0]
#out = -distance_utt + b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1))
#return out.sum(1)

def forward(self, speaker_vectors, spk_mask, spk_labels):

# spk_mask ideally would be the speaker activty at frame level. Because WHAM speakers can be considered always two and active we fix this for now.
# mask with ones and zeros B, SRC, FRAMES
def forward(self, speaker_vectors, spk_mask, spk_labels):

if self.gaussian_reg:
noise = torch.randn(self.spk_embeddings.size(), device=speaker_vectors.device)*math.sqrt(self.gaussian_reg)
spk_embeddings = self.spk_embeddings + noise
else:
spk_embeddings = self.spk_embeddings

if self.learnable_emb or self.gaussian_reg: # re project on unit sphere after noise has been applied and before computing the distance reg
if self.learnable_emb or self.gaussian_reg: # re project on unit sphere

spk_embeddings = spk_embeddings / torch.sum(spk_embeddings ** 2, -1, keepdim=True).sqrt()

if self.distance_reg:

pairwise_dist = ((spk_embeddings.unsqueeze(0) - spk_embeddings.unsqueeze(1))**2).sum(-1)
idx = torch.arange(0, pairwise_dist.shape[0])
pairwise_dist[idx, idx] = np.inf # masking with itself
pairwise_dist = pairwise_dist.sqrt()
pairwise_dist = (torch.abs(spk_embeddings.unsqueeze(0) - spk_embeddings.unsqueeze(1))).mean(-1).fill_diagonal_(np.inf)
distance_reg = -torch.sum(torch.min(torch.log(pairwise_dist), dim=-1)[0])

# speaker vectors B, n_src, dim, frames
Expand Down Expand Up @@ -145,10 +132,8 @@ def forward(self, speaker_vectors, spk_mask, spk_labels):
min_loss_perm = min_loss_perm.transpose(0, 1).reshape(B, n_src, 1, frames).expand(-1, -1, embed_dim, -1)
# tot_loss


spk_loss = self.weight*min_loss.mean()
if self.distance_reg:

spk_loss += self.distance_reg*distance_reg
reordered_sources = torch.gather(speaker_vectors, dim=1, index=min_loss_perm)

Expand All @@ -160,23 +145,24 @@ def forward(self, speaker_vectors, spk_mask, spk_labels):


if __name__ == "__main__":
n_speakers = 101
emb_speaker = 256

# testing exp normalize average
distances = torch.ones((1, 101, 4000))*99
with torch.no_grad():
b = torch.max(distances, dim=1, keepdim=True)[0]
out = b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1))
out2 = - torch.log(torch.exp(-distances).sum(1))
#distances = torch.ones((1, 101, 4000))
#with torch.no_grad():
# b = torch.max(distances, dim=1, keepdim=True)[0]
#out = b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1))
#out2 = - torch.log(torch.exp(-distances).sum(1))

loss_spk = SpeakerVectorLoss(1000, 32, loss_type="distance") # 1000 speakers in training set
loss_spk = SpeakerVectorLoss(n_speakers, emb_speaker, loss_type="global")

speaker_vectors = torch.rand(2, 3, 32, 200)
speaker_vectors = torch.rand(2, 3, emb_speaker, 200)
speaker_labels = torch.from_numpy(np.array([[1, 2, 0], [5, 2, 10]]))
speaker_mask = torch.randint(0, 2, (2, 3, 200)) # silence where there are no speakers actually thi is test
speaker_mask[:, -1, :] = speaker_mask[:, -1, :]*0
loss_spk(speaker_vectors, speaker_mask, speaker_labels)


c = ClippedSDR(-30)
a = torch.rand((2, 3, 200))
print(c(a, a))
Expand Down
8 changes: 2 additions & 6 deletions egs/wham/WaveSplit/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ mode=min
nondefault_src= # If you want to train a network with 3 output streams for example.

# Training
batch_size=1
num_workers=8
kernel_size=16
stride=8
batch_size=4
num_workers=4
#optimizer=adam
lr=0.001
epochs=400
Expand Down Expand Up @@ -134,8 +132,6 @@ if [[ $stage -le 3 ]]; then
--epochs $epochs \
--batch_size $batch_size \
--num_workers $num_workers \
--kernel_size $kernel_size \
--stride $stride \
--exp_dir ${expdir}/ | tee logs/train_${tag}.log
fi

Expand Down
Loading

0 comments on commit d8e7ef7

Please sign in to comment.