Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support compression level in i/o dispatcher backend #3662

Merged
merged 1 commit into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/torchaudio/_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import BinaryIO, Optional, Tuple, Union

from torch import Tensor
from torchaudio.io import CodecConfig

from .common import AudioMetaData

Expand Down Expand Up @@ -37,6 +38,7 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[CodecConfig, float, int]] = None,
) -> None:
raise NotImplementedError

Expand Down
9 changes: 9 additions & 0 deletions src/torchaudio/_backend/ffmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def save_audio(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[torchaudio.io.CodecConfig] = None,
) -> None:
ext = None
if hasattr(uri, "write"):
Expand All @@ -250,6 +251,7 @@ def save_audio(
format=_get_sample_format(src.dtype),
encoder=encoder,
encoder_format=enc_fmt,
codec_config=compression,
)
with s.open():
s.write_audio_chunk(0, src)
Expand Down Expand Up @@ -304,7 +306,13 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
) -> None:
if not isinstance(compression, (torchaudio.io.CodecConfig, type(None))):
raise ValueError(
"FFmpeg backend expects non-`None` value for argument `compression` to be of ",
f"type `torchaudio.io.CodecConfig`, but received value of type {type(compression)}",
)
save_audio(
uri,
src,
Expand All @@ -314,6 +322,7 @@ def save(
encoding,
bits_per_sample,
buffer_size,
compression,
)

@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions src/torchaudio/_backend/soundfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import BinaryIO, Optional, Tuple, Union

import torch
from torchaudio.io import CodecConfig

from . import soundfile_backend
from .backend import Backend
Expand Down Expand Up @@ -35,7 +36,11 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[CodecConfig, float, int]] = None,
) -> None:
if compression:
raise ValueError("soundfile backend does not support argument `compression`.")

soundfile_backend.save(
uri, src, sample_rate, channels_first, format=format, encoding=encoding, bits_per_sample=bits_per_sample
)
Expand Down
8 changes: 7 additions & 1 deletion src/torchaudio/_backend/sox.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
) -> None:
if not isinstance(compression, (float, int, type(None))):
raise ValueError(
"SoX backend expects non-`None` value for argument `compression` to be of ",
f"type `float` or `int`, but received value of type {type(compression)}",
)
if hasattr(uri, "write"):
raise ValueError(
"SoX backend does not support writing to file-like objects. ",
Expand All @@ -68,7 +74,7 @@ def save(
src,
sample_rate,
channels_first,
None,
compression,
format,
encoding,
bits_per_sample,
Expand Down
28 changes: 27 additions & 1 deletion src/torchaudio/_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from torchaudio._extension import lazy_import_ffmpeg_ext, lazy_import_sox_ext
from torchaudio.io import CodecConfig

from . import soundfile_backend

Expand Down Expand Up @@ -229,6 +230,7 @@ def save(
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
backend: Optional[str] = None,
compression: Optional[Union[CodecConfig, float, int]] = None,
):
"""Save audio data to file.

Expand Down Expand Up @@ -283,8 +285,32 @@ def save(

.. seealso::
:ref:`backend`

compression (CodecConfig, float, int, or None, optional):
Compression configuration to apply.

If the selected backend is FFmpeg, an instance of :py:class:`CodecConfig` must be provided.

Otherwise, if the selected backend is SoX, a float or int value corresponding to option ``-C`` of the
``sox`` command line interface must be provided. For instance:

``"mp3"``
Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.

``"flac"``
Whole number from ``0`` to ``8``. ``8`` is default and highest compression.

``"ogg"``, ``"vorbis"``
Number from ``-1`` to ``10``; ``-1`` is the highest compression
and lowest quality. Default: ``3``.

Refer to http://sox.sourceforge.net/soxformat.html for more details.

"""
backend = dispatcher(uri, format, backend)
return backend.save(uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size)
return backend.save(
uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size, compression
)

return save
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_save(self, available_backends, expected_backend):
f"torchaudio._backend.utils.{expected_backend.__name__}.save"
) as mock_save:
get_save_func()(filename, src, sample_rate, format=format)
mock_save.assert_called_once_with(filename, src, sample_rate, True, format, None, None, 4096)
mock_save.assert_called_once_with(filename, src, sample_rate, True, format, None, None, 4096, None)

@parameterized.expand(
[
Expand All @@ -126,4 +126,4 @@ def test_save_fileobj(self, available_backends, expected_backend):
f"torchaudio._backend.utils.{expected_backend.__name__}.save"
) as mock_save:
get_save_func()(f, src, sample_rate, format=format, buffer_size=buffer_size)
mock_save.assert_called_once_with(f, src, sample_rate, True, format, None, None, buffer_size)
mock_save.assert_called_once_with(f, src, sample_rate, True, format, None, None, buffer_size, None)
52 changes: 47 additions & 5 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import subprocess
import sys
from functools import partial
from typing import Optional

import torch
from parameterized import parameterized
from torchaudio._backend.ffmpeg import _parse_save_args
from torchaudio._backend.utils import get_save_func
from torchaudio.io import CodecConfig

from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func
from torchaudio_unittest.common_utils import (
Expand Down Expand Up @@ -45,6 +47,7 @@ def assert_save_consistency(
self,
format: str,
*,
compression: Optional[CodecConfig] = None,
encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
Expand Down Expand Up @@ -104,14 +107,23 @@ def assert_save_consistency(
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
ext = format
self._save(tgt_path, data, sample_rate, format=format, encoding=encoding, bits_per_sample=bits_per_sample)
self._save(
tgt_path,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
elif test_mode == "fileobj":
ext = None
with open(tgt_path, "bw") as file_:
self._save(
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
Expand All @@ -123,6 +135,7 @@ def assert_save_consistency(
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
Expand Down Expand Up @@ -198,11 +211,27 @@ def test_save_wav_dtype(self, test_mode, params):
# NOTE: Supported sample formats: s16 s32 (24 bits)
# [8, 16, 24],
[16, 24],
[
0,
1,
2,
3,
4,
5,
6,
7,
8,
],
)
def test_save_flac(self, test_mode, bits_per_sample):
def test_save_flac(self, test_mode, bits_per_sample, compression_level):
# -acodec flac -sample_fmt s16
# 24 bits needs to be mapped to s32
self.assert_save_consistency("flac", bits_per_sample=bits_per_sample, test_mode=test_mode)
codec_config = CodecConfig(
compression_level=compression_level,
)
self.assert_save_consistency(
"flac", compression=codec_config, bits_per_sample=bits_per_sample, test_mode=test_mode
)

# @nested_params(
# ["path", "fileobj", "bytesio"],
Expand All @@ -212,12 +241,25 @@ def test_save_flac(self, test_mode, bits_per_sample):
# self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1)

@nested_params(
[
None,
-1,
0,
1,
2,
3,
5,
10,
],
["path", "fileobj", "bytesio"],
)
def test_save_vorbis(self, test_mode):
def test_save_vorbis(self, quality_level, test_mode):
# NOTE: ffmpeg doesn't recognize extension "vorbis", so we use "ogg"
# self.assert_save_consistency("vorbis", test_mode=test_mode)
self.assert_save_consistency("ogg", test_mode=test_mode)
codec_config = CodecConfig(
qscale=quality_level,
)
self.assert_save_consistency("ogg", compression=codec_config, test_mode=test_mode)

# @nested_params(
# ["path", "fileobj", "bytesio"],
Expand Down
63 changes: 55 additions & 8 deletions test/torchaudio_unittest/backend/dispatcher/sox/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def assert_save_consistency(
self,
format: str,
*,
compression: float = None,
encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
Expand Down Expand Up @@ -101,13 +102,16 @@ def assert_save_consistency(
# 2.1. Convert the original wav to target format with torchaudio
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
self._save(tgt_path, data, sample_rate, encoding=encoding, bits_per_sample=bits_per_sample)
self._save(
tgt_path, data, sample_rate, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample
)
elif test_mode == "fileobj":
with open(tgt_path, "bw") as file_:
self._save(
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
Expand All @@ -118,6 +122,7 @@ def assert_save_consistency(
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
Expand All @@ -134,7 +139,9 @@ def assert_save_consistency(

# 3.1. Convert the original wav to target format with sox
sox_encoding = _get_sox_encoding(encoding)
sox_utils.convert_audio_file(src_path, sox_path, encoding=sox_encoding, bit_depth=bits_per_sample)
sox_utils.convert_audio_file(
src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample
)
# 3.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 3.3. Load with SciPy
Expand Down Expand Up @@ -175,15 +182,42 @@ def test_save_wav_dtype(self, params):

@nested_params(
[8, 16, 24],
[
None,
0,
1,
2,
3,
4,
5,
6,
7,
8,
],
)
def test_save_flac(self, bits_per_sample):
self.assert_save_consistency("flac", bits_per_sample=bits_per_sample, test_mode="path")
def test_save_flac(self, bits_per_sample, compression_level):
self.assert_save_consistency(
"flac", compression=compression_level, bits_per_sample=bits_per_sample, test_mode="path"
)

def test_save_htk(self):
self.assert_save_consistency("htk", test_mode="path", num_channels=1)

def test_save_vorbis(self):
self.assert_save_consistency("vorbis", test_mode="path")
@nested_params(
[
None,
-1,
0,
1,
2,
3,
3.6,
5,
10,
],
)
def test_save_vorbis(self, quality_level):
self.assert_save_consistency("vorbis", compression=quality_level, test_mode="path")

@nested_params(
[
Expand Down Expand Up @@ -254,9 +288,22 @@ def test_save_amb(self, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency("amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode="path")

@nested_params(
[
None,
0,
1,
2,
3,
4,
5,
6,
7,
],
)
@skipIfNoSoxEncoder("amr-nb")
def test_save_amr_nb(self):
self.assert_save_consistency("amr-nb", num_channels=1, test_mode="path")
def test_save_amr_nb(self, bit_rate):
self.assert_save_consistency("amr-nb", compression=bit_rate, num_channels=1, test_mode="path")

def test_save_gsm(self):
self.assert_save_consistency("gsm", num_channels=1, test_mode="path")
Expand Down
Loading