Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CQT, iCQT, and VQT implementations and testing #3804

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
f4dc893
VQT outline with base args.
d-dawg78 Jan 28, 2024
bd75d4d
Equal temparament frequencies set.
d-dawg78 May 29, 2024
73f54b7
Merge branch 'pytorch:main' into main
d-dawg78 May 29, 2024
40d6581
Raise error if max frequency is superior to Nyquist.
d-dawg78 May 29, 2024
3f2689e
VQT wavelet filter creation.
d-dawg78 Jun 6, 2024
f13b46d
Top bin filter cutoff frequencies.
d-dawg78 Jun 6, 2024
d42cd12
Merge branch 'pytorch:main' into main
d-dawg78 Jun 6, 2024
bfcca8a
Warnings for hop length and sample rate values.
d-dawg78 Jun 7, 2024
213e2d8
Forward loop outline.
d-dawg78 Jun 7, 2024
517e6c0
Wavelet basis function implemented.
d-dawg78 Jun 7, 2024
1ed2271
First shot at entire VQT done.
d-dawg78 Jun 8, 2024
eeba783
Sparsified rows.
d-dawg78 Jun 9, 2024
76afee0
Removed sparsity and matched stft to librosa vqt implementation.
d-dawg78 Jun 9, 2024
86d462c
Fixing dot product operation.
d-dawg78 Jun 9, 2024
7aa9b43
Fixed resampling.
d-dawg78 Jun 9, 2024
b6f8b3c
Object-oriented optimizations!
d-dawg78 Jun 11, 2024
09ffe6c
CQT implementation.
d-dawg78 Jun 12, 2024
1529f0a
Splitting functions from classes to be used by iCQT.
d-dawg78 Jun 12, 2024
b76ad57
iCQT outline and VQT batch computation.
d-dawg78 Jun 15, 2024
78a5169
iCQT algorithm start and outline.
d-dawg78 Jun 16, 2024
e06ac1d
Pre-computations done :)
d-dawg78 Jun 16, 2024
878eb26
Make frequencies float32 to avoid icqt einsum issues.
d-dawg78 Jun 16, 2024
53334af
Basis projection.
d-dawg78 Jun 16, 2024
da65ec3
iCQT for 2D tensors :)
d-dawg78 Jun 16, 2024
d146d1a
Comments on the iCQT and a few other spots.
d-dawg78 Jun 16, 2024
170e9fa
Proper iCQT import.
d-dawg78 Jun 16, 2024
fa3298f
Code cleanup in functional.
d-dawg78 Jun 19, 2024
9a824a3
Librosa compatibility functional tests.
d-dawg78 Jun 19, 2024
2ca6bd9
Small comment removed.
d-dawg78 Jun 19, 2024
9edda98
Fixing functional tests with new dtype method.
d-dawg78 Jun 24, 2024
0eee7ad
Proper transform tests.
d-dawg78 Jun 24, 2024
8a5cbe8
Updated src code for float and double transforms.
d-dawg78 Jun 24, 2024
a9ec66b
Batch consistency tests for transforms.
d-dawg78 Jun 24, 2024
bcff964
Librosa VQT compatibility test.
d-dawg78 Jun 24, 2024
4209aaa
Typing change.
d-dawg78 Jun 25, 2024
e764efe
CQT librosa tests.
d-dawg78 Jun 25, 2024
0680828
Inverse CQT librosa compatibility tests.
d-dawg78 Jun 25, 2024
0d58ebd
Merge branch 'pytorch:main' into main
d-dawg78 Jun 25, 2024
c17e90d
Make sure CQT is VQT with gamma set to 0 test.
d-dawg78 Jun 26, 2024
ede80a2
Typo.
d-dawg78 Jun 26, 2024
8e3c9ef
Bug fixes and top notch librosa matching.
d-dawg78 Jun 27, 2024
3217b19
Higher frequency librosa q-transform tests.
d-dawg78 Jun 27, 2024
f2632a0
Updated VQT and CQT params in tests.
d-dawg78 Jun 27, 2024
da23687
Removing useless white space change.
d-dawg78 Jun 27, 2024
6b510f9
Small changes.
d-dawg78 Jul 1, 2024
6b3f20f
Creating ones on correct device.
d-dawg78 Jul 1, 2024
71778c5
Ones dtype.
d-dawg78 Jul 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/torchaudio/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
edit_distance,
fftconvolve,
frechet_distance,
frequency_set,
griffinlim,
inverse_spectrogram,
linear_fbanks,
Expand All @@ -52,6 +53,7 @@
pitch_shift,
preemphasis,
psd,
relative_bandwidths,
resample,
rnnt_loss,
rtf_evd,
Expand All @@ -60,6 +62,8 @@
spectral_centroid,
spectrogram,
speed,
wavelet_fbank,
wavelet_lengths,
)

__all__ = [
Expand Down Expand Up @@ -124,4 +128,8 @@
"preemphasis",
"deemphasis",
"frechet_distance",
"frequency_set",
"relative_bandwidths",
"wavelet_lengths",
"wavelet_fbank",
]
161 changes: 160 additions & 1 deletion src/torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tempfile
import warnings
from collections.abc import Sequence
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Callable

import torch
import torchaudio
Expand Down Expand Up @@ -51,6 +51,10 @@
"speed",
"preemphasis",
"deemphasis",
"frequency_set",
"relative_bandwidths",
"wavelet_lengths",
"wavelet_fbank",
]


Expand Down Expand Up @@ -2533,3 +2537,158 @@ def frechet_distance(mu_x, sigma_x, mu_y, sigma_y):
b = sigma_x.trace() + sigma_y.trace()
c = torch.linalg.eigvals(sigma_x @ sigma_y).sqrt().real.sum()
return a + b - 2 * c


def frequency_set(f_min: float, n_bins: int, bins_per_octave: int, dtype: torch.dtype) -> Tuple[Tensor, int]:
r"""Return a set of frequencies that assumes an equal temperament tuning system.

.. devices:: CPU

Adapted from librosa: https://librosa.org/doc/main/generated/librosa.interval_frequencies.html

Args:
f_min (float): minimum frequency in Hz.
n_bins (int): number of frequency bins.
bins_per_octave (int): number of bins per octave.

Returns:
torch.Tensor: frequencies.
int: number of octaves
"""
if f_min < 0. or n_bins < 1 or bins_per_octave < 1:
raise ValueError("f_min must be positive. n_bins and bins_per_octave must be ints and superior to 1.")

n_octaves = math.ceil(n_bins / bins_per_octave)
ratios = 2.0 ** (torch.arange(0, bins_per_octave * n_octaves, dtype=dtype) / bins_per_octave)
return f_min * ratios[:n_bins], n_octaves


def relative_bandwidths(freqs: Tensor, n_bins: int, bins_per_octave: int) -> Tensor:
r"""Compute relative bandwidths for specified frequencies.

.. devices:: CPU

Adapted from librosa: https://librosa.org/doc/main/generated/librosa.filters.wavelet_lengths.html

Args:
freqs (Tensor): set of frequencies.
n_bins (int): number of frequency bins.
bins_per_octave (int): number of bins per octave.

Returns:
torch.Tensor: relative bandwidths for set of frequencies.
"""
if min(freqs) < 0. or n_bins < 1 or bins_per_octave < 1:
raise ValueError("freqs must be positive. n_bins and bins_per_octave must be positive ints.")

if n_bins > 1:
# Approximate local octave resolution around each frequency
bandpass_octave = torch.empty_like(freqs)
log_freqs = torch.log2(freqs)

# Reflect at the lowest and highest frequencies
bandpass_octave[0] = 1 / (log_freqs[1] - log_freqs[0])
bandpass_octave[-1] = 1 / (log_freqs[-1] - log_freqs[-2])

# Centered difference
bandpass_octave[1:-1] = 2 / (log_freqs[2:] - log_freqs[:-2])

# Relative bandwidths
alpha = (2. ** (2 / bandpass_octave) - 1) / (2. ** (2 / bandpass_octave) + 1)
else:
# Special case when single basis frequency is used
rel_band_coeff = 2. ** (1. / bins_per_octave)
alpha = torch.tensor([(rel_band_coeff**2 - 1) / (rel_band_coeff**2 + 1)], dtype=freqs.dtype)

return alpha


def wavelet_lengths(freqs: Tensor, sr: float, alpha: Tensor, gamma: float) -> Tuple[Tensor, float]:
r"""Length of each filter in a wavelet basis.

.. devices:: CPU

Source:
* https://librosa.org/doc/main/generated/librosa.filters.wavelet_lengths.html

Args:
freqs (Tensor): set of frequencies.
sr (float): sample rate.
alpha (Tensor): relative bandwidths for set of frequencies.
gamma (float): bandwidth offset for filter length computation.

Returns:
Tensor: filter lengths.
float: cutoff frequency of highest bin.
"""
if gamma < 0. or sr < 0.:
raise ValueError("gamma and sr must be positive!")

if min(freqs) < 0. or min(alpha) < 0.:
raise ValueError("freqs and alpha must be positive!")

# We assume filter_scale (librosa param) is 1
Q = 1. / alpha

# Output upper bound cutoff frequency
# 3.0 > all common window function bandwidths
# https://librosa.org/doc/main/_modules/librosa/filters.html
cutoff_freq = max(freqs * (1 + 0.5 * 3.0 / Q) + 0.5 * gamma)

# Convert frequencies to filter lengths
lengths = Q * sr / (freqs + gamma / alpha)

return lengths, cutoff_freq


def wavelet_fbank(
freqs: Tensor, sr: float, alpha: Tensor, gamma: float, window_fn: Callable[..., Tensor], dtype: torch.dtype,
) -> Tuple[Tensor, Tensor]:
r"""Wavelet filterbank constructed from set of center frequencies.

.. devices:: CPU

Source:
* https://librosa.org/doc/main/generated/librosa.filters.wavelet.html

Args:
freqs (Tensor): set of frequencies.
sr (float): sample rate.
alpha (Tensor): relative bandwidths for set of frequencies.
gamma (float): bandwidth offset for filter length computation.
window_fn (Callable[..., Tensor]): a function to create a window tensor.

Returns:
Tensor: wavelet filters.
Tensor: wavelet filter lengths.
"""
# First get filter lengths
lengths, _ = wavelet_lengths(freqs=freqs, sr=sr, alpha=alpha, gamma=gamma)

# Next power of 2
pad_to_size = 1<<(int(max(lengths))-1).bit_length()

for index, (ilen, freq) in enumerate(zip(lengths, freqs)):
# Build filter with length ceil(ilen)
t = torch.arange(-ilen // 2, ilen // 2, dtype=dtype) * 2 * torch.pi * freq / sr
sig = torch.cos(t) + 1j * torch.sin(t)

# Multiply with window
sig_len = len(sig)
sig = sig * window_fn(sig_len)

# L1 normalize
sig = torch.nn.functional.normalize(sig, p=1., dim=0)

# Pad signal left and right to correct size
l_pad = math.floor((pad_to_size - sig_len) / 2)
r_pad = math.ceil((pad_to_size - sig_len) / 2)
sig = torch.nn.functional.pad(sig, (l_pad, r_pad), mode='constant', value=0.)
sig = sig.unsqueeze(0)

if index == 0:
filters = sig
else:
filters = torch.cat([filters, sig], dim=0)

return filters, lengths
6 changes: 6 additions & 0 deletions src/torchaudio/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
AmplitudeToDB,
ComputeDeltas,
Convolve,
CQT,
Deemphasis,
Fade,
FFTConvolve,
FrequencyMasking,
GriffinLim,
InverseCQT,
InverseMelScale,
InverseSpectrogram,
LFCC,
Expand All @@ -32,6 +34,7 @@
TimeStretch,
Vad,
Vol,
VQT,
)


Expand All @@ -40,11 +43,13 @@
"AmplitudeToDB",
"ComputeDeltas",
"Convolve",
"CQT",
"Deemphasis",
"Fade",
"FFTConvolve",
"FrequencyMasking",
"GriffinLim",
"InverseCQT",
"InverseMelScale",
"InverseSpectrogram",
"LFCC",
Expand Down Expand Up @@ -72,4 +77,5 @@
"TimeStretch",
"Vad",
"Vol",
"VQT",
]
Loading