Skip to content

Commit

Permalink
G&L GPU, WaveNet NN upsample
Browse files Browse the repository at this point in the history
  • Loading branch information
Rayhane-mamah authored Jan 26, 2019
1 parent 869ab79 commit ab5cb08
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 84 deletions.
81 changes: 64 additions & 17 deletions datasets/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,39 @@ def inv_mel_spectrogram(mel_spectrogram, hparams):
else:
return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)

###########################################################################################
# tensorflow Griffin-Lim
# Thanks to @begeekmyfriend: https://github.com/begeekmyfriend/Tacotron-2/blob/mandarin-new/datasets/audio.py

def inv_linear_spectrogram_tensorflow(spectrogram, hparams):
'''Builds computational graph to convert spectrogram to waveform using TensorFlow.
Unlike inv_spectrogram, this does NOT invert the preemphasis. The caller should call
inv_preemphasis on the output after running the graph.
'''
if hparams.signal_normalization:
D = _denormalize_tensorflow(spectrogram, hparams)
else:
D = linear_spectrogram

S = tf.pow(_db_to_amp_tensorflow(D + hparams.ref_level_db), (1/hparams.magnitude_power))
return _griffin_lim_tensorflow(tf.pow(S, hparams.power), hparams)

def inv_mel_spectrogram_tensorflow(mel_spectrogram, hparams):
'''Builds computational graph to convert mel spectrogram to waveform using TensorFlow.
Unlike inv_mel_spectrogram, this does NOT invert the preemphasis. The caller should call
inv_preemphasis on the output after running the graph.
'''
if hparams.signal_normalization:
D = _denormalize_tensorflow(mel_spectrogram, hparams)
else:
D = mel_spectrogram

S = tf.pow(_db_to_amp_tensorflow(D + hparams.ref_level_db), (1/hparams.magnitude_power))
S = _mel_to_linear_tensorflow(S, hparams) # Convert back to linear
return _griffin_lim_tensorflow(tf.pow(S, hparams.power), hparams)

###########################################################################################

def _lws_processor(hparams):
import lws
return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech")
Expand All @@ -127,6 +160,21 @@ def _griffin_lim(S, hparams):
y = _istft(S_complex * angles, hparams)
return y

def _griffin_lim_tensorflow(S, hparams):
'''TensorFlow implementation of Griffin-Lim
Based on https://github.com/Kyubyong/tensorflow-exercises/blob/master/Audio_Processing.ipynb
'''
with tf.variable_scope('griffinlim'):
# TensorFlow's stft and istft operate on a batch of spectrograms; create batch of size 1
S = tf.expand_dims(S, 0)
S_complex = tf.identity(tf.cast(S, dtype=tf.complex64))
y = tf.contrib.signal.inverse_stft(S_complex, hparams.win_size, get_hop_size(hparams), hparams.n_fft)
for i in range(hparams.griffin_lim_iters):
est = tf.contrib.signal.stft(y, hparams.win_size, get_hop_size(hparams), hparams.n_fft)
angles = est / tf.cast(tf.maximum(1e-8, tf.abs(est)), tf.complex64)
y = tf.contrib.signal.inverse_stft(S_complex * angles, hparams.win_size, get_hop_size(hparams), hparams.n_fft)
return tf.squeeze(y, 0)

def _stft(y, hparams):
if hparams.use_lws:
return _lws_processor(hparams).stft(y).T
Expand Down Expand Up @@ -186,6 +234,12 @@ def _mel_to_linear(mel_spectrogram, hparams):
_inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))

def _mel_to_linear_tensorflow(mel_spectrogram, hparams):
global _inv_mel_basis
if _inv_mel_basis is None:
_inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
return tf.transpose(tf.maximum(1e-10, tf.matmul(tf.cast(_inv_mel_basis, tf.float32), tf.transpose(mel_spectrogram, [1, 0]))), [1, 0])

def _build_mel_basis(hparams):
assert hparams.fmax <= hparams.sample_rate // 2
return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
Expand All @@ -198,6 +252,9 @@ def _amp_to_db(x, hparams):
def _db_to_amp(x):
return np.power(10.0, (x) * 0.05)

def _db_to_amp_tensorflow(x):
return tf.pow(tf.ones(tf.shape(x)) * 10.0, x * 0.05)

def _normalize(S, hparams):
if hparams.allow_clipping_in_normalization:
if hparams.symmetric_mels:
Expand Down Expand Up @@ -226,26 +283,16 @@ def _denormalize(D, hparams):
else:
return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)

def normalize_tf(S, hparams):
#[0, 1]
if hparams.normalize_for_wavenet:
if hparams.allow_clipping_in_normalization:
return tf.minimum(tf.maximum((S - hparams.min_level_db) / (-hparams.min_level_db),
-hparams.max_abs_value), hparams.max_abs_value)

else:
return (S - hparams.min_level_db) / (-hparams.min_level_db)

#[-max, max] or [0, max]
def _denormalize_tensorflow(D, hparams):
if hparams.allow_clipping_in_normalization:
if hparams.symmetric_mels:
return tf.minimum(tf.maximum((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
-hparams.max_abs_value), hparams.max_abs_value)
return (((tf.clip_by_value(D, -hparams.max_abs_value,
hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
+ hparams.min_level_db)
else:
return tf.minimum(tf.maximum(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0), hparams.max_abs_value)
return ((tf.clip_by_value(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)

assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
if hparams.symmetric_mels:
return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
else:
return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
2 changes: 1 addition & 1 deletion datasets/wavenet_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,4 @@ def _process_utterance(mel_dir, wav_dir, index, wav_path, hparams):
speaker_id = '<no_g>'

# Return a tuple describing this training example
return (audio_filename, mel_filename, '_', speaker_id, time_steps, mel_frames)
return (audio_filename, mel_filename, mel_filename, speaker_id, time_steps, mel_frames)
4 changes: 2 additions & 2 deletions griffin_lim_synthesis_tool.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"os.makedirs(out_dir, exist_ok=True)\n",
"\n",
"#mel_file = os.path.join(mel_folder, mel_file)\n",
"mel_file = 'training_data/mels/mel-JRE1169-0000.npy'\n",
"mel_file = 'training_data/mels/mel-LJ001-0001.npy'\n",
"mel_spectro = np.load(mel_file)\n",
"mel_spectro.shape"
]
Expand Down Expand Up @@ -55,7 +55,7 @@
"metadata": {},
"outputs": [],
"source": [
"lin_file = 'training_data/linear/linear-JRE1169-0000.npy'\n",
"lin_file = 'training_data/linear/linear-LJ001-0001.npy'\n",
"lin_spectro = np.load(lin_file)\n",
"lin_spectro.shape"
]
Expand Down
15 changes: 8 additions & 7 deletions hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
#Griffin Lim
power = 1.5, #Only used in G&L inversion, usually values between 1.2 and 1.5 are a good choice.
griffin_lim_iters = 60, #Number of G&L iterations, typically 30 is enough but we use 60 to ensure convergence.
GL_on_GPU = True, #Whether to use G&L GPU version as part of tensorflow graph. (Usually much faster than CPU but slightly worse quality too).
###########################################################################################################################################

#Tacotron
Expand Down Expand Up @@ -194,7 +195,7 @@
log_scale_min=float(np.log(1e-14)), #Mixture of logistic distributions minimal log scale
log_scale_min_gauss = float(np.log(1e-7)), #Gaussian distribution minimal allowed log scale
#Loss type
cdf_loss = True, #Whether to use CDF loss in Gaussian modeling. Advantages: non-negative loss term and more training stability. (Automatically True for MoL)
cdf_loss = False, #Whether to use CDF loss in Gaussian modeling. Advantages: non-negative loss term and more training stability. (Automatically True for MoL)

#model parameters
#To use Gaussian distribution as output distribution instead of mixture of logistics, set "out_channels = 2" instead of "out_channels = 10 * 3". (UNDER TEST)
Expand All @@ -208,17 +209,17 @@

#Upsampling parameters (local conditioning)
cin_channels = 80, #Set this to -1 to disable local conditioning, else it must be equal to num_mels!!
upsample_conditional_features = True, #Whether to repeat conditional features or upsample them (The latter is recommended)
#Upsample types: ('1D', '2D', 'Resize', 'SubPixel')
#Upsample types: ('1D', '2D', 'Resize', 'SubPixel', 'NearestNeighbor')
#All upsampling initialization/kernel_size are chosen to omit checkerboard artifacts as much as possible. (Resize is designed to omit that by nature).
#To be specific, all initial upsample weights/biases (when NN_init=True) ensure that the upsampling layers act as a "Nearest neighbor upsample" of size "hop_size" (checkerboard free).
#1D spans all frequency bands for each frame (channel-wise) while 2D spans "freq_axis_kernel_size" bands at a time. Both are vanilla transpose convolutions.
#Resize is a 2D convolution that follows a Nearest Neighbor (NN) resize. For reference, this is: "NN resize->convolution".
#Finally, SubPixel (2D) is the ICNR version (initialized to be equivalent to "convolution->NN resize") of Sub-Pixel convolutions. also called "checkered artifact free sub-pixel conv".
upsample_type = 'SubPixel', #Type of the upsampling deconvolution. Can be ('1D' or '2D', 'Resize', 'SubPixel').
#SubPixel (2D) is the ICNR version (initialized to be equivalent to "convolution->NN resize") of Sub-Pixel convolutions. also called "checkered artifact free sub-pixel conv".
#Finally, NearestNeighbor is a non-trainable upsampling layer that just expands each frame (or "pixel") to the equivalent hop size. Ignores all upsampling parameters.
upsample_type = 'SubPixel', #Type of the upsampling deconvolution. Can be ('1D' or '2D', 'Resize', 'SubPixel' or simple 'NearestNeighbor').
upsample_activation = 'Relu', #Activation function used during upsampling. Can be ('LeakyRelu', 'Relu' or None)
upsample_scales = [11, 25], #prod(upsample_scales) should be equal to hop_size
freq_axis_kernel_size = 2, #Only used for 2D upsampling types. This is the number of requency bands that are spanned at a time for each frame.
freq_axis_kernel_size = 3, #Only used for 2D upsampling types. This is the number of requency bands that are spanned at a time for each frame.
leaky_alpha = 0.4, #slope of the negative portion of LeakyRelu (LeakyRelu: y=x if x>0 else y=alpha * x)
NN_init = True, #Determines whether we want to initialize upsampling kernels/biases in a way to ensure upsample is initialize to Nearest neighbor upsampling. (Mostly for debug)
NN_scaler = 0.3, #Determines the initial Nearest Neighbor upsample values scale. i.e: upscaled_input_values = input_values * NN_scaler (1. to disable)
Expand Down Expand Up @@ -351,7 +352,7 @@
'He reads books.',
'He thought it was time to present the present.',
'Thisss isrealy awhsome.',
'The big brown fox jumped over the lazy dog.',
'The big brown fox jumps over the lazy dog.',
'Did the big brown fox jump over the lazy dog?',
"Peter Piper picked a peck of pickled peppers. How many pickled peppers did Peter Piper pick?",
"She sells sea-shells on the sea-shore. The shells she sells are sea-shells I'm sure.",
Expand Down
25 changes: 22 additions & 3 deletions tacotron/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ def load(self, checkpoint_path, hparams, gta=False, model_name='Tacotron'):
self.stop_token_prediction = self.model.tower_stop_token_prediction
self.targets = targets

if hparams.GL_on_GPU:
self.GLGPU_mel_inputs = tf.placeholder(tf.float32, (None, hparams.num_mels), name='GLGPU_mel_inputs')
self.GLGPU_lin_inputs = tf.placeholder(tf.float32, (None, hparams.num_freq), name='GLGPU_lin_inputs')

self.GLGPU_mel_outputs = audio.inv_mel_spectrogram_tensorflow(self.GLGPU_mel_inputs, hparams)
self.GLGPU_lin_outputs = audio.inv_linear_spectrogram_tensorflow(self.GLGPU_lin_inputs, hparams)

self.gta = gta
self._hparams = hparams
#pad input sequences with the <pad_token> 0 ( _ )
Expand Down Expand Up @@ -154,7 +161,11 @@ def synthesize(self, texts, basenames, out_dir, log_dir, mel_filenames):

if basenames is None:
#Generate wav and read it
wav = audio.inv_mel_spectrogram(mels[0].T, hparams)
if hparams.GL_on_GPU:
wav = self.session.run(self.GLGPU_mel_outputs, feed_dict={self.GLGPU_mel_inputs: mels[0]})
wav = audio.inv_preemphasis(wav, hparams.preemphasis, hparams.preemphasize)
else:
wav = audio.inv_mel_spectrogram(mels[0].T, hparams)
audio.save_wav(wav, 'temp.wav', sr=hparams.sample_rate) #Find a better way

if platform.system() == 'Linux':
Expand Down Expand Up @@ -191,7 +202,11 @@ def synthesize(self, texts, basenames, out_dir, log_dir, mel_filenames):

if log_dir is not None:
#save wav (mel -> wav)
wav = audio.inv_mel_spectrogram(mel.T, hparams)
if hparams.GL_on_GPU:
wav = self.session.run(self.GLGPU_mel_outputs, feed_dict={self.GLGPU_mel_inputs: mel})
wav = audio.inv_preemphasis(wav, hparams.preemphasis, hparams.preemphasize)
else:
wav = audio.inv_mel_spectrogram(mel.T, hparams)
audio.save_wav(wav, os.path.join(log_dir, 'wavs/wav-{}-mel.wav'.format(basenames[i])), sr=hparams.sample_rate)

#save alignments
Expand All @@ -204,7 +219,11 @@ def synthesize(self, texts, basenames, out_dir, log_dir, mel_filenames):

if hparams.predict_linear:
#save wav (linear -> wav)
wav = audio.inv_linear_spectrogram(linears[i].T, hparams)
if hparams.GL_on_GPU:
wav = self.session.run(self.GLGPU_lin_outputs, feed_dict={self.GLGPU_lin_inputs: linears[i]})
wav = audio.inv_preemphasis(wav, hparams.preemphasis, hparams.preemphasize)
else:
wav = audio.inv_linear_spectrogram(linears[i].T, hparams)
audio.save_wav(wav, os.path.join(log_dir, 'wavs/wav-{}-linear.wav'.format(basenames[i])), sr=hparams.sample_rate)

#save linear spectrogram plot
Expand Down
32 changes: 28 additions & 4 deletions tacotron/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ def train(log_dir, args, hparams):

char_embedding_meta = char_embedding_meta.replace(log_dir, '..')

#Potential Griffin-Lim GPU setup
if hparams.GL_on_GPU:
GLGPU_mel_inputs = tf.placeholder(tf.float32, (None, hparams.num_mels), name='GLGPU_mel_inputs')
GLGPU_lin_inputs = tf.placeholder(tf.float32, (None, hparams.num_freq), name='GLGPU_lin_inputs')

GLGPU_mel_outputs = audio.inv_mel_spectrogram_tensorflow(GLGPU_mel_inputs, hparams)
GLGPU_lin_outputs = audio.inv_linear_spectrogram_tensorflow(GLGPU_lin_inputs, hparams)

#Book keeping
step = 0
time_window = ValueWindow(100)
Expand Down Expand Up @@ -256,7 +264,11 @@ def train(log_dir, args, hparams):
linear_losses.append(linear_loss)
linear_loss = sum(linear_losses) / len(linear_losses)

wav = audio.inv_linear_spectrogram(lin_p.T, hparams)
if hparams.GL_on_GPU:
wav = sess.run(GLGPU_lin_outputs, feed_dict={GLGPU_lin_inputs: lin_p})
wav = audio.inv_preemphasis(wav, hparams.preemphasis, hparams.preemphasize)
else:
wav = audio.inv_linear_spectrogram(lin_p.T, hparams)
audio.save_wav(wav, os.path.join(eval_wav_dir, 'step-{}-eval-wave-from-linear.wav'.format(step)), sr=hparams.sample_rate)

else:
Expand All @@ -278,7 +290,11 @@ def train(log_dir, args, hparams):

log('Saving eval log to {}..'.format(eval_dir))
#Save some log to monitor model improvement on same unseen sequence
wav = audio.inv_mel_spectrogram(mel_p.T, hparams)
if hparams.GL_on_GPU:
wav = sess.run(GLGPU_mel_outputs, feed_dict={GLGPU_mel_inputs: mel_p})
wav = audio.inv_preemphasis(wav, hparams.preemphasis, hparams.preemphasize)
else:
wav = audio.inv_mel_spectrogram(mel_p.T, hparams)
audio.save_wav(wav, os.path.join(eval_wav_dir, 'step-{}-eval-wave-from-mel.wav'.format(step)), sr=hparams.sample_rate)

plot.plot_alignment(align, os.path.join(eval_plot_dir, 'step-{}-eval-align.png'.format(step)),
Expand Down Expand Up @@ -319,7 +335,11 @@ def train(log_dir, args, hparams):
np.save(os.path.join(linear_dir, linear_filename), linear_prediction.T, allow_pickle=False)

#save griffin lim inverted wav for debug (linear -> wav)
wav = audio.inv_linear_spectrogram(linear_prediction.T, hparams)
if hparams.GL_on_GPU:
wav = sess.run(GLGPU_lin_outputs, feed_dict={GLGPU_lin_inputs: linear_prediction})
wav = audio.inv_preemphasis(wav, hparams.preemphasis, hparams.preemphasize)
else:
wav = audio.inv_linear_spectrogram(linear_prediction.T, hparams)
audio.save_wav(wav, os.path.join(wav_dir, 'step-{}-wave-from-linear.wav'.format(step)), sr=hparams.sample_rate)

#Save real and predicted linear-spectrogram plot to disk (control purposes)
Expand All @@ -341,7 +361,11 @@ def train(log_dir, args, hparams):
np.save(os.path.join(mel_dir, mel_filename), mel_prediction.T, allow_pickle=False)

#save griffin lim inverted wav for debug (mel -> wav)
wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
if hparams.GL_on_GPU:
wav = sess.run(GLGPU_mel_outputs, feed_dict={GLGPU_mel_inputs: mel_prediction})
wav = audio.inv_preemphasis(wav, hparams.preemphasis, hparams.preemphasize)
else:
wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
audio.save_wav(wav, os.path.join(wav_dir, 'step-{}-wave-from-mel.wav'.format(step)), sr=hparams.sample_rate)

#save alignment plot to disk (control purposes)
Expand Down
4 changes: 0 additions & 4 deletions wavenet_vocoder/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ def create_model(name, hparams, init=False):
if hparams.out_channels != hparams.quantize_channels:
raise RuntimeError(
"out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'")
if hparams.upsample_conditional_features and hparams.cin_channels < 0:
s = "Upsample conv layers were specified while local conditioning disabled. "
s += "Notice that upsample conv layers will never be used."
warn(s)

if name == 'WaveNet':
return WaveNet(hparams, init)
Expand Down
25 changes: 19 additions & 6 deletions wavenet_vocoder/models/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,22 @@ def step(self, x, c, g, is_incremental, queue=None):
x = (x + residual)
return x, s, queue


class NearestNeighborUpsample:
def __init__(self, strides):
#Save upsample params
self.resize_strides = strides

def __call__(self, inputs):
#inputs are supposed [batch_size, freq, time_steps, channels]
outputs = tf.image.resize_images(
inputs,
size=[inputs.shape[1] * self.resize_strides[0], tf.shape(inputs)[2] * self.resize_strides[1]],
method=1) #BILINEAR = 0, NEAREST_NEIGHBOR = 1, BICUBIC = 2, AREA = 3

return outputs


class SubPixelConvolution(tf.layers.Conv2D):
'''Sub-Pixel Convolutions are vanilla convolutions followed by Periodic Shuffle.
Expand Down Expand Up @@ -656,20 +672,17 @@ def __init__(self, filters, kernel_size, padding, strides, NN_init, NN_scaler, u
data_format='channels_last',
name=name, **kwargs)

self.resize_strides = strides
self.resize_layer = NearestNeighborUpsample(strides=strides)
self.scope = 'ResizeConvolution' if None else name

def call(self, inputs):
with tf.variable_scope(self.scope) as scope:
#Inputs are supposed [batch_size, freq, time_steps, channels]
resized = tf.image.resize_images(
inputs,
size=[inputs.shape[1] * self.resize_strides[0], tf.shape(inputs)[2] * self.resize_strides[1]],
method=1) #BILINEAR = 0, NEAREST_NEIGHBOR = 1, BICUBIC = 2, AREA = 3
resized = self.resize_layer(inputs)

return super(ResizeConvolution, self).call(resized)

def _init_kernel(kernel_size, strides):
def _init_kernel(self, kernel_size, strides):
'''Nearest Neighbor Upsample (Checkerboard free) init kernel size
'''
overlap = kernel_size[1] // strides[1]
Expand Down
Loading

1 comment on commit ab5cb08

@luis-vera
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hello. I have downloaded your code but unfortunately it has been very difficult try to execute it because your code run in Tensorflow 1 and I work in Tensorflow 2. I have modified some scripts but train stage it's been impossible. I'd like to know if you have a updated version of your code. Thanks a lot.

Please sign in to comment.