From f4dc893e750ef982a40f86544506e00ecb2d456c Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Sun, 28 Jan 2024 23:42:09 +0000 Subject: [PATCH 01/44] VQT outline with base args. --- src/torchaudio/transforms/_transforms.py | 53 ++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 802cbd3d77..f48f4dddcf 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -621,6 +621,59 @@ def forward(self, waveform: Tensor) -> Tensor: return mel_specgram +class VQT(torch.nn.Module): + r"""Create variable Q-transform for a raw audio signal. + + .. devices:: CPU CUDA + + .. properties:: Autograd TorchScript + + Sources + * https://librosa.org/doc/main/_modules/librosa/core/constantq.html + * https://www.aes.org/e-lib/online/browse.cfm?elib=17112 + * https://newt.phys.unsw.edu.au/jw/notes.html + + Args: + sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) + hop_length (int, optional): Length of hop between VQT windows. (Default: ``400``) + f_min (float, optional): Minimum frequency, which corresponds to first note. (Default: ``32.703``) + n_bins (int, optional): Number of VQT frequency bins, starting at ``f_min``. (Default: ``84``) + gamma (float, optional): Offset that controls VQT filter lengths. Larger values + increase the time resolution at lower frequencies. (Default: ``0``) + bins_per_octave (int, optional): Number of bins per octave. (Default: ``12``) + + Example + >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) + >>> transform = transforms.VQT(sample_rate) + >>> vqt_specgram = transform(waveform) # (channel, n_bins, time) + """ + __constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad", "n_mels", "f_min"] + + def __init__( + self, + sample_rate: int = 16000, + hop_length: int = 400, + f_min: float = 32.703, + n_bins: int = 84, + gamma: float = 0., + bins_per_octave: int = 12, + ) -> None: + super(VQT, self).__init__() + torch._C._log_api_usage_once("torchaudio.transforms.VQT") + + pass + + def forward(self, waveform: Tensor) -> Tensor: + r""" + Args: + waveform (Tensor): Tensor of audio of dimension (..., time). + + Returns: + Tensor: VQT spectrogram of size (..., ``n_bins``, time). + """ + pass + + class MFCC(torch.nn.Module): r"""Create the Mel-frequency cepstrum coefficients from an audio signal. From bd75d4d8abfa6088020e8f15aef83d01d0ccbcd4 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Wed, 29 May 2024 15:41:31 +0000 Subject: [PATCH 02/44] Equal temparament frequencies set. --- src/torchaudio/transforms/_transforms.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index f48f4dddcf..0345c440b7 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -2,6 +2,7 @@ import math import warnings +import numpy as np from typing import Callable, Optional, Sequence, Tuple, Union import torch @@ -636,7 +637,7 @@ class VQT(torch.nn.Module): Args: sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) hop_length (int, optional): Length of hop between VQT windows. (Default: ``400``) - f_min (float, optional): Minimum frequency, which corresponds to first note. (Default: ``32.703``) + f_min (float, optional): Minimum frequency, which corresponds to first note. (Default: ``32.703``, or the frequency of C1 in Hz) n_bins (int, optional): Number of VQT frequency bins, starting at ``f_min``. (Default: ``84``) gamma (float, optional): Offset that controls VQT filter lengths. Larger values increase the time resolution at lower frequencies. (Default: ``0``) @@ -647,7 +648,7 @@ class VQT(torch.nn.Module): >>> transform = transforms.VQT(sample_rate) >>> vqt_specgram = transform(waveform) # (channel, n_bins, time) """ - __constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad", "n_mels", "f_min"] + __constants__ = ["sample_rate", "hop_length", "f_min", "n_bins", "gamma", "bins_per_octave"] def __init__( self, @@ -660,8 +661,20 @@ def __init__( ) -> None: super(VQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.VQT") - - pass + + self.n_bins = n_bins + self.bins_per_octave = bins_per_octave + self.f_min = f_min + + self.n_octaves = math.ceil(self.n_bins / self.bins_per_octave) + n_filters = min(self.bins_per_octave, self.n_bins) + + frequencies = self.get_frequencies() + + def get_frequencies(self) -> list[float]: + r"""Return a set of frequencies that assumes an equal temperament tuning system.""" + ratios = 2.0 ** (np.arange(0, self.bins_per_octave * self.n_octaves) / self.bins_per_octave) + return self.f_min * ratios[:self.n_bins] def forward(self, waveform: Tensor) -> Tensor: r""" From 40d65816f470b8f1d069fd45d974b08171f07705 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Wed, 29 May 2024 17:05:53 +0000 Subject: [PATCH 03/44] Raise error if max frequency is superior to Nyquist. --- src/torchaudio/transforms/_transforms.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 0345c440b7..5cba38d2fa 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -671,6 +671,12 @@ def __init__( frequencies = self.get_frequencies() + if frequencies[-1] > sample_rate / 2: + raise ValueError( + f"Maximum bin center frequency is {frequencies[-1]} and superior to the Nyquist frequency {sample_rate/2}. " + "Try to reduce the number of frequency bins." + ) + def get_frequencies(self) -> list[float]: r"""Return a set of frequencies that assumes an equal temperament tuning system.""" ratios = 2.0 ** (np.arange(0, self.bins_per_octave * self.n_octaves) / self.bins_per_octave) From 3f2689e0b431d1702a6754b98db7990b39df1540 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Thu, 6 Jun 2024 17:50:44 +0000 Subject: [PATCH 04/44] VQT wavelet filter creation. --- src/torchaudio/transforms/_transforms.py | 59 +++++++++++++++++++++--- 1 file changed, 53 insertions(+), 6 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 5cba38d2fa..4908a09346 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -656,7 +656,7 @@ def __init__( hop_length: int = 400, f_min: float = 32.703, n_bins: int = 84, - gamma: float = 0., + gamma: Optional[float] = None, bins_per_octave: int = 12, ) -> None: super(VQT, self).__init__() @@ -665,22 +665,69 @@ def __init__( self.n_bins = n_bins self.bins_per_octave = bins_per_octave self.f_min = f_min + self.gamma = gamma + self.sample_rate = sample_rate self.n_octaves = math.ceil(self.n_bins / self.bins_per_octave) n_filters = min(self.bins_per_octave, self.n_bins) - frequencies = self.get_frequencies() + self.frequencies = self.get_frequencies() - if frequencies[-1] > sample_rate / 2: + if self.frequencies[-1] > sample_rate / 2: raise ValueError( - f"Maximum bin center frequency is {frequencies[-1]} and superior to the Nyquist frequency {sample_rate/2}. " + f"Maximum bin center frequency is {self.frequencies[-1]} and superior to the Nyquist frequency {sample_rate/2}. " "Try to reduce the number of frequency bins." ) - def get_frequencies(self) -> list[float]: + self.alpha = self.compute_alpha() + self.wav_lengths = self.wavelet_lengths() + + def get_frequencies(self) -> Tensor: r"""Return a set of frequencies that assumes an equal temperament tuning system.""" - ratios = 2.0 ** (np.arange(0, self.bins_per_octave * self.n_octaves) / self.bins_per_octave) + ratios = 2.0 ** (torch.arange(0, self.bins_per_octave * self.n_octaves, dtype=float) / self.bins_per_octave) return self.f_min * ratios[:self.n_bins] + + def compute_alpha(self) -> Tensor: + r"""Compute relative bandwidths for specified frequencies.""" + if self.n_bins > 1: + # Approximate local octave resolution around each frequency + bandpass_octave = torch.empty_like(self.frequencies) + log_frequencies = torch.log2(self.frequencies) + + # Reflect at the lowest and highest frequencies + bandpass_octave[0] = 1 / (log_frequencies[1] - log_frequencies[0]) + bandpass_octave[-1] = 1 / (log_frequencies[-1] - log_frequencies[-2]) + + # Centered difference + bandpass_octave[1:-1] = 2 / (log_frequencies[2:] - log_frequencies[:-2]) + + alpha = (2. ** (2 / bandpass_octave) - 1) / (2. ** (2 / bandpass_octave) + 1) + else: + # Special case when single basis frequency is used + rel_band_coeff = 2. ** (1. / self.bins_per_octave) + alpha = torch.atleast_1d((rel_band_coeff**2 - 1) / (rel_band_coeff**2 + 1)) + + return alpha + + def wavelet_lengths(self): + r"""Length of each filter in a wavelet basis.""" + if self.gamma is None: + # Specify gamma_ as: gamma[k] = 24.7 * alpha[k] / 0.108 when not defined + # From: Glasberg, Brian R., and Brian CJ Moore. + # "Derivation of auditory filter shapes from notched-noise data." + # Hearing research 47.1-2 (1990): 103-138. + gamma_ = self.alpha * 24.7 / 0.108 + else: + gamma_ = self.gamma + + # We assume filter_scale (librosa param) is 1 + Q = 1. / self.alpha + + # Convert frequencies to filter lengths + lengths = Q * self.sample_rate / (self.frequencies + gamma_ / self.alpha) + + return lengths + def forward(self, waveform: Tensor) -> Tensor: r""" From f13b46d1a88d45393205610d75c4d6276caa8aa4 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Thu, 6 Jun 2024 21:12:25 +0000 Subject: [PATCH 05/44] Top bin filter cutoff frequencies. --- src/torchaudio/transforms/_transforms.py | 31 +++++++++++++++++------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 4908a09346..494a816c3c 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -673,15 +673,15 @@ def __init__( self.frequencies = self.get_frequencies() - if self.frequencies[-1] > sample_rate / 2: + self.alpha = self.compute_alpha() + self.wav_lengths, cutoff_freq = self.wavelet_lengths() + + if cutoff_freq > sample_rate / 2: raise ValueError( - f"Maximum bin center frequency is {self.frequencies[-1]} and superior to the Nyquist frequency {sample_rate/2}. " + f"Maximum bin cutoff frequency is {cutoff_freq} and superior to the Nyquist frequency {sample_rate/2}. " "Try to reduce the number of frequency bins." ) - self.alpha = self.compute_alpha() - self.wav_lengths = self.wavelet_lengths() - def get_frequencies(self) -> Tensor: r"""Return a set of frequencies that assumes an equal temperament tuning system.""" ratios = 2.0 ** (torch.arange(0, self.bins_per_octave * self.n_octaves, dtype=float) / self.bins_per_octave) @@ -709,8 +709,19 @@ def compute_alpha(self) -> Tensor: return alpha - def wavelet_lengths(self): - r"""Length of each filter in a wavelet basis.""" + def wavelet_lengths(self, window_bandwidth: float = 1.50018310546875) -> Tuple[Tensor, float]: + r"""Length of each filter in a wavelet basis. + + Sources: + * https://librosa.org/doc/main/_modules/librosa/filters.html + + Args: + window_bandwifth (float, optional): Equivalent noise bandwidth (ENBW) of a window function. (Default: ``1.50018310546875``, or the Hann window value) + + Returns: + Tensor: filter lengths. + float: cutoff frequency of highest bin. + """ if self.gamma is None: # Specify gamma_ as: gamma[k] = 24.7 * alpha[k] / 0.108 when not defined # From: Glasberg, Brian R., and Brian CJ Moore. @@ -723,11 +734,13 @@ def wavelet_lengths(self): # We assume filter_scale (librosa param) is 1 Q = 1. / self.alpha + # Output cutoff frequency + cutoff_freq = max(self.frequencies * (1 + 0.5 * window_bandwidth / Q) + 0.5 * gamma_) + # Convert frequencies to filter lengths lengths = Q * self.sample_rate / (self.frequencies + gamma_ / self.alpha) - return lengths - + return lengths, cutoff_freq def forward(self, waveform: Tensor) -> Tensor: r""" From bfcca8a5e37605fd410a32fd6417fc98f231ae64 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Fri, 7 Jun 2024 13:52:03 +0000 Subject: [PATCH 06/44] Warnings for hop length and sample rate values. --- src/torchaudio/transforms/_transforms.py | 28 +++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 494a816c3c..f9f8bebdba 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -642,6 +642,7 @@ class VQT(torch.nn.Module): gamma (float, optional): Offset that controls VQT filter lengths. Larger values increase the time resolution at lower frequencies. (Default: ``0``) bins_per_octave (int, optional): Number of bins per octave. (Default: ``12``) + window_bandwidth (float, optional): Equivalent noise bandwidth (ENBW) of a window function. (Default: ``1.50018310546875``, or the Hann window value) Example >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) @@ -658,6 +659,7 @@ def __init__( n_bins: int = 84, gamma: Optional[float] = None, bins_per_octave: int = 12, + window_bandwidth: float = 1.50018310546875, ) -> None: super(VQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.VQT") @@ -674,14 +676,30 @@ def __init__( self.frequencies = self.get_frequencies() self.alpha = self.compute_alpha() - self.wav_lengths, cutoff_freq = self.wavelet_lengths() + self.wav_lengths, cutoff_freq = self.wavelet_lengths(window_bandwidth) + nyquist = sample_rate / 2 - if cutoff_freq > sample_rate / 2: + if cutoff_freq > nyquist: raise ValueError( - f"Maximum bin cutoff frequency is {cutoff_freq} and superior to the Nyquist frequency {sample_rate/2}. " + f"Maximum bin cutoff frequency is {cutoff_freq} and superior to the Nyquist frequency {nyquist}. " "Try to reduce the number of frequency bins." ) + # Number of zeros after first 1 in binary gives number of divisions by 2 before number becomes odd + num_hop_downsamples = len(str(bin(hop_length)).split('1')[-1]) + + if num_hop_downsamples > self.n_octaves: + warnings.warn( + f"Hop length can be divided {num_hop_downsamples} times by 2 before becoming odd. " + f"The VQT is however being computed for {self.n_octaves} octaves. Consider lowering the hop length or increasing the number of bins for more accurate results." + ) + + if nyquist / cutoff_freq > 4: + warnings.warn( + f"The Nyquist frequency {nyquist} is significantly higher than the highest filter's cutoff frequency {cutoff_freq}. " + "Consider resampling your signal to a lower sample rate or increasing the number of bins before VQT computation for more accurate results." + ) + def get_frequencies(self) -> Tensor: r"""Return a set of frequencies that assumes an equal temperament tuning system.""" ratios = 2.0 ** (torch.arange(0, self.bins_per_octave * self.n_octaves, dtype=float) / self.bins_per_octave) @@ -709,14 +727,14 @@ def compute_alpha(self) -> Tensor: return alpha - def wavelet_lengths(self, window_bandwidth: float = 1.50018310546875) -> Tuple[Tensor, float]: + def wavelet_lengths(self, window_bandwidth: float) -> Tuple[Tensor, float]: r"""Length of each filter in a wavelet basis. Sources: * https://librosa.org/doc/main/_modules/librosa/filters.html Args: - window_bandwifth (float, optional): Equivalent noise bandwidth (ENBW) of a window function. (Default: ``1.50018310546875``, or the Hann window value) + window_bandwidth (float, optional): Equivalent noise bandwidth (ENBW) of a window function. (Default: ``1.50018310546875``, or the Hann window value) Returns: Tensor: filter lengths. From 213e2d8004394ec23b81484310b7f3f20ba38a57 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Fri, 7 Jun 2024 16:26:48 +0000 Subject: [PATCH 07/44] Forward loop outline. --- src/torchaudio/transforms/_transforms.py | 33 ++++++++++++++++++++---- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index f9f8bebdba..42897cd5bd 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -2,7 +2,6 @@ import math import warnings -import numpy as np from typing import Callable, Optional, Sequence, Tuple, Union import torch @@ -669,15 +668,16 @@ def __init__( self.f_min = f_min self.gamma = gamma self.sample_rate = sample_rate + self.hop_length = hop_length self.n_octaves = math.ceil(self.n_bins / self.bins_per_octave) - n_filters = min(self.bins_per_octave, self.n_bins) + self.n_filters = min(self.bins_per_octave, self.n_bins) self.frequencies = self.get_frequencies() self.alpha = self.compute_alpha() self.wav_lengths, cutoff_freq = self.wavelet_lengths(window_bandwidth) - nyquist = sample_rate / 2 + nyquist = self.sample_rate / 2 if cutoff_freq > nyquist: raise ValueError( @@ -686,7 +686,7 @@ def __init__( ) # Number of zeros after first 1 in binary gives number of divisions by 2 before number becomes odd - num_hop_downsamples = len(str(bin(hop_length)).split('1')[-1]) + num_hop_downsamples = len(str(bin(self.hop_length)).split('1')[-1]) if num_hop_downsamples > self.n_octaves: warnings.warn( @@ -768,7 +768,30 @@ def forward(self, waveform: Tensor) -> Tensor: Returns: Tensor: VQT spectrogram of size (..., ``n_bins``, time). """ - pass + temp_waveform, temp_sr, temp_hop = waveform, self.sample_rate, self.hop_length + + # Iterate down the octaves + for oct_index in range(self.n_octaves - 1, -1, -1): + print(f"{temp_waveform.shape} -- {temp_sr} -- {temp_hop}") + indices = slice(self.n_filters * oct_index, self.n_filters * (oct_index + 1)) + + octave_freqs = self.frequencies[indices] + octave_alphas = self.alpha[indices] + + # Resampling + if temp_hop % 2 == 0: + temp_waveform = torch.nn.functional.avg_pool1d( + temp_waveform, + kernel_size=2, + stride=2, + ceil_mode=True, + # padding=0, + count_include_pad=False, ## Prevents edge effects + ) + temp_sr //= 2. + temp_hop //= 2 + + return class MFCC(torch.nn.Module): From 517e6c0b787454ba3c3aa84f8fd069bd5b496e63 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Fri, 7 Jun 2024 23:07:30 +0000 Subject: [PATCH 08/44] Wavelet basis function implemented. --- src/torchaudio/transforms/_transforms.py | 83 +++++++++++++++++------- 1 file changed, 61 insertions(+), 22 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 42897cd5bd..ae7d63a701 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -656,9 +656,9 @@ def __init__( hop_length: int = 400, f_min: float = 32.703, n_bins: int = 84, - gamma: Optional[float] = None, + gamma: float = 0., bins_per_octave: int = 12, - window_bandwidth: float = 1.50018310546875, + window_fn: Callable[..., Tensor] = torch.hann_window, ) -> None: super(VQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.VQT") @@ -670,13 +670,16 @@ def __init__( self.sample_rate = sample_rate self.hop_length = hop_length + # Function that creates new window function with length x + self.window_fn = lambda x: window_fn(x) + self.n_octaves = math.ceil(self.n_bins / self.bins_per_octave) self.n_filters = min(self.bins_per_octave, self.n_bins) self.frequencies = self.get_frequencies() self.alpha = self.compute_alpha() - self.wav_lengths, cutoff_freq = self.wavelet_lengths(window_bandwidth) + self.wav_lengths, cutoff_freq = self.wavelet_lengths(self.frequencies, self.sample_rate, self.alpha) nyquist = self.sample_rate / 2 if cutoff_freq > nyquist: @@ -727,38 +730,64 @@ def compute_alpha(self) -> Tensor: return alpha - def wavelet_lengths(self, window_bandwidth: float) -> Tuple[Tensor, float]: + def wavelet_lengths(self, freqs, sr, alpha) -> Tuple[Tensor, float]: r"""Length of each filter in a wavelet basis. Sources: * https://librosa.org/doc/main/_modules/librosa/filters.html - - Args: - window_bandwidth (float, optional): Equivalent noise bandwidth (ENBW) of a window function. (Default: ``1.50018310546875``, or the Hann window value) - + Returns: Tensor: filter lengths. float: cutoff frequency of highest bin. """ - if self.gamma is None: - # Specify gamma_ as: gamma[k] = 24.7 * alpha[k] / 0.108 when not defined - # From: Glasberg, Brian R., and Brian CJ Moore. - # "Derivation of auditory filter shapes from notched-noise data." - # Hearing research 47.1-2 (1990): 103-138. - gamma_ = self.alpha * 24.7 / 0.108 - else: - gamma_ = self.gamma - # We assume filter_scale (librosa param) is 1 - Q = 1. / self.alpha + Q = 1. / alpha - # Output cutoff frequency - cutoff_freq = max(self.frequencies * (1 + 0.5 * window_bandwidth / Q) + 0.5 * gamma_) + # Output upper bound cutoff frequency + # 3.0 > all common window function bandwidths + # https://librosa.org/doc/main/_modules/librosa/filters.html + cutoff_freq = max(freqs * (1 + 0.5 * 3.0 / Q) + 0.5 * self.gamma) # Convert frequencies to filter lengths - lengths = Q * self.sample_rate / (self.frequencies + gamma_ / self.alpha) + lengths = Q * sr / (freqs + self.gamma / alpha) return lengths, cutoff_freq + + def wavelet(self, freqs: Tensor, sr: int, alpha: Tensor) -> Tuple[Tensor, Tensor]: + """Wavelet filterbank constructed from set of center frequencies.""" + # First get filter lengths + lengths, _ = self.wavelet_lengths(freqs=freqs, sr=sr, alpha=alpha) + + # Next power of 2 + pad_to_size = 1<<(int(max(lengths))-1).bit_length() + + filters = None + + for ilen, freq in zip(lengths, freqs): + # Build filter with length ceil(ilen) + t = torch.arange(-ilen // 2, ilen // 2, dtype=float) * 2 * torch.pi * freq / sr + sig = torch.cos(t) + 1j * torch.sin(t) + + # Multiply with window + sig_len = len(sig) + sig = sig * self.window_fn(sig_len) + + # L1 normalize + sig = torch.nn.functional.normalize(sig, p=1., dim=0) + + # Pad signal left and right to correct size + l_pad = math.ceil((pad_to_size - sig_len) / 2) + r_pad = math.floor((pad_to_size - sig_len) / 2) + sig = torch.nn.functional.pad(sig, (l_pad, r_pad), mode='constant', value=0.) + sig = sig.unsqueeze(0) + + if filters is None: + filters = sig + + else: + filters = torch.cat([filters, sig], dim=0) + + return filters, lengths def forward(self, waveform: Tensor) -> Tensor: r""" @@ -772,12 +801,22 @@ def forward(self, waveform: Tensor) -> Tensor: # Iterate down the octaves for oct_index in range(self.n_octaves - 1, -1, -1): - print(f"{temp_waveform.shape} -- {temp_sr} -- {temp_hop}") indices = slice(self.n_filters * oct_index, self.n_filters * (oct_index + 1)) octave_freqs = self.frequencies[indices] octave_alphas = self.alpha[indices] + basis, lengths = self.wavelet(octave_freqs, temp_sr, octave_alphas) + + # STFT matrix + D = torch.stft( + temp_waveform, + n_fft=512, + hop_length=self.hop_length, + pad_mode='constant', + return_complex=True, + ) + # Resampling if temp_hop % 2 == 0: temp_waveform = torch.nn.functional.avg_pool1d( From 1ed2271eeeddfc9e3a8f3603649ac8fda0ff065a Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Sat, 8 Jun 2024 02:15:53 +0000 Subject: [PATCH 09/44] First shot at entire VQT done. --- src/torchaudio/transforms/_transforms.py | 47 ++++++++++++++++++------ 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index ae7d63a701..96a27a491c 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -641,14 +641,15 @@ class VQT(torch.nn.Module): gamma (float, optional): Offset that controls VQT filter lengths. Larger values increase the time resolution at lower frequencies. (Default: ``0``) bins_per_octave (int, optional): Number of bins per octave. (Default: ``12``) - window_bandwidth (float, optional): Equivalent noise bandwidth (ENBW) of a window function. (Default: ``1.50018310546875``, or the Hann window value) + window_fn (Callable[..., Tensor], optional): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) Example >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) >>> transform = transforms.VQT(sample_rate) - >>> vqt_specgram = transform(waveform) # (channel, n_bins, time) + >>> vqt = transform(waveform) # (channel, n_bins, time) """ - __constants__ = ["sample_rate", "hop_length", "f_min", "n_bins", "gamma", "bins_per_octave"] + __constants__ = ["sample_rate", "hop_length", "f_min", "n_bins", "gamma", "bins_per_octave", "window_fn"] def __init__( self, @@ -753,7 +754,7 @@ def wavelet_lengths(self, freqs, sr, alpha) -> Tuple[Tensor, float]: return lengths, cutoff_freq - def wavelet(self, freqs: Tensor, sr: int, alpha: Tensor) -> Tuple[Tensor, Tensor]: + def wavelet(self, freqs: Tensor, sr: int, alpha: Tensor) -> Tuple[Optional[Tensor], Tensor]: """Wavelet filterbank constructed from set of center frequencies.""" # First get filter lengths lengths, _ = self.wavelet_lengths(freqs=freqs, sr=sr, alpha=alpha) @@ -765,7 +766,8 @@ def wavelet(self, freqs: Tensor, sr: int, alpha: Tensor) -> Tuple[Tensor, Tensor for ilen, freq in zip(lengths, freqs): # Build filter with length ceil(ilen) - t = torch.arange(-ilen // 2, ilen // 2, dtype=float) * 2 * torch.pi * freq / sr + # Use float32 in order to output complex(float) numbers later + t = torch.arange(-ilen // 2, ilen // 2, dtype=torch.float32) * 2 * torch.pi * freq / sr sig = torch.cos(t) + 1j * torch.sin(t) # Multiply with window @@ -776,8 +778,8 @@ def wavelet(self, freqs: Tensor, sr: int, alpha: Tensor) -> Tuple[Tensor, Tensor sig = torch.nn.functional.normalize(sig, p=1., dim=0) # Pad signal left and right to correct size - l_pad = math.ceil((pad_to_size - sig_len) / 2) - r_pad = math.floor((pad_to_size - sig_len) / 2) + l_pad = math.floor((pad_to_size - sig_len) / 2) + r_pad = math.ceil((pad_to_size - sig_len) / 2) sig = torch.nn.functional.pad(sig, (l_pad, r_pad), mode='constant', value=0.) sig = sig.unsqueeze(0) @@ -789,7 +791,7 @@ def wavelet(self, freqs: Tensor, sr: int, alpha: Tensor) -> Tuple[Tensor, Tensor return filters, lengths - def forward(self, waveform: Tensor) -> Tensor: + def forward(self, waveform: Tensor) -> Optional[Tensor]: r""" Args: waveform (Tensor): Tensor of audio of dimension (..., time). @@ -798,6 +800,7 @@ def forward(self, waveform: Tensor) -> Tensor: Tensor: VQT spectrogram of size (..., ``n_bins``, time). """ temp_waveform, temp_sr, temp_hop = waveform, self.sample_rate, self.hop_length + vqt = None # Iterate down the octaves for oct_index in range(self.n_octaves - 1, -1, -1): @@ -807,16 +810,32 @@ def forward(self, waveform: Tensor) -> Tensor: octave_alphas = self.alpha[indices] basis, lengths = self.wavelet(octave_freqs, temp_sr, octave_alphas) + n_fft = basis.shape[1] + + # Normalize wrt FFT window length and compute for basis + factors = lengths.unsqueeze(1) / float(n_fft) + basis *= factors + fft_basis = torch.fft.fft(basis, n=n_fft, dim=1)[:, :(n_fft//2) + 1] + fft_basis[:] *= math.sqrt(self.sample_rate / temp_sr) # STFT matrix - D = torch.stft( + dft = torch.stft( temp_waveform, - n_fft=512, - hop_length=self.hop_length, + n_fft=n_fft, + window=self.window_fn(n_fft), + hop_length=temp_hop, pad_mode='constant', return_complex=True, ) + # Compute octave vqt + temp_vqt = torch.matmul(fft_basis.unsqueeze(0), dft) + + if vqt is None: + vqt = temp_vqt + else: + vqt = torch.cat([temp_vqt, vqt], dim=-2) + # Resampling if temp_hop % 2 == 0: temp_waveform = torch.nn.functional.avg_pool1d( @@ -830,7 +849,11 @@ def forward(self, waveform: Tensor) -> Tensor: temp_sr //= 2. temp_hop //= 2 - return + # Scale VQT by square-root of the length of each channel's filter + expanded_lengths = self.wav_lengths.unsqueeze(0).unsqueeze(-1) + vqt /= torch.sqrt(expanded_lengths) + + return vqt class MFCC(torch.nn.Module): From eeba783457003b71eecf24f11f01763a8634d61d Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Sun, 9 Jun 2024 01:01:15 +0000 Subject: [PATCH 10/44] Sparsified rows. --- src/torchaudio/transforms/_transforms.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 96a27a491c..7af5523751 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -643,6 +643,7 @@ class VQT(torch.nn.Module): bins_per_octave (int, optional): Number of bins per octave. (Default: ``12``) window_fn (Callable[..., Tensor], optional): A function to create a window tensor that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) + sparsity_quantile (float, optional): Quantile under which wavelet filter basis magnitudes are zeroed. (Default: ``0.01``) Example >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) @@ -660,6 +661,7 @@ def __init__( gamma: float = 0., bins_per_octave: int = 12, window_fn: Callable[..., Tensor] = torch.hann_window, + sparsity_quantile: float = 0.01, ) -> None: super(VQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.VQT") @@ -670,6 +672,7 @@ def __init__( self.gamma = gamma self.sample_rate = sample_rate self.hop_length = hop_length + self.sparsity_quantile = sparsity_quantile # Function that creates new window function with length x self.window_fn = lambda x: window_fn(x) @@ -790,6 +793,13 @@ def wavelet(self, freqs: Tensor, sr: int, alpha: Tensor) -> Tuple[Optional[Tenso filters = torch.cat([filters, sig], dim=0) return filters, lengths + + def sparsify_basis(self, basis: Tensor) -> Tensor: + """Set basis magnitudes under sparsity quantile to zero.""" + magnitudes = torch.abs(basis) + mag_sums = magnitudes.sum(dim=-1, keepdim=True) + zeroed_values = torch.ge(magnitudes, mag_sums * self.sparsity_quantile) + return basis * zeroed_values def forward(self, waveform: Tensor) -> Optional[Tensor]: r""" @@ -816,6 +826,7 @@ def forward(self, waveform: Tensor) -> Optional[Tensor]: factors = lengths.unsqueeze(1) / float(n_fft) basis *= factors fft_basis = torch.fft.fft(basis, n=n_fft, dim=1)[:, :(n_fft//2) + 1] + fft_basis = self.sparsify_basis(fft_basis) fft_basis[:] *= math.sqrt(self.sample_rate / temp_sr) # STFT matrix From 76afee0a47dc1fa93ba3a21617c1e4acd0e7fe69 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Sun, 9 Jun 2024 01:39:49 +0000 Subject: [PATCH 11/44] Removed sparsity and matched stft to librosa vqt implementation. --- src/torchaudio/transforms/_transforms.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 7af5523751..5193a31797 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -643,7 +643,6 @@ class VQT(torch.nn.Module): bins_per_octave (int, optional): Number of bins per octave. (Default: ``12``) window_fn (Callable[..., Tensor], optional): A function to create a window tensor that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) - sparsity_quantile (float, optional): Quantile under which wavelet filter basis magnitudes are zeroed. (Default: ``0.01``) Example >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) @@ -661,7 +660,6 @@ def __init__( gamma: float = 0., bins_per_octave: int = 12, window_fn: Callable[..., Tensor] = torch.hann_window, - sparsity_quantile: float = 0.01, ) -> None: super(VQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.VQT") @@ -672,7 +670,6 @@ def __init__( self.gamma = gamma self.sample_rate = sample_rate self.hop_length = hop_length - self.sparsity_quantile = sparsity_quantile # Function that creates new window function with length x self.window_fn = lambda x: window_fn(x) @@ -793,13 +790,6 @@ def wavelet(self, freqs: Tensor, sr: int, alpha: Tensor) -> Tuple[Optional[Tenso filters = torch.cat([filters, sig], dim=0) return filters, lengths - - def sparsify_basis(self, basis: Tensor) -> Tensor: - """Set basis magnitudes under sparsity quantile to zero.""" - magnitudes = torch.abs(basis) - mag_sums = magnitudes.sum(dim=-1, keepdim=True) - zeroed_values = torch.ge(magnitudes, mag_sums * self.sparsity_quantile) - return basis * zeroed_values def forward(self, waveform: Tensor) -> Optional[Tensor]: r""" @@ -826,14 +816,13 @@ def forward(self, waveform: Tensor) -> Optional[Tensor]: factors = lengths.unsqueeze(1) / float(n_fft) basis *= factors fft_basis = torch.fft.fft(basis, n=n_fft, dim=1)[:, :(n_fft//2) + 1] - fft_basis = self.sparsify_basis(fft_basis) fft_basis[:] *= math.sqrt(self.sample_rate / temp_sr) # STFT matrix dft = torch.stft( temp_waveform, n_fft=n_fft, - window=self.window_fn(n_fft), + window=torch.ones(n_fft), hop_length=temp_hop, pad_mode='constant', return_complex=True, From 86d462c664b16963d3caea09c7111c10c2adad42 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Sun, 9 Jun 2024 12:05:14 +0000 Subject: [PATCH 12/44] Fixing dot product operation. --- src/torchaudio/transforms/_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 5193a31797..1231898706 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -829,7 +829,7 @@ def forward(self, waveform: Tensor) -> Optional[Tensor]: ) # Compute octave vqt - temp_vqt = torch.matmul(fft_basis.unsqueeze(0), dft) + temp_vqt = torch.einsum('ij,...jk->...ik', fft_basis, dft) if vqt is None: vqt = temp_vqt From 7aa9b437b67f03fc417c79870b0b68c590cb060b Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Sun, 9 Jun 2024 19:24:34 +0000 Subject: [PATCH 13/44] Fixed resampling. --- src/torchaudio/transforms/_transforms.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 1231898706..7608575e3b 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -643,13 +643,15 @@ class VQT(torch.nn.Module): bins_per_octave (int, optional): Number of bins per octave. (Default: ``12``) window_fn (Callable[..., Tensor], optional): A function to create a window tensor that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) + resampling_method (str, optional): The resampling method to use. + Options: [``sinc_interp_hann``, ``sinc_interp_kaiser``] (Default: ``"sinc_interp_hann"``) Example >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) >>> transform = transforms.VQT(sample_rate) >>> vqt = transform(waveform) # (channel, n_bins, time) """ - __constants__ = ["sample_rate", "hop_length", "f_min", "n_bins", "gamma", "bins_per_octave", "window_fn"] + __constants__ = ["sample_rate", "hop_length", "f_min", "n_bins", "gamma", "bins_per_octave", "window_fn", "resampling_method"] def __init__( self, @@ -660,6 +662,7 @@ def __init__( gamma: float = 0., bins_per_octave: int = 12, window_fn: Callable[..., Tensor] = torch.hann_window, + resampling_method: str = "sinc_interp_hann", ) -> None: super(VQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.VQT") @@ -704,6 +707,8 @@ def __init__( "Consider resampling your signal to a lower sample rate or increasing the number of bins before VQT computation for more accurate results." ) + self.resample = Resample(2, 1, resampling_method) + def get_frequencies(self) -> Tensor: r"""Return a set of frequencies that assumes an equal temperament tuning system.""" ratios = 2.0 ** (torch.arange(0, self.bins_per_octave * self.n_octaves, dtype=float) / self.bins_per_octave) @@ -838,15 +843,8 @@ def forward(self, waveform: Tensor) -> Optional[Tensor]: # Resampling if temp_hop % 2 == 0: - temp_waveform = torch.nn.functional.avg_pool1d( - temp_waveform, - kernel_size=2, - stride=2, - ceil_mode=True, - # padding=0, - count_include_pad=False, ## Prevents edge effects - ) - temp_sr //= 2. + temp_waveform = self.resample(temp_waveform) + temp_sr /= 2. temp_hop //= 2 # Scale VQT by square-root of the length of each channel's filter From b6f8b3cfa02f58f7d7383fde429e65cd11fd5359 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Tue, 11 Jun 2024 00:10:11 +0000 Subject: [PATCH 14/44] Object-oriented optimizations! --- src/torchaudio/transforms/_transforms.py | 128 ++++++++++++----------- 1 file changed, 66 insertions(+), 62 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 7608575e3b..c1bb39848c 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -639,7 +639,7 @@ class VQT(torch.nn.Module): f_min (float, optional): Minimum frequency, which corresponds to first note. (Default: ``32.703``, or the frequency of C1 in Hz) n_bins (int, optional): Number of VQT frequency bins, starting at ``f_min``. (Default: ``84``) gamma (float, optional): Offset that controls VQT filter lengths. Larger values - increase the time resolution at lower frequencies. (Default: ``0``) + increase the time resolution at lower frequencies. (Default: ``0.``) bins_per_octave (int, optional): Number of bins per octave. (Default: ``12``) window_fn (Callable[..., Tensor], optional): A function to create a window tensor that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) @@ -649,7 +649,7 @@ class VQT(torch.nn.Module): Example >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) >>> transform = transforms.VQT(sample_rate) - >>> vqt = transform(waveform) # (channel, n_bins, time) + >>> vqt = transform(waveform) # (..., n_bins, time) """ __constants__ = ["sample_rate", "hop_length", "f_min", "n_bins", "gamma", "bins_per_octave", "window_fn", "resampling_method"] @@ -667,24 +667,14 @@ def __init__( super(VQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.VQT") - self.n_bins = n_bins - self.bins_per_octave = bins_per_octave - self.f_min = f_min - self.gamma = gamma - self.sample_rate = sample_rate - self.hop_length = hop_length - - # Function that creates new window function with length x - self.window_fn = lambda x: window_fn(x) + n_octaves = math.ceil(n_bins / bins_per_octave) + n_filters = min(bins_per_octave, n_bins) - self.n_octaves = math.ceil(self.n_bins / self.bins_per_octave) - self.n_filters = min(self.bins_per_octave, self.n_bins) + frequencies = self.get_frequencies(bins_per_octave, n_octaves, f_min, n_bins) - self.frequencies = self.get_frequencies() - - self.alpha = self.compute_alpha() - self.wav_lengths, cutoff_freq = self.wavelet_lengths(self.frequencies, self.sample_rate, self.alpha) - nyquist = self.sample_rate / 2 + alpha = self.compute_alpha(frequencies, bins_per_octave, n_bins) + wav_lengths, cutoff_freq = self.wavelet_lengths(frequencies, sample_rate, alpha, gamma) + nyquist = sample_rate / 2 if cutoff_freq > nyquist: raise ValueError( @@ -693,12 +683,12 @@ def __init__( ) # Number of zeros after first 1 in binary gives number of divisions by 2 before number becomes odd - num_hop_downsamples = len(str(bin(self.hop_length)).split('1')[-1]) + num_hop_downsamples = len(str(bin(hop_length)).split('1')[-1]) - if num_hop_downsamples > self.n_octaves: + if num_hop_downsamples > n_octaves: warnings.warn( f"Hop length can be divided {num_hop_downsamples} times by 2 before becoming odd. " - f"The VQT is however being computed for {self.n_octaves} octaves. Consider lowering the hop length or increasing the number of bins for more accurate results." + f"The VQT is however being computed for {n_octaves} octaves. Consider lowering the hop length or increasing the number of bins for more accurate results." ) if nyquist / cutoff_freq > 4: @@ -709,34 +699,66 @@ def __init__( self.resample = Resample(2, 1, resampling_method) - def get_frequencies(self) -> Tensor: + # Pre-compute wavelet filter bases + self.forward_params = [] + temp_sr, temp_hop = sample_rate, hop_length + register_index = 0 + + for oct_index in range(n_octaves - 1, -1, -1): + indices = slice(n_filters * oct_index, n_filters * (oct_index + 1)) + + octave_freqs = frequencies[indices] + octave_alphas = alpha[indices] + + basis, lengths = self.wavelet(octave_freqs, temp_sr, octave_alphas, gamma, window_fn) + n_fft = basis.shape[1] + + factors = lengths.unsqueeze(1) / float(n_fft) + basis *= factors + + fft_basis = torch.fft.fft(basis, n=n_fft, dim=1)[:, :(n_fft//2) + 1] + fft_basis[:] *= math.sqrt(sample_rate / temp_sr) + + self.register_buffer(f"fft_basis_{register_index}", fft_basis) + self.forward_params.append((temp_hop, n_fft)) + + register_index += 1 + + if temp_hop % 2 == 0: + temp_sr /= 2. + temp_hop //= 2 + + self.register_buffer("expanded_lengths", wav_lengths.unsqueeze(0).unsqueeze(-1)) + self.ones = lambda x: torch.ones(x, device=self.expanded_lengths.device) + + def get_frequencies(self, bins_per_octave: int, n_octaves: int, f_min: float, n_bins: int) -> Tensor: r"""Return a set of frequencies that assumes an equal temperament tuning system.""" - ratios = 2.0 ** (torch.arange(0, self.bins_per_octave * self.n_octaves, dtype=float) / self.bins_per_octave) - return self.f_min * ratios[:self.n_bins] + ratios = 2.0 ** (torch.arange(0, bins_per_octave * n_octaves, dtype=float) / bins_per_octave) + return f_min * ratios[:n_bins] - def compute_alpha(self) -> Tensor: + def compute_alpha(self, freqs: Tensor, bins_per_octave: int, n_bins: int) -> Tensor: r"""Compute relative bandwidths for specified frequencies.""" - if self.n_bins > 1: + if n_bins > 1: # Approximate local octave resolution around each frequency - bandpass_octave = torch.empty_like(self.frequencies) - log_frequencies = torch.log2(self.frequencies) + bandpass_octave = torch.empty_like(freqs) + log_freqs = torch.log2(freqs) # Reflect at the lowest and highest frequencies - bandpass_octave[0] = 1 / (log_frequencies[1] - log_frequencies[0]) - bandpass_octave[-1] = 1 / (log_frequencies[-1] - log_frequencies[-2]) + bandpass_octave[0] = 1 / (log_freqs[1] - log_freqs[0]) + bandpass_octave[-1] = 1 / (log_freqs[-1] - log_freqs[-2]) # Centered difference - bandpass_octave[1:-1] = 2 / (log_frequencies[2:] - log_frequencies[:-2]) + bandpass_octave[1:-1] = 2 / (log_freqs[2:] - log_freqs[:-2]) alpha = (2. ** (2 / bandpass_octave) - 1) / (2. ** (2 / bandpass_octave) + 1) else: # Special case when single basis frequency is used - rel_band_coeff = 2. ** (1. / self.bins_per_octave) + rel_band_coeff = 2. ** (1. / bins_per_octave) alpha = torch.atleast_1d((rel_band_coeff**2 - 1) / (rel_band_coeff**2 + 1)) return alpha - def wavelet_lengths(self, freqs, sr, alpha) -> Tuple[Tensor, float]: + def wavelet_lengths(self, freqs: Tensor, sr: int, alpha: Tensor, gamma: float) -> Tuple[Tensor, float]: r"""Length of each filter in a wavelet basis. Sources: @@ -752,17 +774,17 @@ def wavelet_lengths(self, freqs, sr, alpha) -> Tuple[Tensor, float]: # Output upper bound cutoff frequency # 3.0 > all common window function bandwidths # https://librosa.org/doc/main/_modules/librosa/filters.html - cutoff_freq = max(freqs * (1 + 0.5 * 3.0 / Q) + 0.5 * self.gamma) + cutoff_freq = max(freqs * (1 + 0.5 * 3.0 / Q) + 0.5 * gamma) # Convert frequencies to filter lengths - lengths = Q * sr / (freqs + self.gamma / alpha) + lengths = Q * sr / (freqs + gamma / alpha) return lengths, cutoff_freq - def wavelet(self, freqs: Tensor, sr: int, alpha: Tensor) -> Tuple[Optional[Tensor], Tensor]: + def wavelet(self, freqs: Tensor, sr: int, alpha: Tensor, gamma: float, window_fn: Callable[..., Tensor]) -> Tuple[Optional[Tensor], Tensor]: """Wavelet filterbank constructed from set of center frequencies.""" # First get filter lengths - lengths, _ = self.wavelet_lengths(freqs=freqs, sr=sr, alpha=alpha) + lengths, _ = self.wavelet_lengths(freqs=freqs, sr=sr, alpha=alpha, gamma=gamma) # Next power of 2 pad_to_size = 1<<(int(max(lengths))-1).bit_length() @@ -777,7 +799,7 @@ def wavelet(self, freqs: Tensor, sr: int, alpha: Tensor) -> Tuple[Optional[Tenso # Multiply with window sig_len = len(sig) - sig = sig * self.window_fn(sig_len) + sig = sig * window_fn(sig_len) # L1 normalize sig = torch.nn.functional.normalize(sig, p=1., dim=0) @@ -804,37 +826,22 @@ def forward(self, waveform: Tensor) -> Optional[Tensor]: Returns: Tensor: VQT spectrogram of size (..., ``n_bins``, time). """ - temp_waveform, temp_sr, temp_hop = waveform, self.sample_rate, self.hop_length vqt = None # Iterate down the octaves - for oct_index in range(self.n_octaves - 1, -1, -1): - indices = slice(self.n_filters * oct_index, self.n_filters * (oct_index + 1)) - - octave_freqs = self.frequencies[indices] - octave_alphas = self.alpha[indices] - - basis, lengths = self.wavelet(octave_freqs, temp_sr, octave_alphas) - n_fft = basis.shape[1] - - # Normalize wrt FFT window length and compute for basis - factors = lengths.unsqueeze(1) / float(n_fft) - basis *= factors - fft_basis = torch.fft.fft(basis, n=n_fft, dim=1)[:, :(n_fft//2) + 1] - fft_basis[:] *= math.sqrt(self.sample_rate / temp_sr) - + for register_index, (temp_hop, n_fft) in enumerate(self.forward_params): # STFT matrix dft = torch.stft( - temp_waveform, + waveform, n_fft=n_fft, - window=torch.ones(n_fft), hop_length=temp_hop, + window=self.ones(n_fft), pad_mode='constant', return_complex=True, ) # Compute octave vqt - temp_vqt = torch.einsum('ij,...jk->...ik', fft_basis, dft) + temp_vqt = torch.einsum('ij,...jk->...ik', getattr(self, f"fft_basis_{register_index}"), dft) if vqt is None: vqt = temp_vqt @@ -843,13 +850,10 @@ def forward(self, waveform: Tensor) -> Optional[Tensor]: # Resampling if temp_hop % 2 == 0: - temp_waveform = self.resample(temp_waveform) - temp_sr /= 2. - temp_hop //= 2 + waveform = self.resample(waveform) # Scale VQT by square-root of the length of each channel's filter - expanded_lengths = self.wav_lengths.unsqueeze(0).unsqueeze(-1) - vqt /= torch.sqrt(expanded_lengths) + vqt /= torch.sqrt(self.expanded_lengths) return vqt From 09ffe6c3b6e959bf8ec99f4829a9539298a90992 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Wed, 12 Jun 2024 16:15:50 +0000 Subject: [PATCH 15/44] CQT implementation. --- src/torchaudio/transforms/_transforms.py | 88 +++++++++++++++++++++--- 1 file changed, 77 insertions(+), 11 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index c1bb39848c..0d30361460 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -622,7 +622,7 @@ def forward(self, waveform: Tensor) -> Tensor: class VQT(torch.nn.Module): - r"""Create variable Q-transform for a raw audio signal. + r"""Create the variable Q-transform for a raw audio signal. .. devices:: CPU CUDA @@ -701,7 +701,7 @@ def __init__( # Pre-compute wavelet filter bases self.forward_params = [] - temp_sr, temp_hop = sample_rate, hop_length + temp_sr, temp_hop = float(sample_rate), hop_length register_index = 0 for oct_index in range(n_octaves - 1, -1, -1): @@ -758,7 +758,7 @@ def compute_alpha(self, freqs: Tensor, bins_per_octave: int, n_bins: int) -> Ten return alpha - def wavelet_lengths(self, freqs: Tensor, sr: int, alpha: Tensor, gamma: float) -> Tuple[Tensor, float]: + def wavelet_lengths(self, freqs: Tensor, sr: float, alpha: Tensor, gamma: float) -> Tuple[Tensor, float]: r"""Length of each filter in a wavelet basis. Sources: @@ -781,7 +781,7 @@ def wavelet_lengths(self, freqs: Tensor, sr: int, alpha: Tensor, gamma: float) - return lengths, cutoff_freq - def wavelet(self, freqs: Tensor, sr: int, alpha: Tensor, gamma: float, window_fn: Callable[..., Tensor]) -> Tuple[Optional[Tensor], Tensor]: + def wavelet(self, freqs: Tensor, sr: float, alpha: Tensor, gamma: float, window_fn: Callable[..., Tensor]) -> Tuple[Tensor, Tensor]: """Wavelet filterbank constructed from set of center frequencies.""" # First get filter lengths lengths, _ = self.wavelet_lengths(freqs=freqs, sr=sr, alpha=alpha, gamma=gamma) @@ -789,9 +789,9 @@ def wavelet(self, freqs: Tensor, sr: int, alpha: Tensor, gamma: float, window_fn # Next power of 2 pad_to_size = 1<<(int(max(lengths))-1).bit_length() - filters = None + filters: Tensor - for ilen, freq in zip(lengths, freqs): + for index, (ilen, freq) in enumerate(zip(lengths, freqs)): # Build filter with length ceil(ilen) # Use float32 in order to output complex(float) numbers later t = torch.arange(-ilen // 2, ilen // 2, dtype=torch.float32) * 2 * torch.pi * freq / sr @@ -810,15 +810,14 @@ def wavelet(self, freqs: Tensor, sr: int, alpha: Tensor, gamma: float, window_fn sig = torch.nn.functional.pad(sig, (l_pad, r_pad), mode='constant', value=0.) sig = sig.unsqueeze(0) - if filters is None: + if index == 0: filters = sig - else: filters = torch.cat([filters, sig], dim=0) return filters, lengths - def forward(self, waveform: Tensor) -> Optional[Tensor]: + def forward(self, waveform: Tensor) -> Tensor: r""" Args: waveform (Tensor): Tensor of audio of dimension (..., time). @@ -826,7 +825,8 @@ def forward(self, waveform: Tensor) -> Optional[Tensor]: Returns: Tensor: VQT spectrogram of size (..., ``n_bins``, time). """ - vqt = None + # Mypy type + vqt: torch.Tensor # Iterate down the octaves for register_index, (temp_hop, n_fft) in enumerate(self.forward_params): @@ -843,7 +843,7 @@ def forward(self, waveform: Tensor) -> Optional[Tensor]: # Compute octave vqt temp_vqt = torch.einsum('ij,...jk->...ik', getattr(self, f"fft_basis_{register_index}"), dft) - if vqt is None: + if register_index == 0: vqt = temp_vqt else: vqt = torch.cat([temp_vqt, vqt], dim=-2) @@ -858,6 +858,72 @@ def forward(self, waveform: Tensor) -> Optional[Tensor]: return vqt +class CQT(torch.nn.Module): + r"""Create the constant Q-transform for a raw audio signal. + + .. devices:: CPU CUDA + + .. properties:: Autograd TorchScript + + Sources + * https://librosa.org/doc/main/_modules/librosa/core/constantq.html + * https://www.aes.org/e-lib/online/browse.cfm?elib=17112 + * https://newt.phys.unsw.edu.au/jw/notes.html + + Args: + sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) + hop_length (int, optional): Length of hop between CQT windows. (Default: ``400``) + f_min (float, optional): Minimum frequency, which corresponds to first note. (Default: ``32.703``, or the frequency of C1 in Hz) + n_bins (int, optional): Number of CQT frequency bins, starting at ``f_min``. (Default: ``84``) + bins_per_octave (int, optional): Number of bins per octave. (Default: ``12``) + window_fn (Callable[..., Tensor], optional): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) + resampling_method (str, optional): The resampling method to use. + Options: [``sinc_interp_hann``, ``sinc_interp_kaiser``] (Default: ``"sinc_interp_hann"``) + + Example + >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) + >>> transform = transforms.CQT(sample_rate) + >>> cqt = transform(waveform) # (..., n_bins, time) + """ + __constants__ = ["sample_rate", "hop_length", "f_min", "n_bins", "bins_per_octave", "window_fn", "resampling_method"] + + def __init__( + self, + sample_rate: int = 16000, + hop_length: int = 400, + f_min: float = 32.703, + n_bins: int = 84, + bins_per_octave: int = 12, + window_fn: Callable[..., Tensor] = torch.hann_window, + resampling_method: str = "sinc_interp_hann", + ) -> None: + super(CQT, self).__init__() + torch._C._log_api_usage_once("torchaudio.transforms.CQT") + + # CQT corresponds to a VQT with gamma set to 0 + self.transform = VQT( + sample_rate=sample_rate, + hop_length=hop_length, + f_min=f_min, + n_bins=n_bins, + gamma=0., + bins_per_octave=bins_per_octave, + window_fn=window_fn, + resampling_method=resampling_method, + ) + + def forward(self, waveform: Tensor) -> Tensor: + r""" + Args: + waveform (Tensor): Tensor of audio of dimension (..., time). + + Returns: + Tensor: CQT spectrogram of size (..., ``n_bins``, time). + """ + return self.transform(waveform) + + class MFCC(torch.nn.Module): r"""Create the Mel-frequency cepstrum coefficients from an audio signal. From 1529f0a89d53f0805584776cf23a873ff7459d19 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Wed, 12 Jun 2024 23:51:44 +0000 Subject: [PATCH 16/44] Splitting functions from classes to be used by iCQT. --- src/torchaudio/functional/__init__.py | 8 ++ src/torchaudio/functional/functional.py | 142 ++++++++++++++++++++++- src/torchaudio/transforms/__init__.py | 4 + src/torchaudio/transforms/_transforms.py | 128 ++++---------------- 4 files changed, 173 insertions(+), 109 deletions(-) diff --git a/src/torchaudio/functional/__init__.py b/src/torchaudio/functional/__init__.py index b866977c67..6c37ed4e41 100644 --- a/src/torchaudio/functional/__init__.py +++ b/src/torchaudio/functional/__init__.py @@ -37,6 +37,7 @@ edit_distance, fftconvolve, frechet_distance, + frequency_set, griffinlim, inverse_spectrogram, linear_fbanks, @@ -52,6 +53,7 @@ pitch_shift, preemphasis, psd, + relative_bandwidths, resample, rnnt_loss, rtf_evd, @@ -60,6 +62,8 @@ spectral_centroid, spectrogram, speed, + wavelet_fbank, + wavelet_lengths, ) __all__ = [ @@ -124,4 +128,8 @@ "preemphasis", "deemphasis", "frechet_distance", + "frequency_set", + "relative_bandwidths", + "wavelet_lengths", + "wavelet_fbank", ] diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index af34e707e5..772e20496c 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -4,7 +4,7 @@ import tempfile import warnings from collections.abc import Sequence -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Callable import torch import torchaudio @@ -51,6 +51,10 @@ "speed", "preemphasis", "deemphasis", + "frequency_set", + "relative_bandwidths", + "wavelet_lengths", + "wavelet_fbank", ] @@ -2533,3 +2537,139 @@ def frechet_distance(mu_x, sigma_x, mu_y, sigma_y): b = sigma_x.trace() + sigma_y.trace() c = torch.linalg.eigvals(sigma_x @ sigma_y).sqrt().real.sum() return a + b - 2 * c + + +def frequency_set(f_min: float, n_bins: int, bins_per_octave: int) -> Tuple[Tensor, int]: + r"""Return a set of frequencies that assumes an equal temperament tuning system. + + Adapted from librosa: https://librosa.org/doc/main/generated/librosa.interval_frequencies.html + + Args: + f_min (float): minimum frequency in Hz. + n_bins (int): number of frequency bins. + bins_per_octave (int): number of bins per octave. + + Returns: + torch.Tensor: frequencies. + int: number of octaves + """ + n_octaves = math.ceil(n_bins / bins_per_octave) + ratios = 2.0 ** (torch.arange(0, bins_per_octave * n_octaves, dtype=float) / bins_per_octave) + return f_min * ratios[:n_bins], n_octaves + + +def relative_bandwidths(freqs: Tensor, n_bins: int, bins_per_octave: int) -> Tensor: + r"""Compute relative bandwidths for specified frequencies. + + Adapted from librosa: https://librosa.org/doc/main/generated/librosa.filters.wavelet_lengths.html + + Args: + freqs (Tensor): set of frequencies. + n_bins (int): number of frequency bins. + bins_per_octave (int): number of bins per octave. + + Returns: + torch.Tensor: relative bandwidths for set of frequencies. + """ + if n_bins > 1: + # Approximate local octave resolution around each frequency + bandpass_octave = torch.empty_like(freqs) + log_freqs = torch.log2(freqs) + + # Reflect at the lowest and highest frequencies + bandpass_octave[0] = 1 / (log_freqs[1] - log_freqs[0]) + bandpass_octave[-1] = 1 / (log_freqs[-1] - log_freqs[-2]) + + # Centered difference + bandpass_octave[1:-1] = 2 / (log_freqs[2:] - log_freqs[:-2]) + + # Relative bandwidths + alpha = (2. ** (2 / bandpass_octave) - 1) / (2. ** (2 / bandpass_octave) + 1) + else: + # Special case when single basis frequency is used + rel_band_coeff = 2. ** (1. / bins_per_octave) + alpha = torch.atleast_1d((rel_band_coeff**2 - 1) / (rel_band_coeff**2 + 1)) + + return alpha + + +def wavelet_lengths(freqs: Tensor, sr: float, alpha: Tensor, gamma: float) -> Tuple[Tensor, float]: + r"""Length of each filter in a wavelet basis. + + Source: + * https://librosa.org/doc/main/generated/librosa.filters.wavelet_lengths.html + + Args: + freqs (Tensor): set of frequencies. + sr (float): sample rate. + alpha (Tensor): relative bandwidths for set of frequencies. + gamma (float): bandwidth offset for filter length computation. + + Returns: + Tensor: filter lengths. + float: cutoff frequency of highest bin. + """ + # We assume filter_scale (librosa param) is 1 + Q = 1. / alpha + + # Output upper bound cutoff frequency + # 3.0 > all common window function bandwidths + # https://librosa.org/doc/main/_modules/librosa/filters.html + cutoff_freq = max(freqs * (1 + 0.5 * 3.0 / Q) + 0.5 * gamma) + + # Convert frequencies to filter lengths + lengths = Q * sr / (freqs + gamma / alpha) + + return lengths, cutoff_freq + + +def wavelet_fbank(freqs: Tensor, sr: float, alpha: Tensor, gamma: float, window_fn: Callable[..., Tensor]) -> Tuple[Tensor, Tensor]: + r"""Wavelet filterbank constructed from set of center frequencies. + + Source: + * https://librosa.org/doc/main/generated/librosa.filters.wavelet.html + + Args: + freqs (Tensor): set of frequencies. + sr (float): sample rate. + alpha (Tensor): relative bandwidths for set of frequencies. + gamma (float): bandwidth offset for filter length computation. + window_fn (Callable[..., Tensor]): a function to create a window tensor. + + Returns: + Tensor: wavelet filters. + Tensor: wavelet filter lengths. + """ + # First get filter lengths + lengths, _ = wavelet_lengths(freqs=freqs, sr=sr, alpha=alpha, gamma=gamma) + + # Next power of 2 + pad_to_size = 1<<(int(max(lengths))-1).bit_length() + + filters: Tensor + + for index, (ilen, freq) in enumerate(zip(lengths, freqs)): + # Build filter with length ceil(ilen) + # Use float32 in order to output complex(float) numbers later + t = torch.arange(-ilen // 2, ilen // 2, dtype=torch.float32) * 2 * torch.pi * freq / sr + sig = torch.cos(t) + 1j * torch.sin(t) + + # Multiply with window + sig_len = len(sig) + sig = sig * window_fn(sig_len) + + # L1 normalize + sig = torch.nn.functional.normalize(sig, p=1., dim=0) + + # Pad signal left and right to correct size + l_pad = math.floor((pad_to_size - sig_len) / 2) + r_pad = math.ceil((pad_to_size - sig_len) / 2) + sig = torch.nn.functional.pad(sig, (l_pad, r_pad), mode='constant', value=0.) + sig = sig.unsqueeze(0) + + if index == 0: + filters = sig + else: + filters = torch.cat([filters, sig], dim=0) + + return filters, lengths diff --git a/src/torchaudio/transforms/__init__.py b/src/torchaudio/transforms/__init__.py index 1fe77865a9..af47595fb3 100644 --- a/src/torchaudio/transforms/__init__.py +++ b/src/torchaudio/transforms/__init__.py @@ -4,6 +4,7 @@ AmplitudeToDB, ComputeDeltas, Convolve, + CQT, Deemphasis, Fade, FFTConvolve, @@ -32,6 +33,7 @@ TimeStretch, Vad, Vol, + VQT, ) @@ -40,6 +42,7 @@ "AmplitudeToDB", "ComputeDeltas", "Convolve", + "CQT", "Deemphasis", "Fade", "FFTConvolve", @@ -72,4 +75,5 @@ "TimeStretch", "Vad", "Vol", + "VQT", ] diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 0d30361460..9bfc714555 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -629,7 +629,7 @@ class VQT(torch.nn.Module): .. properties:: Autograd TorchScript Sources - * https://librosa.org/doc/main/_modules/librosa/core/constantq.html + * https://librosa.org/doc/main/generated/librosa.vqt.html * https://www.aes.org/e-lib/online/browse.cfm?elib=17112 * https://newt.phys.unsw.edu.au/jw/notes.html @@ -667,13 +667,18 @@ def __init__( super(VQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.VQT") - n_octaves = math.ceil(n_bins / bins_per_octave) n_filters = min(bins_per_octave, n_bins) + frequencies, n_octaves = F.frequency_set(f_min, n_bins, bins_per_octave) + alpha = F.relative_bandwidths(frequencies, n_bins, bins_per_octave) + freq_lengths, cutoff_freq = F.wavelet_lengths(frequencies, sample_rate, alpha, gamma) - frequencies = self.get_frequencies(bins_per_octave, n_octaves, f_min, n_bins) + self.resample = Resample(2, 1, resampling_method) + self.register_buffer("expanded_lengths", freq_lengths.unsqueeze(0).unsqueeze(-1)) + self.ones = lambda x: torch.ones(x, device=self.expanded_lengths.device) - alpha = self.compute_alpha(frequencies, bins_per_octave, n_bins) - wav_lengths, cutoff_freq = self.wavelet_lengths(frequencies, sample_rate, alpha, gamma) + # Generate errors or warnings if needed + # Number of divisions by 2 before number becomes odd + num_hop_downsamples = len(str(bin(hop_length)).split('1')[-1]) nyquist = sample_rate / 2 if cutoff_freq > nyquist: @@ -681,41 +686,38 @@ def __init__( f"Maximum bin cutoff frequency is {cutoff_freq} and superior to the Nyquist frequency {nyquist}. " "Try to reduce the number of frequency bins." ) - - # Number of zeros after first 1 in binary gives number of divisions by 2 before number becomes odd - num_hop_downsamples = len(str(bin(hop_length)).split('1')[-1]) - if num_hop_downsamples > n_octaves: warnings.warn( f"Hop length can be divided {num_hop_downsamples} times by 2 before becoming odd. " f"The VQT is however being computed for {n_octaves} octaves. Consider lowering the hop length or increasing the number of bins for more accurate results." ) - if nyquist / cutoff_freq > 4: warnings.warn( f"The Nyquist frequency {nyquist} is significantly higher than the highest filter's cutoff frequency {cutoff_freq}. " "Consider resampling your signal to a lower sample rate or increasing the number of bins before VQT computation for more accurate results." ) - self.resample = Resample(2, 1, resampling_method) - - # Pre-compute wavelet filter bases + # Now pre-compute what's needed for forward loop self.forward_params = [] temp_sr, temp_hop = float(sample_rate), hop_length register_index = 0 for oct_index in range(n_octaves - 1, -1, -1): + # Slice out correct octave indices = slice(n_filters * oct_index, n_filters * (oct_index + 1)) octave_freqs = frequencies[indices] octave_alphas = alpha[indices] - basis, lengths = self.wavelet(octave_freqs, temp_sr, octave_alphas, gamma, window_fn) + # Compute wavelet filterbanks + basis, lengths = F.wavelet_fbank(octave_freqs, temp_sr, octave_alphas, gamma, window_fn) n_fft = basis.shape[1] + # Normalize wrt FFT window length factors = lengths.unsqueeze(1) / float(n_fft) basis *= factors + # Wavelet basis FFT fft_basis = torch.fft.fft(basis, n=n_fft, dim=1)[:, :(n_fft//2) + 1] fft_basis[:] *= math.sqrt(sample_rate / temp_sr) @@ -727,95 +729,6 @@ def __init__( if temp_hop % 2 == 0: temp_sr /= 2. temp_hop //= 2 - - self.register_buffer("expanded_lengths", wav_lengths.unsqueeze(0).unsqueeze(-1)) - self.ones = lambda x: torch.ones(x, device=self.expanded_lengths.device) - - def get_frequencies(self, bins_per_octave: int, n_octaves: int, f_min: float, n_bins: int) -> Tensor: - r"""Return a set of frequencies that assumes an equal temperament tuning system.""" - ratios = 2.0 ** (torch.arange(0, bins_per_octave * n_octaves, dtype=float) / bins_per_octave) - return f_min * ratios[:n_bins] - - def compute_alpha(self, freqs: Tensor, bins_per_octave: int, n_bins: int) -> Tensor: - r"""Compute relative bandwidths for specified frequencies.""" - if n_bins > 1: - # Approximate local octave resolution around each frequency - bandpass_octave = torch.empty_like(freqs) - log_freqs = torch.log2(freqs) - - # Reflect at the lowest and highest frequencies - bandpass_octave[0] = 1 / (log_freqs[1] - log_freqs[0]) - bandpass_octave[-1] = 1 / (log_freqs[-1] - log_freqs[-2]) - - # Centered difference - bandpass_octave[1:-1] = 2 / (log_freqs[2:] - log_freqs[:-2]) - - alpha = (2. ** (2 / bandpass_octave) - 1) / (2. ** (2 / bandpass_octave) + 1) - else: - # Special case when single basis frequency is used - rel_band_coeff = 2. ** (1. / bins_per_octave) - alpha = torch.atleast_1d((rel_band_coeff**2 - 1) / (rel_band_coeff**2 + 1)) - - return alpha - - def wavelet_lengths(self, freqs: Tensor, sr: float, alpha: Tensor, gamma: float) -> Tuple[Tensor, float]: - r"""Length of each filter in a wavelet basis. - - Sources: - * https://librosa.org/doc/main/_modules/librosa/filters.html - - Returns: - Tensor: filter lengths. - float: cutoff frequency of highest bin. - """ - # We assume filter_scale (librosa param) is 1 - Q = 1. / alpha - - # Output upper bound cutoff frequency - # 3.0 > all common window function bandwidths - # https://librosa.org/doc/main/_modules/librosa/filters.html - cutoff_freq = max(freqs * (1 + 0.5 * 3.0 / Q) + 0.5 * gamma) - - # Convert frequencies to filter lengths - lengths = Q * sr / (freqs + gamma / alpha) - - return lengths, cutoff_freq - - def wavelet(self, freqs: Tensor, sr: float, alpha: Tensor, gamma: float, window_fn: Callable[..., Tensor]) -> Tuple[Tensor, Tensor]: - """Wavelet filterbank constructed from set of center frequencies.""" - # First get filter lengths - lengths, _ = self.wavelet_lengths(freqs=freqs, sr=sr, alpha=alpha, gamma=gamma) - - # Next power of 2 - pad_to_size = 1<<(int(max(lengths))-1).bit_length() - - filters: Tensor - - for index, (ilen, freq) in enumerate(zip(lengths, freqs)): - # Build filter with length ceil(ilen) - # Use float32 in order to output complex(float) numbers later - t = torch.arange(-ilen // 2, ilen // 2, dtype=torch.float32) * 2 * torch.pi * freq / sr - sig = torch.cos(t) + 1j * torch.sin(t) - - # Multiply with window - sig_len = len(sig) - sig = sig * window_fn(sig_len) - - # L1 normalize - sig = torch.nn.functional.normalize(sig, p=1., dim=0) - - # Pad signal left and right to correct size - l_pad = math.floor((pad_to_size - sig_len) / 2) - r_pad = math.ceil((pad_to_size - sig_len) / 2) - sig = torch.nn.functional.pad(sig, (l_pad, r_pad), mode='constant', value=0.) - sig = sig.unsqueeze(0) - - if index == 0: - filters = sig - else: - filters = torch.cat([filters, sig], dim=0) - - return filters, lengths def forward(self, waveform: Tensor) -> Tensor: r""" @@ -823,10 +736,9 @@ def forward(self, waveform: Tensor) -> Tensor: waveform (Tensor): Tensor of audio of dimension (..., time). Returns: - Tensor: VQT spectrogram of size (..., ``n_bins``, time). + Tensor: variable-Q transform of size (..., ``n_bins``, time). """ - # Mypy type - vqt: torch.Tensor + vqt: Tensor # Iterate down the octaves for register_index, (temp_hop, n_fft) in enumerate(self.forward_params): @@ -866,7 +778,7 @@ class CQT(torch.nn.Module): .. properties:: Autograd TorchScript Sources - * https://librosa.org/doc/main/_modules/librosa/core/constantq.html + * https://librosa.org/doc/main/generated/librosa.cqt.html * https://www.aes.org/e-lib/online/browse.cfm?elib=17112 * https://newt.phys.unsw.edu.au/jw/notes.html @@ -919,7 +831,7 @@ def forward(self, waveform: Tensor) -> Tensor: waveform (Tensor): Tensor of audio of dimension (..., time). Returns: - Tensor: CQT spectrogram of size (..., ``n_bins``, time). + Tensor: constant-Q transform spectrogram of size (..., ``n_bins``, time). """ return self.transform(waveform) From b76ad5734d51a37f30224eb411f5f9a62d93fc25 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Sat, 15 Jun 2024 16:18:43 +0000 Subject: [PATCH 17/44] iCQT outline and VQT batch computation. --- src/torchaudio/transforms/_transforms.py | 95 +++++++++++++++++++++--- 1 file changed, 84 insertions(+), 11 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 9bfc714555..39fb9b5f97 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -733,24 +733,46 @@ def __init__( def forward(self, waveform: Tensor) -> Tensor: r""" Args: - waveform (Tensor): Tensor of audio of dimension (..., time). + waveform (Tensor): Tensor of audio of dimension (..., channels, time). + 2D or 3D; batch dimension is optional. Returns: - Tensor: variable-Q transform of size (..., ``n_bins``, time). + Tensor: variable-Q transform of size (..., channels, ``n_bins``, time). """ vqt: Tensor # Iterate down the octaves - for register_index, (temp_hop, n_fft) in enumerate(self.forward_params): + for register_index, (temp_hop, n_fft) in enumerate(self.forward_params): # STFT matrix - dft = torch.stft( - waveform, - n_fft=n_fft, - hop_length=temp_hop, - window=self.ones(n_fft), - pad_mode='constant', - return_complex=True, - ) + if waveform.ndim == 3: + dft: Tensor + + # torch stft does not support 3D computation yet + # iterate through channels for stft computation + for channel in range(waveform.shape[1]): + channel_dft = torch.stft( + waveform[:, channel, :], + n_fft=n_fft, + hop_length=temp_hop, + window=self.ones(n_fft), + pad_mode='constant', + return_complex=True, + ) + + if channel == 0: + dft = channel_dft.unsqueeze(1) + else: + dft = torch.cat([dft, channel_dft.unsqueeze(1)], dim=1) + + else: + dft = torch.stft( + waveform, + n_fft=n_fft, + hop_length=temp_hop, + window=self.ones(n_fft), + pad_mode='constant', + return_complex=True, + ) # Compute octave vqt temp_vqt = torch.einsum('ij,...jk->...ik', getattr(self, f"fft_basis_{register_index}"), dft) @@ -834,6 +856,57 @@ def forward(self, waveform: Tensor) -> Tensor: Tensor: constant-Q transform spectrogram of size (..., ``n_bins``, time). """ return self.transform(waveform) + + +class InverseCQT(torch.nn.Module): + r"""Compute the inverse constant Q-transform. + + .. devices:: CPU CUDA + + .. properties:: Autograd TorchScript + + Sources + * https://librosa.org/doc/main/generated/librosa.icqt.html + * https://www.aes.org/e-lib/online/browse.cfm?elib=17112 + * https://newt.phys.unsw.edu.au/jw/notes.html + + Args: + sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) + hop_length (int, optional): Length of hop between VQT windows. (Default: ``400``) + f_min (float, optional): Minimum frequency, which corresponds to first note. (Default: ``32.703``, or the frequency of C1 in Hz) + bins_per_octave (int, optional): Number of bins per octave. (Default: ``12``) + window_fn (Callable[..., Tensor], optional): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) + resampling_method (str, optional): The resampling method to use. + Options: [``sinc_interp_hann``, ``sinc_interp_kaiser``] (Default: ``"sinc_interp_hann"``) + + Example + >>> transform = transforms.InverseCQT() + >>> waveform = transform(cqt) # (..., time) + """ + __constants__ = ["sample_rate", "hop_length", "f_min", "bins_per_octave", "window_fn", "resampling_method"] + + def __init__( + self, + sample_rate: int = 16000, + hop_length: int = 400, + f_min: float = 32.703, + bins_per_octave: int = 12, + window_fn: Callable[..., Tensor] = torch.hann_window, + resampling_method: str = "sinc_interp_hann", + ) -> None: + super(InverseCQT, self).__init__() + torch._C._log_api_usage_once("torchaudio.transforms.InverseCQT") + + def forward(self, cqt: Tensor) -> Tensor: + r""" + Args: + cqt (Tensor): Constant-q trasnform tensor of dimension (..., ``n_bins``, time). + + Returns: + Tensor: waveform of size (..., time). + """ + pass class MFCC(torch.nn.Module): From 78a51698ca8fced8908aa17b90ad692278aba92b Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Sun, 16 Jun 2024 08:10:35 +0000 Subject: [PATCH 18/44] iCQT algorithm start and outline. --- src/torchaudio/transforms/_transforms.py | 44 ++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 39fb9b5f97..1c618aeae7 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -874,6 +874,7 @@ class InverseCQT(torch.nn.Module): sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) hop_length (int, optional): Length of hop between VQT windows. (Default: ``400``) f_min (float, optional): Minimum frequency, which corresponds to first note. (Default: ``32.703``, or the frequency of C1 in Hz) + n_bins (int, optional): Number of CQT frequency bins, starting at ``f_min``. (Default: ``84``) bins_per_octave (int, optional): Number of bins per octave. (Default: ``12``) window_fn (Callable[..., Tensor], optional): A function to create a window tensor that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) @@ -891,12 +892,55 @@ def __init__( sample_rate: int = 16000, hop_length: int = 400, f_min: float = 32.703, + n_bins: int = 84, bins_per_octave: int = 12, window_fn: Callable[..., Tensor] = torch.hann_window, resampling_method: str = "sinc_interp_hann", ) -> None: super(InverseCQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.InverseCQT") + + n_filters = min(bins_per_octave, n_bins) + frequencies, n_octaves = F.frequency_set(f_min, n_bins, bins_per_octave) + alpha = F.relative_bandwidths(frequencies, n_bins, bins_per_octave) + freq_lengths, cutoff_freq = F.wavelet_lengths(frequencies, sample_rate, alpha, 0.) + cqt_scale = torch.sqrt(freq_lengths) + + self.sample_rates = [] + self.hop_lengths = [] + temp_sr, temp_hop = float(sample_rate), hop_length + + for _ in range(n_octaves - 1, -1, -1): + self.sample_rates.append(temp_sr) + self.hop_lengths.append(temp_hop) + + if temp_hop % 2 == 0: + temp_sr /= 2. + temp_hop //= 2 + + self.sample_rates.reverse() + self.hop_lengths.reverse() + + for oct_index, (temp_sr, temp_hop) in enumerate(zip(self.sample_rates, self.hop_lengths)): + # Slice out correct octave + indices = slice(n_filters * oct_index, n_filters * (oct_index + 1)) + + octave_freqs = frequencies[indices] + octave_alphas = alpha[indices] + + # Compute wavelet filterbanks + basis, lengths = F.wavelet_fbank(octave_freqs, temp_sr, octave_alphas, 0., window_fn) + n_fft = basis.shape[1] + + # Normalize wrt FFT window length + factors = lengths.unsqueeze(1) / float(n_fft) + basis *= factors + + # Wavelet basis FFT + fft_basis = torch.fft.fft(basis, n=n_fft, dim=1)[:, :(n_fft//2) + 1] + + # Transpose basis + basis_inverse = fft_basis.H def forward(self, cqt: Tensor) -> Tensor: r""" From e06ac1df0e9f03d24923ff1d6b1c0c992b47fe2f Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Sun, 16 Jun 2024 09:13:32 +0000 Subject: [PATCH 19/44] Pre-computations done :) --- src/torchaudio/transforms/_transforms.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 1c618aeae7..242d1ce537 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -903,8 +903,11 @@ def __init__( n_filters = min(bins_per_octave, n_bins) frequencies, n_octaves = F.frequency_set(f_min, n_bins, bins_per_octave) alpha = F.relative_bandwidths(frequencies, n_bins, bins_per_octave) - freq_lengths, cutoff_freq = F.wavelet_lengths(frequencies, sample_rate, alpha, 0.) - cqt_scale = torch.sqrt(freq_lengths) + freq_lengths, _ = F.wavelet_lengths(frequencies, sample_rate, alpha, 0.) + + self.resampling_method = resampling_method + self.register_buffer("c_scale", torch.sqrt(freq_lengths)) + self.ones = lambda x: torch.ones(x, device=self.c_scale.device) self.sample_rates = [] self.hop_lengths = [] @@ -941,6 +944,12 @@ def __init__( # Transpose basis basis_inverse = fft_basis.H + squared_mag = torch.abs(basis_inverse)**2 + frequency_pow = 1 / squared_mag.sum(dim=0) + frequency_pow *= n_fft / freq_lengths[indices] + + self.register_buffer(f"basis_inverse_{oct_index}", basis_inverse) + self.register_buffer(f"frequency_pow_{oct_index}", frequency_pow) def forward(self, cqt: Tensor) -> Tensor: r""" From 878eb26e6ffc0e116c22b70b366036eeeca24573 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Sun, 16 Jun 2024 12:39:25 +0000 Subject: [PATCH 20/44] Make frequencies float32 to avoid icqt einsum issues. --- src/torchaudio/functional/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 772e20496c..98565e927c 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -2554,7 +2554,7 @@ def frequency_set(f_min: float, n_bins: int, bins_per_octave: int) -> Tuple[Tens int: number of octaves """ n_octaves = math.ceil(n_bins / bins_per_octave) - ratios = 2.0 ** (torch.arange(0, bins_per_octave * n_octaves, dtype=float) / bins_per_octave) + ratios = 2.0 ** (torch.arange(0, bins_per_octave * n_octaves, dtype=torch.float32) / bins_per_octave) return f_min * ratios[:n_bins], n_octaves From 53334af234589245c49aaab798f572e1da18da94 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Sun, 16 Jun 2024 12:40:03 +0000 Subject: [PATCH 21/44] Basis projection. --- src/torchaudio/transforms/_transforms.py | 36 ++++++++++++++++-------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 242d1ce537..49bf28c76c 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -742,7 +742,7 @@ def forward(self, waveform: Tensor) -> Tensor: vqt: Tensor # Iterate down the octaves - for register_index, (temp_hop, n_fft) in enumerate(self.forward_params): + for buffer_index, (temp_hop, n_fft) in enumerate(self.forward_params): # STFT matrix if waveform.ndim == 3: dft: Tensor @@ -775,9 +775,9 @@ def forward(self, waveform: Tensor) -> Tensor: ) # Compute octave vqt - temp_vqt = torch.einsum('ij,...jk->...ik', getattr(self, f"fft_basis_{register_index}"), dft) + temp_vqt = torch.einsum('ij,...jk->...ik', getattr(self, f"fft_basis_{buffer_index}"), dft) - if register_index == 0: + if buffer_index == 0: vqt = temp_vqt else: vqt = torch.cat([temp_vqt, vqt], dim=-2) @@ -909,22 +909,23 @@ def __init__( self.register_buffer("c_scale", torch.sqrt(freq_lengths)) self.ones = lambda x: torch.ones(x, device=self.c_scale.device) - self.sample_rates = [] - self.hop_lengths = [] + sample_rates = [] + hop_lengths = [] temp_sr, temp_hop = float(sample_rate), hop_length for _ in range(n_octaves - 1, -1, -1): - self.sample_rates.append(temp_sr) - self.hop_lengths.append(temp_hop) + sample_rates.append(temp_sr) + hop_lengths.append(temp_hop) if temp_hop % 2 == 0: temp_sr /= 2. temp_hop //= 2 - self.sample_rates.reverse() - self.hop_lengths.reverse() + sample_rates.reverse() + hop_lengths.reverse() + self.forward_params = [] - for oct_index, (temp_sr, temp_hop) in enumerate(zip(self.sample_rates, self.hop_lengths)): + for oct_index, (temp_sr, temp_hop) in enumerate(zip(sample_rates, hop_lengths)): # Slice out correct octave indices = slice(n_filters * oct_index, n_filters * (oct_index + 1)) @@ -950,6 +951,7 @@ def __init__( self.register_buffer(f"basis_inverse_{oct_index}", basis_inverse) self.register_buffer(f"frequency_pow_{oct_index}", frequency_pow) + self.forward_params.append((temp_sr, temp_hop, indices)) def forward(self, cqt: Tensor) -> Tensor: r""" @@ -959,7 +961,19 @@ def forward(self, cqt: Tensor) -> Tensor: Returns: Tensor: waveform of size (..., time). """ - pass + waveform: Tensor + + # Iterate down the octaves + for buffer_index, (sr, hop, indices) in enumerate(self.forward_params): + temp_proj = torch.einsum( + 'fc,c,c,...ct->...ft', + getattr(self, f"basis_inverse_{buffer_index}"), + self.c_scale[indices], + getattr(self, f"frequency_pow_{buffer_index}"), + cqt[..., indices, :], + ) + + return waveform class MFCC(torch.nn.Module): From da65ec3e38166942bff3eb5fe647e4c62aea15d5 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Sun, 16 Jun 2024 13:41:01 +0000 Subject: [PATCH 22/44] iCQT for 2D tensors :) --- src/torchaudio/transforms/_transforms.py | 25 +++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 49bf28c76c..81a72ae15e 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -900,10 +900,11 @@ def __init__( super(InverseCQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.InverseCQT") + self.sample_rate = sample_rate n_filters = min(bins_per_octave, n_bins) frequencies, n_octaves = F.frequency_set(f_min, n_bins, bins_per_octave) alpha = F.relative_bandwidths(frequencies, n_bins, bins_per_octave) - freq_lengths, _ = F.wavelet_lengths(frequencies, sample_rate, alpha, 0.) + freq_lengths, _ = F.wavelet_lengths(frequencies, self.sample_rate, alpha, 0.) self.resampling_method = resampling_method self.register_buffer("c_scale", torch.sqrt(freq_lengths)) @@ -911,7 +912,7 @@ def __init__( sample_rates = [] hop_lengths = [] - temp_sr, temp_hop = float(sample_rate), hop_length + temp_sr, temp_hop = float(self.sample_rate), hop_length for _ in range(n_octaves - 1, -1, -1): sample_rates.append(temp_sr) @@ -964,7 +965,7 @@ def forward(self, cqt: Tensor) -> Tensor: waveform: Tensor # Iterate down the octaves - for buffer_index, (sr, hop, indices) in enumerate(self.forward_params): + for buffer_index, (temp_sr, temp_hop, indices) in enumerate(self.forward_params): temp_proj = torch.einsum( 'fc,c,c,...ct->...ft', getattr(self, f"basis_inverse_{buffer_index}"), @@ -972,6 +973,24 @@ def forward(self, cqt: Tensor) -> Tensor: getattr(self, f"frequency_pow_{buffer_index}"), cqt[..., indices, :], ) + n_fft = 2 * (temp_proj.shape[-2] - 1) + temp_waveform = torch.istft( + temp_proj, + n_fft=n_fft, + hop_length=temp_hop, + window=self.ones(n_fft), + ) + temp_waveform = F.resample( + temp_waveform, + orig_freq=1, + new_freq=self.sample_rate//temp_sr, + resampling_method=self.resampling_method, + ) + + if buffer_index == 0: + waveform = temp_waveform + else: + waveform[..., :temp_waveform.shape[-1]] += temp_waveform return waveform From d146d1ac0e4000b9da538552d53d245a6f14fe9b Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Sun, 16 Jun 2024 20:17:38 +0000 Subject: [PATCH 23/44] Comments on the iCQT and a few other spots. --- src/torchaudio/transforms/_transforms.py | 53 +++++++++++++++++++----- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 81a72ae15e..b372a4654b 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -850,10 +850,11 @@ def __init__( def forward(self, waveform: Tensor) -> Tensor: r""" Args: - waveform (Tensor): Tensor of audio of dimension (..., time). + waveform (Tensor): Tensor of audio of dimension (..., channels, time). + 2D or 3D; batch dimension is optional. Returns: - Tensor: constant-Q transform spectrogram of size (..., ``n_bins``, time). + Tensor: constant-Q transform spectrogram of size (..., channels, ``n_bins``, time). """ return self.transform(waveform) @@ -910,6 +911,7 @@ def __init__( self.register_buffer("c_scale", torch.sqrt(freq_lengths)) self.ones = lambda x: torch.ones(x, device=self.c_scale.device) + # Get sample rates and hop lengths used during CQT downsampling sample_rates = [] hop_lengths = [] temp_sr, temp_hop = float(self.sample_rate), hop_length @@ -924,6 +926,8 @@ def __init__( sample_rates.reverse() hop_lengths.reverse() + + # Now pre-compute what's needed for forward loop self.forward_params = [] for oct_index, (temp_sr, temp_hop) in enumerate(zip(sample_rates, hop_lengths)): @@ -946,8 +950,12 @@ def __init__( # Transpose basis basis_inverse = fft_basis.H + + # Compute filter power spectrum squared_mag = torch.abs(basis_inverse)**2 frequency_pow = 1 / squared_mag.sum(dim=0) + + # Adjust by normalizing with lengths frequency_pow *= n_fft / freq_lengths[indices] self.register_buffer(f"basis_inverse_{oct_index}", basis_inverse) @@ -957,15 +965,17 @@ def __init__( def forward(self, cqt: Tensor) -> Tensor: r""" Args: - cqt (Tensor): Constant-q trasnform tensor of dimension (..., ``n_bins``, time). + cqt (Tensor): Constant-q transform tensor of dimension (..., channels, ``n_bins``, time). + 3D or 4D; batch dimension is optional. Returns: - Tensor: waveform of size (..., time). + Tensor: waveform of size (..., channels, time). """ waveform: Tensor # Iterate down the octaves for buffer_index, (temp_sr, temp_hop, indices) in enumerate(self.forward_params): + # Inverse project the basis temp_proj = torch.einsum( 'fc,c,c,...ct->...ft', getattr(self, f"basis_inverse_{buffer_index}"), @@ -973,13 +983,36 @@ def forward(self, cqt: Tensor) -> Tensor: getattr(self, f"frequency_pow_{buffer_index}"), cqt[..., indices, :], ) + # Taken from librosa n_fft = 2 * (temp_proj.shape[-2] - 1) - temp_waveform = torch.istft( - temp_proj, - n_fft=n_fft, - hop_length=temp_hop, - window=self.ones(n_fft), - ) + + if temp_proj.ndim == 4: + temp_waveform: Tensor + + # torch istft does not support 4D computation yet + # iterate through channels for stft computation + for channel in range(temp_proj.shape[1]): + channel_waveform = torch.istft( + temp_proj[:, channel, :, :], + n_fft=n_fft, + hop_length=temp_hop, + window=self.ones(n_fft), + ) + + if channel == 0: + temp_waveform = channel_waveform.unsqueeze(1) + else: + temp_waveform = torch.cat([temp_waveform, channel_waveform.unsqueeze(1)], dim=1) + + else: + temp_waveform = torch.istft( + temp_proj, + n_fft=n_fft, + hop_length=temp_hop, + window=self.ones(n_fft), + ) + + # Resample to desired output shape temp_waveform = F.resample( temp_waveform, orig_freq=1, From 170e9faedacb9fc09c0be5d74264192f8d434170 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Sun, 16 Jun 2024 20:24:11 +0000 Subject: [PATCH 24/44] Proper iCQT import. --- src/torchaudio/transforms/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchaudio/transforms/__init__.py b/src/torchaudio/transforms/__init__.py index af47595fb3..6417ec0fb7 100644 --- a/src/torchaudio/transforms/__init__.py +++ b/src/torchaudio/transforms/__init__.py @@ -10,6 +10,7 @@ FFTConvolve, FrequencyMasking, GriffinLim, + InverseCQT, InverseMelScale, InverseSpectrogram, LFCC, @@ -48,6 +49,7 @@ "FFTConvolve", "FrequencyMasking", "GriffinLim", + "InverseCQT", "InverseMelScale", "InverseSpectrogram", "LFCC", From fa3298f5413e35649a8736ca05061c99b0960d4f Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Wed, 19 Jun 2024 17:18:05 +0000 Subject: [PATCH 25/44] Code cleanup in functional. --- src/torchaudio/functional/functional.py | 26 +++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 98565e927c..61cf49c2bb 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -2542,6 +2542,8 @@ def frechet_distance(mu_x, sigma_x, mu_y, sigma_y): def frequency_set(f_min: float, n_bins: int, bins_per_octave: int) -> Tuple[Tensor, int]: r"""Return a set of frequencies that assumes an equal temperament tuning system. + .. devices:: CPU + Adapted from librosa: https://librosa.org/doc/main/generated/librosa.interval_frequencies.html Args: @@ -2553,6 +2555,9 @@ def frequency_set(f_min: float, n_bins: int, bins_per_octave: int) -> Tuple[Tens torch.Tensor: frequencies. int: number of octaves """ + if f_min < 0. or n_bins < 1 or bins_per_octave < 1: + raise ValueError("f_min must be positive. n_bins and bins_per_octave must be ints and superior to 1.") + n_octaves = math.ceil(n_bins / bins_per_octave) ratios = 2.0 ** (torch.arange(0, bins_per_octave * n_octaves, dtype=torch.float32) / bins_per_octave) return f_min * ratios[:n_bins], n_octaves @@ -2561,6 +2566,8 @@ def frequency_set(f_min: float, n_bins: int, bins_per_octave: int) -> Tuple[Tens def relative_bandwidths(freqs: Tensor, n_bins: int, bins_per_octave: int) -> Tensor: r"""Compute relative bandwidths for specified frequencies. + .. devices:: CPU + Adapted from librosa: https://librosa.org/doc/main/generated/librosa.filters.wavelet_lengths.html Args: @@ -2571,6 +2578,9 @@ def relative_bandwidths(freqs: Tensor, n_bins: int, bins_per_octave: int) -> Ten Returns: torch.Tensor: relative bandwidths for set of frequencies. """ + if min(freqs) < 0. or n_bins < 1 or bins_per_octave < 1: + raise ValueError("freqs must be positive. n_bins and bins_per_octave must be ints and superior to 1.") + if n_bins > 1: # Approximate local octave resolution around each frequency bandpass_octave = torch.empty_like(freqs) @@ -2588,7 +2598,7 @@ def relative_bandwidths(freqs: Tensor, n_bins: int, bins_per_octave: int) -> Ten else: # Special case when single basis frequency is used rel_band_coeff = 2. ** (1. / bins_per_octave) - alpha = torch.atleast_1d((rel_band_coeff**2 - 1) / (rel_band_coeff**2 + 1)) + alpha = torch.tensor([(rel_band_coeff**2 - 1) / (rel_band_coeff**2 + 1)]) return alpha @@ -2596,6 +2606,8 @@ def relative_bandwidths(freqs: Tensor, n_bins: int, bins_per_octave: int) -> Ten def wavelet_lengths(freqs: Tensor, sr: float, alpha: Tensor, gamma: float) -> Tuple[Tensor, float]: r"""Length of each filter in a wavelet basis. + .. devices:: CPU + Source: * https://librosa.org/doc/main/generated/librosa.filters.wavelet_lengths.html @@ -2609,6 +2621,12 @@ def wavelet_lengths(freqs: Tensor, sr: float, alpha: Tensor, gamma: float) -> Tu Tensor: filter lengths. float: cutoff frequency of highest bin. """ + if gamma < 0. or sr < 0.: + raise ValueError("gamma and sr must be positive!") + + if min(freqs) < 0. or min(alpha) < 0.: + raise ValueError("freqs and alpha must be positive!") + # We assume filter_scale (librosa param) is 1 Q = 1. / alpha @@ -2626,6 +2644,8 @@ def wavelet_lengths(freqs: Tensor, sr: float, alpha: Tensor, gamma: float) -> Tu def wavelet_fbank(freqs: Tensor, sr: float, alpha: Tensor, gamma: float, window_fn: Callable[..., Tensor]) -> Tuple[Tensor, Tensor]: r"""Wavelet filterbank constructed from set of center frequencies. + .. devices:: CPU + Source: * https://librosa.org/doc/main/generated/librosa.filters.wavelet.html @@ -2645,9 +2665,7 @@ def wavelet_fbank(freqs: Tensor, sr: float, alpha: Tensor, gamma: float, window_ # Next power of 2 pad_to_size = 1<<(int(max(lengths))-1).bit_length() - - filters: Tensor - + for index, (ilen, freq) in enumerate(zip(lengths, freqs)): # Build filter with length ceil(ilen) # Use float32 in order to output complex(float) numbers later From 9a824a3147c7c9d6f2d47f48639412a6f2e4836c Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Wed, 19 Jun 2024 17:19:02 +0000 Subject: [PATCH 26/44] Librosa compatibility functional tests. --- .../librosa_compatibility_test_impl.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py b/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py index 4e8d4d3d5f..78bcac7575 100644 --- a/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py +++ b/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py @@ -2,6 +2,7 @@ from distutils.version import StrictVersion import torch +import math import torchaudio.functional as F from parameterized import param from torchaudio._internal.module_utils import is_module_available @@ -117,6 +118,94 @@ def test_amplitude_to_DB(self): result = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db) expected = librosa.core.amplitude_to_db(spec[0].cpu().numpy())[None, ...] self.assertEqual(result, torch.from_numpy(expected)) + + def test_frequency_set(self): + f_min = 32.703 + n_bins = 84 + bins_per_octave = 12 + + actual_freqs, _ = F.frequency_set(f_min, n_bins, bins_per_octave) + expected_freqs = librosa.interval_frequencies( + n_bins=n_bins, fmin=f_min, intervals="equal", bins_per_octave=bins_per_octave, tuning=0.0, sort=True + ).astype(np.float32) + + self.assertEqual(actual_freqs, torch.from_numpy(expected_freqs)) + + def test_single_bin_relative_bandwidths(self): + f_min = 32.703 + n_bins = 1 + bins_per_octave = 12 + + torch_freqs, _ = F.frequency_set(f_min, n_bins, bins_per_octave) + + # Compute expected_alpha + # __et_relative_bw: from https://librosa.org/doc/main/_modules/librosa/core/constantq.html + r = 2 ** (1 / bins_per_octave) + expected_alpha = np.atleast_1d((r**2 - 1) / (r**2 + 1)).astype(np.float32) + actual_alpha = F.relative_bandwidths(torch_freqs, n_bins, bins_per_octave) + + self.assertEqual(actual_alpha, torch.from_numpy(expected_alpha)) + + def test_multi_bin_relative_bandwidths(self): + f_min = 32.703 + n_bins = 84 + bins_per_octave = 12 + + np_freqs = librosa.interval_frequencies( + n_bins=n_bins, fmin=f_min, intervals="equal", bins_per_octave=bins_per_octave, tuning=0.0, sort=True + ).astype(np.float32) + torch_freqs = torch.from_numpy(np_freqs) + + expected_alpha = librosa.filters._relative_bandwidth(freqs=np_freqs) + actual_alpha = F.relative_bandwidths(torch_freqs, n_bins, bins_per_octave) + + self.assertEqual(actual_alpha, torch.from_numpy(expected_alpha)) + + def test_wavelet_lengths(self): + f_min = 32.703 + n_bins = 84 + bins_per_octave = 12 + sample_rate = 16000 + gamma = 0. + + np_freqs = librosa.interval_frequencies( + n_bins=n_bins, fmin=f_min, intervals="equal", bins_per_octave=bins_per_octave, tuning=0.0, sort=True + ).astype(np.float32) + np_alpha = librosa.filters._relative_bandwidth(freqs=np_freqs) + + torch_freqs = torch.from_numpy(np_freqs) + torch_alpha = torch.from_numpy(np_alpha) + + librosa_lengths, _ = librosa.filters.wavelet_lengths( + freqs=np_freqs, sr=sample_rate, window='hann', filter_scale=1, gamma=0, alpha=np_alpha + ) + torch_lengths, _ = F.wavelet_lengths(torch_freqs, sample_rate, torch_alpha, gamma) + + self.assertEqual(torch_lengths, torch.from_numpy(librosa_lengths)) + + def test_wavelet_fbank(self): + f_min = 32.703 + n_bins = 84 + bins_per_octave = 12 + sample_rate = 16000 + gamma = 0. + window_fn = torch.hann_window + + np_freqs = librosa.interval_frequencies( + n_bins=n_bins, fmin=f_min, intervals="equal", bins_per_octave=bins_per_octave, tuning=0.0, sort=True + ).astype(np.float32) + np_alpha = librosa.filters._relative_bandwidth(freqs=np_freqs) + + torch_freqs = torch.from_numpy(np_freqs) + torch_alpha = torch.from_numpy(np_alpha) + + librosa_filters, librosa_lengths = librosa.filters.wavelet( + freqs=np_freqs, sr=sample_rate, window='hann', filter_scale=1, pad_fft=True, gamma=gamma, alpha=np_alpha + ) + torch_filters, torch_lengths = F.wavelet_fbank(torch_freqs, sample_rate, torch_alpha, gamma, window_fn) + + self.assertEqual(torch_filters, torch.from_numpy(librosa_filters)) + self.assertEqual(torch_lengths, torch.from_numpy(librosa_lengths)) @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") From 2ca6bd9eacf5ee95fa3da9858b518b12a9b63198 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Wed, 19 Jun 2024 17:23:38 +0000 Subject: [PATCH 27/44] Small comment removed. --- .../functional/librosa_compatibility_test_impl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py b/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py index 78bcac7575..611ef13987 100644 --- a/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py +++ b/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py @@ -138,7 +138,6 @@ def test_single_bin_relative_bandwidths(self): torch_freqs, _ = F.frequency_set(f_min, n_bins, bins_per_octave) - # Compute expected_alpha # __et_relative_bw: from https://librosa.org/doc/main/_modules/librosa/core/constantq.html r = 2 ** (1 / bins_per_octave) expected_alpha = np.atleast_1d((r**2 - 1) / (r**2 + 1)).astype(np.float32) From 9edda98c502716980160f1ae110a4a8c4065d28c Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Mon, 24 Jun 2024 02:19:34 +0000 Subject: [PATCH 28/44] Fixing functional tests with new dtype method. --- .../librosa_compatibility_test_impl.py | 62 ++++++++++++++----- 1 file changed, 48 insertions(+), 14 deletions(-) diff --git a/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py b/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py index 611ef13987..747afa72bd 100644 --- a/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py +++ b/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py @@ -124,10 +124,15 @@ def test_frequency_set(self): n_bins = 84 bins_per_octave = 12 - actual_freqs, _ = F.frequency_set(f_min, n_bins, bins_per_octave) + actual_freqs, _ = F.frequency_set(f_min, n_bins, bins_per_octave, dtype=torch.double) expected_freqs = librosa.interval_frequencies( - n_bins=n_bins, fmin=f_min, intervals="equal", bins_per_octave=bins_per_octave, tuning=0.0, sort=True - ).astype(np.float32) + n_bins=n_bins, + fmin=f_min, + intervals="equal", + bins_per_octave=bins_per_octave, + tuning=0.0, + sort=True, + ) self.assertEqual(actual_freqs, torch.from_numpy(expected_freqs)) @@ -136,11 +141,11 @@ def test_single_bin_relative_bandwidths(self): n_bins = 1 bins_per_octave = 12 - torch_freqs, _ = F.frequency_set(f_min, n_bins, bins_per_octave) + torch_freqs, _ = F.frequency_set(f_min, n_bins, bins_per_octave, dtype=torch.double) # __et_relative_bw: from https://librosa.org/doc/main/_modules/librosa/core/constantq.html r = 2 ** (1 / bins_per_octave) - expected_alpha = np.atleast_1d((r**2 - 1) / (r**2 + 1)).astype(np.float32) + expected_alpha = np.atleast_1d((r**2 - 1) / (r**2 + 1)) actual_alpha = F.relative_bandwidths(torch_freqs, n_bins, bins_per_octave) self.assertEqual(actual_alpha, torch.from_numpy(expected_alpha)) @@ -151,8 +156,13 @@ def test_multi_bin_relative_bandwidths(self): bins_per_octave = 12 np_freqs = librosa.interval_frequencies( - n_bins=n_bins, fmin=f_min, intervals="equal", bins_per_octave=bins_per_octave, tuning=0.0, sort=True - ).astype(np.float32) + n_bins=n_bins, + fmin=f_min, + intervals="equal", + bins_per_octave=bins_per_octave, + tuning=0.0, + sort=True, + ) torch_freqs = torch.from_numpy(np_freqs) expected_alpha = librosa.filters._relative_bandwidth(freqs=np_freqs) @@ -168,15 +178,25 @@ def test_wavelet_lengths(self): gamma = 0. np_freqs = librosa.interval_frequencies( - n_bins=n_bins, fmin=f_min, intervals="equal", bins_per_octave=bins_per_octave, tuning=0.0, sort=True - ).astype(np.float32) + n_bins=n_bins, + fmin=f_min, + intervals="equal", + bins_per_octave=bins_per_octave, + tuning=0.0, + sort=True, + ) np_alpha = librosa.filters._relative_bandwidth(freqs=np_freqs) torch_freqs = torch.from_numpy(np_freqs) torch_alpha = torch.from_numpy(np_alpha) librosa_lengths, _ = librosa.filters.wavelet_lengths( - freqs=np_freqs, sr=sample_rate, window='hann', filter_scale=1, gamma=0, alpha=np_alpha + freqs=np_freqs, + sr=sample_rate, + window='hann', + filter_scale=1, + gamma=0, + alpha=np_alpha, ) torch_lengths, _ = F.wavelet_lengths(torch_freqs, sample_rate, torch_alpha, gamma) @@ -191,17 +211,31 @@ def test_wavelet_fbank(self): window_fn = torch.hann_window np_freqs = librosa.interval_frequencies( - n_bins=n_bins, fmin=f_min, intervals="equal", bins_per_octave=bins_per_octave, tuning=0.0, sort=True - ).astype(np.float32) + n_bins=n_bins, + fmin=f_min, + intervals="equal", + bins_per_octave=bins_per_octave, + tuning=0.0, + sort=True, + ) np_alpha = librosa.filters._relative_bandwidth(freqs=np_freqs) torch_freqs = torch.from_numpy(np_freqs) torch_alpha = torch.from_numpy(np_alpha) librosa_filters, librosa_lengths = librosa.filters.wavelet( - freqs=np_freqs, sr=sample_rate, window='hann', filter_scale=1, pad_fft=True, gamma=gamma, alpha=np_alpha + freqs=np_freqs, + sr=sample_rate, + window='hann', + filter_scale=1, + pad_fft=True, + gamma=gamma, + alpha=np_alpha, + dtype=np.complex128, + ) + torch_filters, torch_lengths = F.wavelet_fbank( + torch_freqs, sample_rate, torch_alpha, gamma, window_fn, dtype=torch.double ) - torch_filters, torch_lengths = F.wavelet_fbank(torch_freqs, sample_rate, torch_alpha, gamma, window_fn) self.assertEqual(torch_filters, torch.from_numpy(librosa_filters)) self.assertEqual(torch_lengths, torch.from_numpy(librosa_lengths)) From 0eee7adf1deafb2d22701e591afc50ff0f0a66e8 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Mon, 24 Jun 2024 02:19:55 +0000 Subject: [PATCH 29/44] Proper transform tests. --- .../transforms/autograd_test_impl.py | 59 ++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index 9c321fe223..58510b8ac6 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -30,7 +30,11 @@ def assert_grad( nondet_tol: float = 0.0, enable_all_grad: bool = True, ): - transform = transform.to(dtype=torch.float64, device=self.device) + # VQT, CQT, and InverseCQT have complex register buffers + # dtype is defined upon object creation + if not isinstance(transform, (T.VQT, T.CQT, T.InverseCQT)): + transform = transform.to(dtype=torch.float64) + transform = transform.to(device=self.device) # gradcheck and gradgradcheck only pass if the input tensors are of dtype `torch.double` or # `torch.cdouble`, when the default eps and tolerance values are used. @@ -91,6 +95,59 @@ def test_melspectrogram(self): transform = T.MelSpectrogram(sample_rate=sample_rate) waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) self.assert_grad(transform, [waveform], nondet_tol=1e-10) + + def test_vqt(self): + sample_rate = 8000 + hop_length=200 + n_bins=72 + gamma=3. + + transform = T.VQT( + sample_rate=sample_rate, + hop_length=hop_length, + n_bins=n_bins, + gamma=gamma, + dtype=torch.double, + ) + waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) + self.assert_grad(transform, [waveform]) + + def test_cqt(self): + sample_rate = 8000 + hop_length=200 + n_bins=72 + + transform = T.CQT( + sample_rate=sample_rate, + hop_length=hop_length, + n_bins=n_bins, + dtype=torch.double, + ) + waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) + self.assert_grad(transform, [waveform]) + + def test_inverse_cqt(self): + sample_rate = 8000 + hop_length=200 + n_bins=72 + + transform = T.CQT( + sample_rate=sample_rate, + hop_length=hop_length, + n_bins=n_bins, + dtype=torch.double, + ) + inverse_transform = T.InverseCQT( + sample_rate=sample_rate, + hop_length=hop_length, + n_bins=n_bins, + dtype=torch.double, + ) + waveform = get_whitenoise( + sample_rate=sample_rate, duration=0.05, n_channels=2, dtype=torch.double + ) + cqt = transform(waveform) + self.assert_grad(inverse_transform, [cqt]) @nested_params( [0, 0.99], From 8a5cbe821ee4d1938f11edcbd56b59cd971cab61 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Mon, 24 Jun 2024 02:20:22 +0000 Subject: [PATCH 30/44] Updated src code for float and double transforms. --- src/torchaudio/functional/functional.py | 12 ++-- src/torchaudio/transforms/_transforms.py | 77 ++++++++++++++---------- 2 files changed, 53 insertions(+), 36 deletions(-) diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 61cf49c2bb..7f0f2a4abd 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -2539,7 +2539,7 @@ def frechet_distance(mu_x, sigma_x, mu_y, sigma_y): return a + b - 2 * c -def frequency_set(f_min: float, n_bins: int, bins_per_octave: int) -> Tuple[Tensor, int]: +def frequency_set(f_min: float, n_bins: int, bins_per_octave: int, dtype: torch.dtype) -> Tuple[Tensor, int]: r"""Return a set of frequencies that assumes an equal temperament tuning system. .. devices:: CPU @@ -2559,7 +2559,7 @@ def frequency_set(f_min: float, n_bins: int, bins_per_octave: int) -> Tuple[Tens raise ValueError("f_min must be positive. n_bins and bins_per_octave must be ints and superior to 1.") n_octaves = math.ceil(n_bins / bins_per_octave) - ratios = 2.0 ** (torch.arange(0, bins_per_octave * n_octaves, dtype=torch.float32) / bins_per_octave) + ratios = 2.0 ** (torch.arange(0, bins_per_octave * n_octaves, dtype=dtype) / bins_per_octave) return f_min * ratios[:n_bins], n_octaves @@ -2598,7 +2598,7 @@ def relative_bandwidths(freqs: Tensor, n_bins: int, bins_per_octave: int) -> Ten else: # Special case when single basis frequency is used rel_band_coeff = 2. ** (1. / bins_per_octave) - alpha = torch.tensor([(rel_band_coeff**2 - 1) / (rel_band_coeff**2 + 1)]) + alpha = torch.tensor([(rel_band_coeff**2 - 1) / (rel_band_coeff**2 + 1)], dtype=freqs.dtype) return alpha @@ -2641,7 +2641,9 @@ def wavelet_lengths(freqs: Tensor, sr: float, alpha: Tensor, gamma: float) -> Tu return lengths, cutoff_freq -def wavelet_fbank(freqs: Tensor, sr: float, alpha: Tensor, gamma: float, window_fn: Callable[..., Tensor]) -> Tuple[Tensor, Tensor]: +def wavelet_fbank( + freqs: Tensor, sr: float, alpha: Tensor, gamma: float, window_fn: Callable[..., Tensor], dtype: torch.dtype, +) -> Tuple[Tensor, Tensor]: r"""Wavelet filterbank constructed from set of center frequencies. .. devices:: CPU @@ -2669,7 +2671,7 @@ def wavelet_fbank(freqs: Tensor, sr: float, alpha: Tensor, gamma: float, window_ for index, (ilen, freq) in enumerate(zip(lengths, freqs)): # Build filter with length ceil(ilen) # Use float32 in order to output complex(float) numbers later - t = torch.arange(-ilen // 2, ilen // 2, dtype=torch.float32) * 2 * torch.pi * freq / sr + t = torch.arange(-ilen // 2, ilen // 2, dtype=dtype) * 2 * torch.pi * freq / sr sig = torch.cos(t) + 1j * torch.sin(t) # Multiply with window diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index b372a4654b..63c8a737aa 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -636,7 +636,8 @@ class VQT(torch.nn.Module): Args: sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) hop_length (int, optional): Length of hop between VQT windows. (Default: ``400``) - f_min (float, optional): Minimum frequency, which corresponds to first note. (Default: ``32.703``, or the frequency of C1 in Hz) + f_min (float, optional): Minimum frequency, which corresponds to first note. + (Default: ``32.703``, or the frequency of C1 in Hz) n_bins (int, optional): Number of VQT frequency bins, starting at ``f_min``. (Default: ``84``) gamma (float, optional): Offset that controls VQT filter lengths. Larger values increase the time resolution at lower frequencies. (Default: ``0.``) @@ -645,13 +646,19 @@ class VQT(torch.nn.Module): that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) resampling_method (str, optional): The resampling method to use. Options: [``sinc_interp_hann``, ``sinc_interp_kaiser``] (Default: ``"sinc_interp_hann"``) + dtype (torch.device, optional): + Determines the precision that kernels are pre-computed and cached in. Note that complex bases + are either cfloat or cdouble depending on provided precision. + Options: [``torch.float``, ``torch.double``] (Default: ``torch.float``) Example >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) >>> transform = transforms.VQT(sample_rate) >>> vqt = transform(waveform) # (..., n_bins, time) """ - __constants__ = ["sample_rate", "hop_length", "f_min", "n_bins", "gamma", "bins_per_octave", "window_fn", "resampling_method"] + __constants__ = [ + "sample_rate", "hop_length", "f_min", "n_bins", "gamma", "bins_per_octave", "window_fn", "resampling_method", + ] def __init__( self, @@ -663,16 +670,17 @@ def __init__( bins_per_octave: int = 12, window_fn: Callable[..., Tensor] = torch.hann_window, resampling_method: str = "sinc_interp_hann", + dtype: Optional[torch.dtype] = torch.float, ) -> None: super(VQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.VQT") n_filters = min(bins_per_octave, n_bins) - frequencies, n_octaves = F.frequency_set(f_min, n_bins, bins_per_octave) + frequencies, n_octaves = F.frequency_set(f_min, n_bins, bins_per_octave, dtype) alpha = F.relative_bandwidths(frequencies, n_bins, bins_per_octave) freq_lengths, cutoff_freq = F.wavelet_lengths(frequencies, sample_rate, alpha, gamma) - self.resample = Resample(2, 1, resampling_method) + self.resample = Resample(2, 1, resampling_method, dtype=dtype) self.register_buffer("expanded_lengths", freq_lengths.unsqueeze(0).unsqueeze(-1)) self.ones = lambda x: torch.ones(x, device=self.expanded_lengths.device) @@ -683,18 +691,20 @@ def __init__( if cutoff_freq > nyquist: raise ValueError( - f"Maximum bin cutoff frequency is {cutoff_freq} and superior to the Nyquist frequency {nyquist}. " - "Try to reduce the number of frequency bins." + f"Maximum bin cutoff frequency is {cutoff_freq} and superior to the " + f"Nyquist frequency {nyquist}. Try to reduce the number of frequency bins." ) if num_hop_downsamples > n_octaves: warnings.warn( - f"Hop length can be divided {num_hop_downsamples} times by 2 before becoming odd. " - f"The VQT is however being computed for {n_octaves} octaves. Consider lowering the hop length or increasing the number of bins for more accurate results." + f"Hop length can be divided {num_hop_downsamples} times by 2 before becoming " + f"odd. The VQT is however being computed for {n_octaves} octaves. Consider lowering " + "the hop length or increasing the number of bins for more accurate results." ) if nyquist / cutoff_freq > 4: warnings.warn( - f"The Nyquist frequency {nyquist} is significantly higher than the highest filter's cutoff frequency {cutoff_freq}. " - "Consider resampling your signal to a lower sample rate or increasing the number of bins before VQT computation for more accurate results." + f"The Nyquist frequency {nyquist} is significantly higher than the highest filter's " + f"cutoff frequency {cutoff_freq}. Consider resampling your signal to a lower sample " + "rate or increasing the number of bins before VQT computation for more accurate results." ) # Now pre-compute what's needed for forward loop @@ -710,7 +720,7 @@ def __init__( octave_alphas = alpha[indices] # Compute wavelet filterbanks - basis, lengths = F.wavelet_fbank(octave_freqs, temp_sr, octave_alphas, gamma, window_fn) + basis, lengths = F.wavelet_fbank(octave_freqs, temp_sr, octave_alphas, gamma, window_fn, dtype) n_fft = basis.shape[1] # Normalize wrt FFT window length @@ -719,7 +729,7 @@ def __init__( # Wavelet basis FFT fft_basis = torch.fft.fft(basis, n=n_fft, dim=1)[:, :(n_fft//2) + 1] - fft_basis[:] *= math.sqrt(sample_rate / temp_sr) + fft_basis *= math.sqrt(sample_rate / temp_sr) self.register_buffer(f"fft_basis_{register_index}", fft_basis) self.forward_params.append((temp_hop, n_fft)) @@ -739,14 +749,10 @@ def forward(self, waveform: Tensor) -> Tensor: Returns: Tensor: variable-Q transform of size (..., channels, ``n_bins``, time). """ - vqt: Tensor - # Iterate down the octaves for buffer_index, (temp_hop, n_fft) in enumerate(self.forward_params): # STFT matrix if waveform.ndim == 3: - dft: Tensor - # torch stft does not support 3D computation yet # iterate through channels for stft computation for channel in range(waveform.shape[1]): @@ -807,20 +813,27 @@ class CQT(torch.nn.Module): Args: sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) hop_length (int, optional): Length of hop between CQT windows. (Default: ``400``) - f_min (float, optional): Minimum frequency, which corresponds to first note. (Default: ``32.703``, or the frequency of C1 in Hz) + f_min (float, optional): Minimum frequency, which corresponds to first note. + (Default: ``32.703``, or the frequency of C1 in Hz) n_bins (int, optional): Number of CQT frequency bins, starting at ``f_min``. (Default: ``84``) bins_per_octave (int, optional): Number of bins per octave. (Default: ``12``) window_fn (Callable[..., Tensor], optional): A function to create a window tensor that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) resampling_method (str, optional): The resampling method to use. Options: [``sinc_interp_hann``, ``sinc_interp_kaiser``] (Default: ``"sinc_interp_hann"``) + dtype (torch.device, optional): + Determines the precision that kernels are pre-computed and cached in. Note that complex bases + are either cfloat or cdouble depending on provided precision. + Options: [``torch.float``, ``torch.double``] (Default: ``torch.float``) Example >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) >>> transform = transforms.CQT(sample_rate) >>> cqt = transform(waveform) # (..., n_bins, time) """ - __constants__ = ["sample_rate", "hop_length", "f_min", "n_bins", "bins_per_octave", "window_fn", "resampling_method"] + __constants__ = [ + "sample_rate", "hop_length", "f_min", "n_bins", "bins_per_octave", "window_fn", "resampling_method", + ] def __init__( self, @@ -831,6 +844,7 @@ def __init__( bins_per_octave: int = 12, window_fn: Callable[..., Tensor] = torch.hann_window, resampling_method: str = "sinc_interp_hann", + dtype: Optional[torch.dtype] = torch.float, ) -> None: super(CQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.CQT") @@ -845,6 +859,7 @@ def __init__( bins_per_octave=bins_per_octave, window_fn=window_fn, resampling_method=resampling_method, + dtype=dtype, ) def forward(self, waveform: Tensor) -> Tensor: @@ -874,19 +889,26 @@ class InverseCQT(torch.nn.Module): Args: sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) hop_length (int, optional): Length of hop between VQT windows. (Default: ``400``) - f_min (float, optional): Minimum frequency, which corresponds to first note. (Default: ``32.703``, or the frequency of C1 in Hz) + f_min (float, optional): Minimum frequency, which corresponds to first note. + (Default: ``32.703``, or the frequency of C1 in Hz) n_bins (int, optional): Number of CQT frequency bins, starting at ``f_min``. (Default: ``84``) bins_per_octave (int, optional): Number of bins per octave. (Default: ``12``) window_fn (Callable[..., Tensor], optional): A function to create a window tensor that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) resampling_method (str, optional): The resampling method to use. Options: [``sinc_interp_hann``, ``sinc_interp_kaiser``] (Default: ``"sinc_interp_hann"``) + dtype (torch.device, optional): + Determines the precision that kernels are pre-computed and cached in. + Note that complex bases are either cfloat or cdouble depending on provided precision. + Options: [``torch.float``, ``torch.double``] (Default: ``torch.float``) Example >>> transform = transforms.InverseCQT() >>> waveform = transform(cqt) # (..., time) """ - __constants__ = ["sample_rate", "hop_length", "f_min", "bins_per_octave", "window_fn", "resampling_method"] + __constants__ = [ + "sample_rate", "hop_length", "f_min", "bins_per_octave", "window_fn", "resampling_method", + ] def __init__( self, @@ -897,13 +919,14 @@ def __init__( bins_per_octave: int = 12, window_fn: Callable[..., Tensor] = torch.hann_window, resampling_method: str = "sinc_interp_hann", + dtype: Optional[torch.dtype] = torch.float, ) -> None: super(InverseCQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.InverseCQT") self.sample_rate = sample_rate n_filters = min(bins_per_octave, n_bins) - frequencies, n_octaves = F.frequency_set(f_min, n_bins, bins_per_octave) + frequencies, n_octaves = F.frequency_set(f_min, n_bins, bins_per_octave, dtype=dtype) alpha = F.relative_bandwidths(frequencies, n_bins, bins_per_octave) freq_lengths, _ = F.wavelet_lengths(frequencies, self.sample_rate, alpha, 0.) @@ -938,7 +961,7 @@ def __init__( octave_alphas = alpha[indices] # Compute wavelet filterbanks - basis, lengths = F.wavelet_fbank(octave_freqs, temp_sr, octave_alphas, 0., window_fn) + basis, lengths = F.wavelet_fbank(octave_freqs, temp_sr, octave_alphas, 0., window_fn, dtype=dtype) n_fft = basis.shape[1] # Normalize wrt FFT window length @@ -971,8 +994,6 @@ def forward(self, cqt: Tensor) -> Tensor: Returns: Tensor: waveform of size (..., channels, time). """ - waveform: Tensor - # Iterate down the octaves for buffer_index, (temp_sr, temp_hop, indices) in enumerate(self.forward_params): # Inverse project the basis @@ -983,12 +1004,9 @@ def forward(self, cqt: Tensor) -> Tensor: getattr(self, f"frequency_pow_{buffer_index}"), cqt[..., indices, :], ) - # Taken from librosa n_fft = 2 * (temp_proj.shape[-2] - 1) if temp_proj.ndim == 4: - temp_waveform: Tensor - # torch istft does not support 4D computation yet # iterate through channels for stft computation for channel in range(temp_proj.shape[1]): @@ -1006,10 +1024,7 @@ def forward(self, cqt: Tensor) -> Tensor: else: temp_waveform = torch.istft( - temp_proj, - n_fft=n_fft, - hop_length=temp_hop, - window=self.ones(n_fft), + temp_proj, n_fft=n_fft, hop_length=temp_hop, window=self.ones(n_fft), ) # Resample to desired output shape From a9ec66b5dfd7492d8be6e507186202dc0dcd8847 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Mon, 24 Jun 2024 15:21:13 +0000 Subject: [PATCH 31/44] Batch consistency tests for transforms. --- .../transforms/batch_consistency_test.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/torchaudio_unittest/transforms/batch_consistency_test.py b/test/torchaudio_unittest/transforms/batch_consistency_test.py index 843bd2e4ee..ee1ba750b8 100644 --- a/test/torchaudio_unittest/transforms/batch_consistency_test.py +++ b/test/torchaudio_unittest/transforms/batch_consistency_test.py @@ -105,6 +105,26 @@ def test_batch_melspectrogram(self): transform = T.MelSpectrogram() self.assert_batch_consistency(transform, waveform) + + def test_batch_vqt(self): + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) + waveform = waveform.reshape(3, 2, -1) + transform = T.VQT(sample_rate=8000, hop_length=200, n_bins=72) + + self.assert_batch_consistency(transform, waveform) + + def test_batch_cqt(self): + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) + waveform = waveform.reshape(3, 2, -1) + transform = T.CQT(sample_rate=8000, hop_length=200, n_bins=72) + + self.assert_batch_consistency(transform, waveform) + + def test_batch_inverse_cqt(self): + cqt = torch.randn(3, 2, 72, 41, dtype=torch.cdouble) + transform = T.InverseCQT(sample_rate=8000, hop_length=200, n_bins=72, dtype=torch.double) + + self.assert_batch_consistency(transform, cqt) def test_batch_mfcc(self): waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) From bcff9648c121612168061a680eee7136ec3f5d65 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Mon, 24 Jun 2024 21:53:18 +0000 Subject: [PATCH 32/44] Librosa VQT compatibility test. --- .../librosa_compatibility_test_impl.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py b/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py index 6d1d79fae1..a06d3e5801 100644 --- a/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py +++ b/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py @@ -92,6 +92,44 @@ def test_MelSpectrogram(self, n_fft, hop_length, n_mels, norm, mel_scale): mel_scale=mel_scale, ).to(self.device, self.dtype)(waveform)[0] self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5) + + @nested_params( + [ + param(sample_rate=2000, hop_length=200, n_bins=36, bins_per_octave=12, gamma=2., atol=0.7, rtol=0.7), + param(sample_rate=2000, hop_length=200, n_bins=3, bins_per_octave=1, gamma=4., atol=0.7, rtol=0.7), + param(sample_rate=1000, hop_length=100, n_bins=16, bins_per_octave=8, gamma=6., atol=0.35, rtol=0.35), + param(sample_rate=500, hop_length=50, n_bins=4, bins_per_octave=4, gamma=8., atol=1e-7, rtol=1e-7), + ], + ) + def test_VQT(self, sample_rate, hop_length, n_bins, bins_per_octave, gamma, atol, rtol): + """ + Differences in resampling, which occurs n_bins/bins_per_octave - 1 times, between torch and librosa + lead to diverging VQTs. This is likely as close as it can get. + """ + f_min = 32.703 + waveform = get_whitenoise(sample_rate=sample_rate, dtype=self.dtype).to(self.device) + + expected = librosa.core.constantq.vqt( + y=waveform[0].cpu().numpy(), + sr=sample_rate, + hop_length=hop_length, + fmin=f_min, + n_bins=n_bins, + gamma=gamma, + bins_per_octave=bins_per_octave, + sparsity=0., # torchaudio VQT implemeted with sparsity 0 + res_type="sinc_best", # torchaudio resampling roughly equivalent to sinc_best + ) + result = T.VQT( + sample_rate=sample_rate, + hop_length=hop_length, + f_min=f_min, + n_bins=n_bins, + gamma=gamma, + bins_per_octave=bins_per_octave, + dtype=self.dtype, + ).to(self.device)(waveform)[0] + self.assertEqual(result, torch.from_numpy(expected), atol=atol, rtol=rtol) def test_magnitude_to_db(self): spectrogram = get_spectrogram(get_whitenoise(), n_fft=400, power=2).to(self.device, self.dtype) From 4209aaaa4011c296266bf9d78bdb1035ae129b7c Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Tue, 25 Jun 2024 00:36:29 +0000 Subject: [PATCH 33/44] Typing change. --- src/torchaudio/transforms/_transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 63c8a737aa..0f84cddc38 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -670,7 +670,7 @@ def __init__( bins_per_octave: int = 12, window_fn: Callable[..., Tensor] = torch.hann_window, resampling_method: str = "sinc_interp_hann", - dtype: Optional[torch.dtype] = torch.float, + dtype: torch.dtype = torch.float, ) -> None: super(VQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.VQT") @@ -844,7 +844,7 @@ def __init__( bins_per_octave: int = 12, window_fn: Callable[..., Tensor] = torch.hann_window, resampling_method: str = "sinc_interp_hann", - dtype: Optional[torch.dtype] = torch.float, + dtype: torch.dtype = torch.float, ) -> None: super(CQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.CQT") @@ -919,7 +919,7 @@ def __init__( bins_per_octave: int = 12, window_fn: Callable[..., Tensor] = torch.hann_window, resampling_method: str = "sinc_interp_hann", - dtype: Optional[torch.dtype] = torch.float, + dtype: torch.dtype = torch.float, ) -> None: super(InverseCQT, self).__init__() torch._C._log_api_usage_once("torchaudio.transforms.InverseCQT") From e764efe01e48cd4f930ca389d6d1f2e1dd5851c1 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Tue, 25 Jun 2024 00:37:08 +0000 Subject: [PATCH 34/44] CQT librosa tests. --- .../librosa_compatibility_test_impl.py | 112 ++++++++++++------ 1 file changed, 74 insertions(+), 38 deletions(-) diff --git a/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py b/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py index a06d3e5801..92c2c93c01 100644 --- a/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py +++ b/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py @@ -93,44 +93,6 @@ def test_MelSpectrogram(self, n_fft, hop_length, n_mels, norm, mel_scale): ).to(self.device, self.dtype)(waveform)[0] self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5) - @nested_params( - [ - param(sample_rate=2000, hop_length=200, n_bins=36, bins_per_octave=12, gamma=2., atol=0.7, rtol=0.7), - param(sample_rate=2000, hop_length=200, n_bins=3, bins_per_octave=1, gamma=4., atol=0.7, rtol=0.7), - param(sample_rate=1000, hop_length=100, n_bins=16, bins_per_octave=8, gamma=6., atol=0.35, rtol=0.35), - param(sample_rate=500, hop_length=50, n_bins=4, bins_per_octave=4, gamma=8., atol=1e-7, rtol=1e-7), - ], - ) - def test_VQT(self, sample_rate, hop_length, n_bins, bins_per_octave, gamma, atol, rtol): - """ - Differences in resampling, which occurs n_bins/bins_per_octave - 1 times, between torch and librosa - lead to diverging VQTs. This is likely as close as it can get. - """ - f_min = 32.703 - waveform = get_whitenoise(sample_rate=sample_rate, dtype=self.dtype).to(self.device) - - expected = librosa.core.constantq.vqt( - y=waveform[0].cpu().numpy(), - sr=sample_rate, - hop_length=hop_length, - fmin=f_min, - n_bins=n_bins, - gamma=gamma, - bins_per_octave=bins_per_octave, - sparsity=0., # torchaudio VQT implemeted with sparsity 0 - res_type="sinc_best", # torchaudio resampling roughly equivalent to sinc_best - ) - result = T.VQT( - sample_rate=sample_rate, - hop_length=hop_length, - f_min=f_min, - n_bins=n_bins, - gamma=gamma, - bins_per_octave=bins_per_octave, - dtype=self.dtype, - ).to(self.device)(waveform)[0] - self.assertEqual(result, torch.from_numpy(expected), atol=atol, rtol=rtol) - def test_magnitude_to_db(self): spectrogram = get_spectrogram(get_whitenoise(), n_fft=400, power=2).to(self.device, self.dtype) result = T.AmplitudeToDB("magnitude", 80.0).to(self.device, self.dtype)(spectrogram)[0] @@ -194,3 +156,77 @@ def test_spectral_centroid(self, n_fft, hop_length): y=waveform[0].cpu().numpy(), sr=sample_rate, n_fft=n_fft, hop_length=hop_length, pad_mode="reflect" ) self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5) + + @nested_params( + [ + param(sample_rate=1000, hop_length=100, n_bins=36, bins_per_octave=12, gamma=2., atol=0.3, rtol=0.3), + param(sample_rate=1000, hop_length=100, n_bins=3, bins_per_octave=1, gamma=4., atol=0.3, rtol=0.3), + param(sample_rate=500, hop_length=50, n_bins=16, bins_per_octave=8, gamma=6., atol=0.2, rtol=0.2), + param(sample_rate=250, hop_length=25, n_bins=4, bins_per_octave=4, gamma=8., atol=1e-7, rtol=1e-7), + ], + ) + def test_VQT(self, sample_rate, hop_length, n_bins, bins_per_octave, gamma, atol, rtol): + """ + Differences in resampling, which occurs n_bins/bins_per_octave - 1 times, between torch and librosa + lead to diverging VQTs. This is likely as close as it can get. + """ + f_min = 32.703 + waveform = get_whitenoise(sample_rate=sample_rate, dtype=self.dtype).to(self.device) + + expected = librosa.core.constantq.vqt( + y=waveform[0].cpu().numpy(), + sr=sample_rate, + hop_length=hop_length, + fmin=f_min, + n_bins=n_bins, + gamma=gamma, + bins_per_octave=bins_per_octave, + sparsity=0., # torchaudio VQT implemeted with sparsity 0 + res_type="sinc_best", # torchaudio resampling roughly equivalent to sinc_best + ) + result = T.VQT( + sample_rate=sample_rate, + hop_length=hop_length, + f_min=f_min, + n_bins=n_bins, + gamma=gamma, + bins_per_octave=bins_per_octave, + dtype=self.dtype, + ).to(self.device)(waveform)[0] + self.assertEqual(result, torch.from_numpy(expected), atol=atol, rtol=rtol) + + @nested_params( + [ + param(sample_rate=1000, hop_length=100, n_bins=36, bins_per_octave=12, atol=0.3, rtol=0.3), + param(sample_rate=1000, hop_length=100, n_bins=3, bins_per_octave=1, atol=0.3, rtol=0.3), + param(sample_rate=500, hop_length=50, n_bins=16, bins_per_octave=8, atol=0.2, rtol=0.2), + param(sample_rate=250, hop_length=25, n_bins=4, bins_per_octave=4, atol=1e-7, rtol=1e-7), + ], + ) + def test_CQT(self, sample_rate, hop_length, n_bins, bins_per_octave, atol, rtol): + """ + Differences in resampling, which occurs n_bins/bins_per_octave - 1 times, between torch and librosa + lead to diverging CQTs. This is likely as close as it can get. + """ + f_min = 32.703 + waveform = get_whitenoise(sample_rate=sample_rate, duration=2, dtype=self.dtype).to(self.device) + + expected = librosa.cqt( + y=waveform[0].cpu().numpy(), + sr=sample_rate, + hop_length=hop_length, + fmin=f_min, + n_bins=n_bins, + bins_per_octave=bins_per_octave, + sparsity=0., # torchaudio CQT implemeted with sparsity 0 + res_type="sinc_best", # torchaudio resampling roughly equivalent to sinc_best + ) + result = T.CQT( + sample_rate=sample_rate, + hop_length=hop_length, + f_min=f_min, + n_bins=n_bins, + bins_per_octave=bins_per_octave, + dtype=self.dtype, + ).to(self.device)(waveform)[0] + self.assertEqual(result, torch.from_numpy(expected), atol=atol, rtol=rtol) From 0680828dd88c98474a1b62b24ab515a7ced61099 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Tue, 25 Jun 2024 16:38:05 +0000 Subject: [PATCH 35/44] Inverse CQT librosa compatibility tests. --- .../librosa_compatibility_test_impl.py | 48 ++++++++++++++++++- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py b/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py index 92c2c93c01..fbc691c47f 100644 --- a/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py +++ b/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py @@ -160,7 +160,7 @@ def test_spectral_centroid(self, n_fft, hop_length): @nested_params( [ param(sample_rate=1000, hop_length=100, n_bins=36, bins_per_octave=12, gamma=2., atol=0.3, rtol=0.3), - param(sample_rate=1000, hop_length=100, n_bins=3, bins_per_octave=1, gamma=4., atol=0.3, rtol=0.3), + param(sample_rate=1000, hop_length=10, n_bins=3, bins_per_octave=1, gamma=4., atol=0.2, rtol=0.2), param(sample_rate=500, hop_length=50, n_bins=16, bins_per_octave=8, gamma=6., atol=0.2, rtol=0.2), param(sample_rate=250, hop_length=25, n_bins=4, bins_per_octave=4, gamma=8., atol=1e-7, rtol=1e-7), ], @@ -198,7 +198,7 @@ def test_VQT(self, sample_rate, hop_length, n_bins, bins_per_octave, gamma, atol @nested_params( [ param(sample_rate=1000, hop_length=100, n_bins=36, bins_per_octave=12, atol=0.3, rtol=0.3), - param(sample_rate=1000, hop_length=100, n_bins=3, bins_per_octave=1, atol=0.3, rtol=0.3), + param(sample_rate=1000, hop_length=10, n_bins=3, bins_per_octave=1, atol=0.2, rtol=0.2), param(sample_rate=500, hop_length=50, n_bins=16, bins_per_octave=8, atol=0.2, rtol=0.2), param(sample_rate=250, hop_length=25, n_bins=4, bins_per_octave=4, atol=1e-7, rtol=1e-7), ], @@ -230,3 +230,47 @@ def test_CQT(self, sample_rate, hop_length, n_bins, bins_per_octave, atol, rtol) dtype=self.dtype, ).to(self.device)(waveform)[0] self.assertEqual(result, torch.from_numpy(expected), atol=atol, rtol=rtol) + + @nested_params( + [ + param(sample_rate=1000, hop_length=100, n_bins=36, bins_per_octave=12, atol=0.02, rtol=0.02), + param(sample_rate=1000, hop_length=10, n_bins=3, bins_per_octave=1, atol=0.01, rtol=0.01), + param(sample_rate=500, hop_length=50, n_bins=16, bins_per_octave=8, atol=0.01, rtol=0.01), + param(sample_rate=250, hop_length=25, n_bins=4, bins_per_octave=4, atol=1e-7, rtol=1e-7), + ], + ) + def test_InverseCQT(self, sample_rate, hop_length, n_bins, bins_per_octave, atol, rtol): + """ + Differences in resampling, which occurs n_bins/bins_per_octave - 1 times, between torch and librosa + lead to diverging CQTs. This is likely as close as it can get. + """ + f_min = 32.703 + waveform = get_whitenoise(sample_rate=sample_rate, duration=4, dtype=self.dtype).to(self.device) + + cqt = T.CQT( + sample_rate=sample_rate, + hop_length=hop_length, + f_min=f_min, + n_bins=n_bins, + bins_per_octave=bins_per_octave, + dtype=self.dtype, + ).to(self.device)(waveform) + + expected = librosa.core.icqt( + C=cqt[0].cpu().numpy(), + sr=sample_rate, + hop_length=hop_length, + fmin=f_min, + bins_per_octave=bins_per_octave, + sparsity=0., # torchaudio iCQT implemeted with sparsity 0 + res_type="sinc_best", # torchaudio resampling roughly equivalent to sinc_best + ) + result = T.InverseCQT( + sample_rate=sample_rate, + hop_length=hop_length, + f_min=f_min, + n_bins=n_bins, + bins_per_octave=bins_per_octave, + dtype=self.dtype, + ).to(self.device)(cqt)[0] + self.assertEqual(result, torch.from_numpy(expected), atol=atol, rtol=rtol) From c17e90de280a893dfb48737fcb7a71fb6810ba55 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Wed, 26 Jun 2024 16:19:09 +0000 Subject: [PATCH 36/44] Make sure CQT is VQT with gamma set to 0 test. --- .../transforms/transforms_test_impl.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/torchaudio_unittest/transforms/transforms_test_impl.py b/test/torchaudio_unittest/transforms/transforms_test_impl.py index 2e70ab4ad3..5d7e6252bd 100644 --- a/test/torchaudio_unittest/transforms/transforms_test_impl.py +++ b/test/torchaudio_unittest/transforms/transforms_test_impl.py @@ -109,6 +109,40 @@ def test_roundtrip_spectrogram(self, **args): restored = inv_s.forward(transformed, length=waveform.shape[-1]) self.assertEqual(waveform, restored, atol=1e-6, rtol=1e-6) + @parameterized.expand( + [ + param(sample_rate=1000, hop_length=100, n_bins=36, bins_per_octave=12), + param(sample_rate=1000, hop_length=10, n_bins=3, bins_per_octave=1), + param(sample_rate=500, hop_length=50, n_bins=16, bins_per_octave=8), + param(sample_rate=250, hop_length=25, n_bins=4, bins_per_octave=4), + ], + ) + def test_CQT_VQT_match(self, sample_rate, hop_length, n_bins, bins_per_octave): + """Make sure that the CQT is the VQT with gamma set to 0.""" + f_min = 32.703 + waveform = get_whitenoise(sample_rate=sample_rate, dtype=self.dtype).to(self.device) + + cqt = T.CQT( + sample_rate=sample_rate, + hop_length=hop_length, + f_min=f_min, + n_bins=n_bins, + bins_per_octave=bins_per_octave, + dtype=self.dtype, + ).to(self.device)(waveform) + + vqt = T.VQT( + sample_rate=sample_rate, + hop_length=hop_length, + f_min=f_min, + n_bins=n_bins, + gamma=0., + bins_per_octave=bins_per_octave, + dtype=self.dtype, + ).to(self.device)(waveform) + + self.assertEqual(cqt, vqt) + @parameterized.expand( [ param(0.5, 1, True, False), From ede80a20f73a8d02ee2a4d93e6352c1fb5c36fb0 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Wed, 26 Jun 2024 16:32:18 +0000 Subject: [PATCH 37/44] Typo. --- .../transforms/librosa_compatibility_test_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py b/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py index fbc691c47f..ebcb9bee08 100644 --- a/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py +++ b/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py @@ -242,7 +242,7 @@ def test_CQT(self, sample_rate, hop_length, n_bins, bins_per_octave, atol, rtol) def test_InverseCQT(self, sample_rate, hop_length, n_bins, bins_per_octave, atol, rtol): """ Differences in resampling, which occurs n_bins/bins_per_octave - 1 times, between torch and librosa - lead to diverging CQTs. This is likely as close as it can get. + lead to diverging iCQTs. This is likely as close as it can get. """ f_min = 32.703 waveform = get_whitenoise(sample_rate=sample_rate, duration=4, dtype=self.dtype).to(self.device) From 8e3c9efa24683bb9ad27df908c7ced7fbf807ac1 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Thu, 27 Jun 2024 01:56:58 +0000 Subject: [PATCH 38/44] Bug fixes and top notch librosa matching. --- src/torchaudio/functional/functional.py | 4 +- src/torchaudio/transforms/_transforms.py | 69 ++++++++++-------------- 2 files changed, 29 insertions(+), 44 deletions(-) diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 7f0f2a4abd..91ecaf7ce6 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -2579,7 +2579,7 @@ def relative_bandwidths(freqs: Tensor, n_bins: int, bins_per_octave: int) -> Ten torch.Tensor: relative bandwidths for set of frequencies. """ if min(freqs) < 0. or n_bins < 1 or bins_per_octave < 1: - raise ValueError("freqs must be positive. n_bins and bins_per_octave must be ints and superior to 1.") + raise ValueError("freqs must be positive. n_bins and bins_per_octave must be positive ints.") if n_bins > 1: # Approximate local octave resolution around each frequency @@ -2667,7 +2667,7 @@ def wavelet_fbank( # Next power of 2 pad_to_size = 1<<(int(max(lengths))-1).bit_length() - + for index, (ilen, freq) in enumerate(zip(lengths, freqs)): # Build filter with length ceil(ilen) # Use float32 in order to output complex(float) numbers later diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 0f84cddc38..fabcd13890 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -626,7 +626,7 @@ class VQT(torch.nn.Module): .. devices:: CPU CUDA - .. properties:: Autograd TorchScript + .. properties:: Autograd Sources * https://librosa.org/doc/main/generated/librosa.vqt.html @@ -636,10 +636,10 @@ class VQT(torch.nn.Module): Args: sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) hop_length (int, optional): Length of hop between VQT windows. (Default: ``400``) - f_min (float, optional): Minimum frequency, which corresponds to first note. + f_min (float, optional): Minimum frequency, which corresponds to first note. (Default: ``32.703``, or the frequency of C1 in Hz) n_bins (int, optional): Number of VQT frequency bins, starting at ``f_min``. (Default: ``84``) - gamma (float, optional): Offset that controls VQT filter lengths. Larger values + gamma (float, optional): Offset that controls VQT filter lengths. Larger values increase the time resolution at lower frequencies. (Default: ``0.``) bins_per_octave (int, optional): Number of bins per octave. (Default: ``12``) window_fn (Callable[..., Tensor], optional): A function to create a window tensor @@ -647,8 +647,8 @@ class VQT(torch.nn.Module): resampling_method (str, optional): The resampling method to use. Options: [``sinc_interp_hann``, ``sinc_interp_kaiser``] (Default: ``"sinc_interp_hann"``) dtype (torch.device, optional): - Determines the precision that kernels are pre-computed and cached in. Note that complex bases - are either cfloat or cdouble depending on provided precision. + Determines the precision that kernels are pre-computed and cached in. Note that complex + bases are either cfloat or cdouble depending on provided precision. Options: [``torch.float``, ``torch.double``] (Default: ``torch.float``) Example @@ -656,9 +656,7 @@ class VQT(torch.nn.Module): >>> transform = transforms.VQT(sample_rate) >>> vqt = transform(waveform) # (..., n_bins, time) """ - __constants__ = [ - "sample_rate", "hop_length", "f_min", "n_bins", "gamma", "bins_per_octave", "window_fn", "resampling_method", - ] + __constants__ = ["resample", "forward_params"] def __init__( self, @@ -678,11 +676,9 @@ def __init__( n_filters = min(bins_per_octave, n_bins) frequencies, n_octaves = F.frequency_set(f_min, n_bins, bins_per_octave, dtype) alpha = F.relative_bandwidths(frequencies, n_bins, bins_per_octave) - freq_lengths, cutoff_freq = F.wavelet_lengths(frequencies, sample_rate, alpha, gamma) + _, cutoff_freq = F.wavelet_lengths(frequencies, sample_rate, alpha, gamma) self.resample = Resample(2, 1, resampling_method, dtype=dtype) - self.register_buffer("expanded_lengths", freq_lengths.unsqueeze(0).unsqueeze(-1)) - self.ones = lambda x: torch.ones(x, device=self.expanded_lengths.device) # Generate errors or warnings if needed # Number of divisions by 2 before number becomes odd @@ -691,18 +687,18 @@ def __init__( if cutoff_freq > nyquist: raise ValueError( - f"Maximum bin cutoff frequency is {cutoff_freq} and superior to the " + f"Maximum bin cutoff frequency is approximately {cutoff_freq} and superior to the " f"Nyquist frequency {nyquist}. Try to reduce the number of frequency bins." ) - if num_hop_downsamples > n_octaves: + if n_octaves - 1 > num_hop_downsamples: warnings.warn( f"Hop length can be divided {num_hop_downsamples} times by 2 before becoming " - f"odd. The VQT is however being computed for {n_octaves} octaves. Consider lowering " - "the hop length or increasing the number of bins for more accurate results." + f"odd. The VQT is however being computed for {n_octaves} octaves. Consider setting " + "the hop length to a ``more even'' number for more accurate results." ) if nyquist / cutoff_freq > 4: warnings.warn( - f"The Nyquist frequency {nyquist} is significantly higher than the highest filter's " + f"The Nyquist frequency {nyquist} is significantly higher than the highest filter's approximate " f"cutoff frequency {cutoff_freq}. Consider resampling your signal to a lower sample " "rate or increasing the number of bins before VQT computation for more accurate results." ) @@ -760,7 +756,7 @@ def forward(self, waveform: Tensor) -> Tensor: waveform[:, channel, :], n_fft=n_fft, hop_length=temp_hop, - window=self.ones(n_fft), + window=torch.ones(n_fft), pad_mode='constant', return_complex=True, ) @@ -775,7 +771,7 @@ def forward(self, waveform: Tensor) -> Tensor: waveform, n_fft=n_fft, hop_length=temp_hop, - window=self.ones(n_fft), + window=torch.ones(n_fft), pad_mode='constant', return_complex=True, ) @@ -788,12 +784,9 @@ def forward(self, waveform: Tensor) -> Tensor: else: vqt = torch.cat([temp_vqt, vqt], dim=-2) - # Resampling if temp_hop % 2 == 0: waveform = self.resample(waveform) - - # Scale VQT by square-root of the length of each channel's filter - vqt /= torch.sqrt(self.expanded_lengths) + waveform /= math.sqrt(0.5) return vqt @@ -803,7 +796,7 @@ class CQT(torch.nn.Module): .. devices:: CPU CUDA - .. properties:: Autograd TorchScript + .. properties:: Autograd Sources * https://librosa.org/doc/main/generated/librosa.cqt.html @@ -822,8 +815,8 @@ class CQT(torch.nn.Module): resampling_method (str, optional): The resampling method to use. Options: [``sinc_interp_hann``, ``sinc_interp_kaiser``] (Default: ``"sinc_interp_hann"``) dtype (torch.device, optional): - Determines the precision that kernels are pre-computed and cached in. Note that complex bases - are either cfloat or cdouble depending on provided precision. + Determines the precision that kernels are pre-computed and cached in. Note that complex + bases are either cfloat or cdouble depending on provided precision. Options: [``torch.float``, ``torch.double``] (Default: ``torch.float``) Example @@ -831,9 +824,7 @@ class CQT(torch.nn.Module): >>> transform = transforms.CQT(sample_rate) >>> cqt = transform(waveform) # (..., n_bins, time) """ - __constants__ = [ - "sample_rate", "hop_length", "f_min", "n_bins", "bins_per_octave", "window_fn", "resampling_method", - ] + __constants__ = ["transform"] def __init__( self, @@ -879,7 +870,7 @@ class InverseCQT(torch.nn.Module): .. devices:: CPU CUDA - .. properties:: Autograd TorchScript + .. properties:: Autograd Sources * https://librosa.org/doc/main/generated/librosa.icqt.html @@ -888,7 +879,7 @@ class InverseCQT(torch.nn.Module): Args: sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) - hop_length (int, optional): Length of hop between VQT windows. (Default: ``400``) + hop_length (int, optional): Length of hop between CQT windows. (Default: ``400``) f_min (float, optional): Minimum frequency, which corresponds to first note. (Default: ``32.703``, or the frequency of C1 in Hz) n_bins (int, optional): Number of CQT frequency bins, starting at ``f_min``. (Default: ``84``) @@ -898,17 +889,15 @@ class InverseCQT(torch.nn.Module): resampling_method (str, optional): The resampling method to use. Options: [``sinc_interp_hann``, ``sinc_interp_kaiser``] (Default: ``"sinc_interp_hann"``) dtype (torch.device, optional): - Determines the precision that kernels are pre-computed and cached in. - Note that complex bases are either cfloat or cdouble depending on provided precision. + Determines the precision that kernels are pre-computed and cached in. Note that complex + bases are either cfloat or cdouble depending on provided precision. Options: [``torch.float``, ``torch.double``] (Default: ``torch.float``) Example >>> transform = transforms.InverseCQT() >>> waveform = transform(cqt) # (..., time) """ - __constants__ = [ - "sample_rate", "hop_length", "f_min", "bins_per_octave", "window_fn", "resampling_method", - ] + __constants__ = ["sample_rate", "resampling_method", "forward_params"] def __init__( self, @@ -931,8 +920,6 @@ def __init__( freq_lengths, _ = F.wavelet_lengths(frequencies, self.sample_rate, alpha, 0.) self.resampling_method = resampling_method - self.register_buffer("c_scale", torch.sqrt(freq_lengths)) - self.ones = lambda x: torch.ones(x, device=self.c_scale.device) # Get sample rates and hop lengths used during CQT downsampling sample_rates = [] @@ -998,9 +985,8 @@ def forward(self, cqt: Tensor) -> Tensor: for buffer_index, (temp_sr, temp_hop, indices) in enumerate(self.forward_params): # Inverse project the basis temp_proj = torch.einsum( - 'fc,c,c,...ct->...ft', + 'fc,c,...ct->...ft', getattr(self, f"basis_inverse_{buffer_index}"), - self.c_scale[indices], getattr(self, f"frequency_pow_{buffer_index}"), cqt[..., indices, :], ) @@ -1014,7 +1000,7 @@ def forward(self, cqt: Tensor) -> Tensor: temp_proj[:, channel, :, :], n_fft=n_fft, hop_length=temp_hop, - window=self.ones(n_fft), + window=torch.ones(n_fft), ) if channel == 0: @@ -1024,10 +1010,9 @@ def forward(self, cqt: Tensor) -> Tensor: else: temp_waveform = torch.istft( - temp_proj, n_fft=n_fft, hop_length=temp_hop, window=self.ones(n_fft), + temp_proj, n_fft=n_fft, hop_length=temp_hop, window=torch.ones(n_fft), ) - # Resample to desired output shape temp_waveform = F.resample( temp_waveform, orig_freq=1, From 3217b194bb5cb1a1c39b1ba7cc64d8fb231155ce Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Thu, 27 Jun 2024 15:17:26 +0000 Subject: [PATCH 39/44] Higher frequency librosa q-transform tests. --- .../librosa_compatibility_test_impl.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py b/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py index ebcb9bee08..0e9ac7b3d0 100644 --- a/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py +++ b/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py @@ -159,10 +159,10 @@ def test_spectral_centroid(self, n_fft, hop_length): @nested_params( [ - param(sample_rate=1000, hop_length=100, n_bins=36, bins_per_octave=12, gamma=2., atol=0.3, rtol=0.3), - param(sample_rate=1000, hop_length=10, n_bins=3, bins_per_octave=1, gamma=4., atol=0.2, rtol=0.2), - param(sample_rate=500, hop_length=50, n_bins=16, bins_per_octave=8, gamma=6., atol=0.2, rtol=0.2), - param(sample_rate=250, hop_length=25, n_bins=4, bins_per_octave=4, gamma=8., atol=1e-7, rtol=1e-7), + param(sample_rate=16384, hop_length=256, n_bins=84, bins_per_octave=12, gamma=0., atol=5e-1, rtol=5e-1), + param(sample_rate=4096, hop_length=128, n_bins=40, bins_per_octave=8, gamma=2., atol=2e-1, rtol=2e-1), + param(sample_rate=1024, hop_length=64, n_bins=12, bins_per_octave=4, gamma=4., atol=1e-1, rtol=1e-1), + param(sample_rate=512, hop_length=32, n_bins=12, bins_per_octave=12, gamma=8., atol=1e-6, rtol=1e-6), ], ) def test_VQT(self, sample_rate, hop_length, n_bins, bins_per_octave, gamma, atol, rtol): @@ -171,7 +171,7 @@ def test_VQT(self, sample_rate, hop_length, n_bins, bins_per_octave, gamma, atol lead to diverging VQTs. This is likely as close as it can get. """ f_min = 32.703 - waveform = get_whitenoise(sample_rate=sample_rate, dtype=self.dtype).to(self.device) + waveform = get_whitenoise(sample_rate=sample_rate, duration=2, dtype=self.dtype).to(self.device) expected = librosa.core.constantq.vqt( y=waveform[0].cpu().numpy(), @@ -183,6 +183,7 @@ def test_VQT(self, sample_rate, hop_length, n_bins, bins_per_octave, gamma, atol bins_per_octave=bins_per_octave, sparsity=0., # torchaudio VQT implemeted with sparsity 0 res_type="sinc_best", # torchaudio resampling roughly equivalent to sinc_best + scale=False, ) result = T.VQT( sample_rate=sample_rate, @@ -197,10 +198,10 @@ def test_VQT(self, sample_rate, hop_length, n_bins, bins_per_octave, gamma, atol @nested_params( [ - param(sample_rate=1000, hop_length=100, n_bins=36, bins_per_octave=12, atol=0.3, rtol=0.3), - param(sample_rate=1000, hop_length=10, n_bins=3, bins_per_octave=1, atol=0.2, rtol=0.2), - param(sample_rate=500, hop_length=50, n_bins=16, bins_per_octave=8, atol=0.2, rtol=0.2), - param(sample_rate=250, hop_length=25, n_bins=4, bins_per_octave=4, atol=1e-7, rtol=1e-7), + param(sample_rate=16384, hop_length=256, n_bins=84, bins_per_octave=12, atol=5e-1, rtol=5e-1), + param(sample_rate=4096, hop_length=128, n_bins=40, bins_per_octave=8, atol=2e-1, rtol=2e-1), + param(sample_rate=1024, hop_length=64, n_bins=12, bins_per_octave=4, atol=1e-1, rtol=1e-1), + param(sample_rate=512, hop_length=32, n_bins=12, bins_per_octave=12, atol=1e-6, rtol=1e-6), ], ) def test_CQT(self, sample_rate, hop_length, n_bins, bins_per_octave, atol, rtol): @@ -220,6 +221,7 @@ def test_CQT(self, sample_rate, hop_length, n_bins, bins_per_octave, atol, rtol) bins_per_octave=bins_per_octave, sparsity=0., # torchaudio CQT implemeted with sparsity 0 res_type="sinc_best", # torchaudio resampling roughly equivalent to sinc_best + scale=False, ) result = T.CQT( sample_rate=sample_rate, @@ -233,10 +235,10 @@ def test_CQT(self, sample_rate, hop_length, n_bins, bins_per_octave, atol, rtol) @nested_params( [ - param(sample_rate=1000, hop_length=100, n_bins=36, bins_per_octave=12, atol=0.02, rtol=0.02), - param(sample_rate=1000, hop_length=10, n_bins=3, bins_per_octave=1, atol=0.01, rtol=0.01), - param(sample_rate=500, hop_length=50, n_bins=16, bins_per_octave=8, atol=0.01, rtol=0.01), - param(sample_rate=250, hop_length=25, n_bins=4, bins_per_octave=4, atol=1e-7, rtol=1e-7), + param(sample_rate=16384, hop_length=256, n_bins=84, bins_per_octave=12, atol=5e-1, rtol=5e-1), + param(sample_rate=4096, hop_length=128, n_bins=40, bins_per_octave=8, atol=2e-1, rtol=2e-1), + param(sample_rate=1024, hop_length=64, n_bins=12, bins_per_octave=4, atol=1e-1, rtol=1e-1), + param(sample_rate=512, hop_length=32, n_bins=12, bins_per_octave=12, atol=1e-6, rtol=1e-6), ], ) def test_InverseCQT(self, sample_rate, hop_length, n_bins, bins_per_octave, atol, rtol): @@ -245,7 +247,7 @@ def test_InverseCQT(self, sample_rate, hop_length, n_bins, bins_per_octave, atol lead to diverging iCQTs. This is likely as close as it can get. """ f_min = 32.703 - waveform = get_whitenoise(sample_rate=sample_rate, duration=4, dtype=self.dtype).to(self.device) + waveform = get_whitenoise(sample_rate=sample_rate, duration=2, dtype=self.dtype).to(self.device) cqt = T.CQT( sample_rate=sample_rate, @@ -264,6 +266,7 @@ def test_InverseCQT(self, sample_rate, hop_length, n_bins, bins_per_octave, atol bins_per_octave=bins_per_octave, sparsity=0., # torchaudio iCQT implemeted with sparsity 0 res_type="sinc_best", # torchaudio resampling roughly equivalent to sinc_best + scale=False, ) result = T.InverseCQT( sample_rate=sample_rate, From f2632a0425dafbf1184530601a0bd479566e091e Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Thu, 27 Jun 2024 16:06:59 +0000 Subject: [PATCH 40/44] Updated VQT and CQT params in tests. --- .../transforms/transforms_test_impl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/torchaudio_unittest/transforms/transforms_test_impl.py b/test/torchaudio_unittest/transforms/transforms_test_impl.py index 5d7e6252bd..c81c0f742f 100644 --- a/test/torchaudio_unittest/transforms/transforms_test_impl.py +++ b/test/torchaudio_unittest/transforms/transforms_test_impl.py @@ -111,10 +111,10 @@ def test_roundtrip_spectrogram(self, **args): @parameterized.expand( [ - param(sample_rate=1000, hop_length=100, n_bins=36, bins_per_octave=12), - param(sample_rate=1000, hop_length=10, n_bins=3, bins_per_octave=1), - param(sample_rate=500, hop_length=50, n_bins=16, bins_per_octave=8), - param(sample_rate=250, hop_length=25, n_bins=4, bins_per_octave=4), + param(sample_rate=16384, hop_length=256, n_bins=84, bins_per_octave=12), + param(sample_rate=4096, hop_length=128, n_bins=40, bins_per_octave=8), + param(sample_rate=1024, hop_length=64, n_bins=12, bins_per_octave=4), + param(sample_rate=512, hop_length=32, n_bins=12, bins_per_octave=12), ], ) def test_CQT_VQT_match(self, sample_rate, hop_length, n_bins, bins_per_octave): From da23687eaa24c0654eb1058283d65fe8cebc41ef Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Thu, 27 Jun 2024 16:39:34 +0000 Subject: [PATCH 41/44] Removing useless white space change. --- .../transforms/librosa_compatibility_test_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py b/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py index 0e9ac7b3d0..3bfc1a135c 100644 --- a/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py +++ b/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py @@ -92,7 +92,7 @@ def test_MelSpectrogram(self, n_fft, hop_length, n_mels, norm, mel_scale): mel_scale=mel_scale, ).to(self.device, self.dtype)(waveform)[0] self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5) - + def test_magnitude_to_db(self): spectrogram = get_spectrogram(get_whitenoise(), n_fft=400, power=2).to(self.device, self.dtype) result = T.AmplitudeToDB("magnitude", 80.0).to(self.device, self.dtype)(spectrogram)[0] From 6b510f913a7dbbb1f36ce911c347a122cbdfdf28 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Mon, 1 Jul 2024 02:15:10 +0000 Subject: [PATCH 42/44] Small changes. --- src/torchaudio/functional/functional.py | 1 - src/torchaudio/transforms/_transforms.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 91ecaf7ce6..4e0944e0b7 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -2670,7 +2670,6 @@ def wavelet_fbank( for index, (ilen, freq) in enumerate(zip(lengths, freqs)): # Build filter with length ceil(ilen) - # Use float32 in order to output complex(float) numbers later t = torch.arange(-ilen // 2, ilen // 2, dtype=dtype) * 2 * torch.pi * freq / sr sig = torch.cos(t) + 1j * torch.sin(t) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index fabcd13890..160a78b4f4 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -661,7 +661,7 @@ class VQT(torch.nn.Module): def __init__( self, sample_rate: int = 16000, - hop_length: int = 400, + hop_length: int = 256, f_min: float = 32.703, n_bins: int = 84, gamma: float = 0., @@ -829,7 +829,7 @@ class CQT(torch.nn.Module): def __init__( self, sample_rate: int = 16000, - hop_length: int = 400, + hop_length: int = 256, f_min: float = 32.703, n_bins: int = 84, bins_per_octave: int = 12, @@ -902,7 +902,7 @@ class InverseCQT(torch.nn.Module): def __init__( self, sample_rate: int = 16000, - hop_length: int = 400, + hop_length: int = 256, f_min: float = 32.703, n_bins: int = 84, bins_per_octave: int = 12, From 6b3f20f94a824a5b969a3eafe0966650972e5d7b Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Mon, 1 Jul 2024 10:40:18 +0000 Subject: [PATCH 43/44] Creating ones on correct device. --- src/torchaudio/transforms/_transforms.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 160a78b4f4..6a59ecfbd2 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -736,6 +736,9 @@ def __init__( temp_sr /= 2. temp_hop //= 2 + # Create ones on the correct device in the forward pass + self.ones = lambda x: torch.ones(x, device=self.fft_basis_0.device) + def forward(self, waveform: Tensor) -> Tensor: r""" Args: @@ -756,7 +759,7 @@ def forward(self, waveform: Tensor) -> Tensor: waveform[:, channel, :], n_fft=n_fft, hop_length=temp_hop, - window=torch.ones(n_fft), + window=self.ones(n_fft), pad_mode='constant', return_complex=True, ) @@ -771,7 +774,7 @@ def forward(self, waveform: Tensor) -> Tensor: waveform, n_fft=n_fft, hop_length=temp_hop, - window=torch.ones(n_fft), + window=self.ones(n_fft), pad_mode='constant', return_complex=True, ) @@ -972,6 +975,9 @@ def __init__( self.register_buffer(f"frequency_pow_{oct_index}", frequency_pow) self.forward_params.append((temp_sr, temp_hop, indices)) + # Create ones on the correct device in the forward pass + self.ones = lambda x: torch.ones(x, device=self.basis_inverse_0.device) + def forward(self, cqt: Tensor) -> Tensor: r""" Args: @@ -1000,7 +1006,7 @@ def forward(self, cqt: Tensor) -> Tensor: temp_proj[:, channel, :, :], n_fft=n_fft, hop_length=temp_hop, - window=torch.ones(n_fft), + window=self.ones(n_fft), ) if channel == 0: @@ -1010,7 +1016,7 @@ def forward(self, cqt: Tensor) -> Tensor: else: temp_waveform = torch.istft( - temp_proj, n_fft=n_fft, hop_length=temp_hop, window=torch.ones(n_fft), + temp_proj, n_fft=n_fft, hop_length=temp_hop, window=self.ones(n_fft), ) temp_waveform = F.resample( From 71778c5ebac0f51d125995a4dd3bed3f3d6675e8 Mon Sep 17 00:00:00 2001 From: Dorian Desblancs Date: Mon, 1 Jul 2024 13:07:41 +0000 Subject: [PATCH 44/44] Ones dtype. --- src/torchaudio/transforms/_transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 6a59ecfbd2..1017adac75 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -737,7 +737,7 @@ def __init__( temp_hop //= 2 # Create ones on the correct device in the forward pass - self.ones = lambda x: torch.ones(x, device=self.fft_basis_0.device) + self.ones = lambda x: torch.ones(x, dtype=dtype, device=self.fft_basis_0.device) def forward(self, waveform: Tensor) -> Tensor: r""" @@ -976,7 +976,7 @@ def __init__( self.forward_params.append((temp_sr, temp_hop, indices)) # Create ones on the correct device in the forward pass - self.ones = lambda x: torch.ones(x, device=self.basis_inverse_0.device) + self.ones = lambda x: torch.ones(x, dtype=dtype, device=self.basis_inverse_0.device) def forward(self, cqt: Tensor) -> Tensor: r"""