diff --git a/src/torchaudio/_backend/backend.py b/src/torchaudio/_backend/backend.py index da8d1bb7bc..579340962c 100644 --- a/src/torchaudio/_backend/backend.py +++ b/src/torchaudio/_backend/backend.py @@ -3,6 +3,7 @@ from typing import BinaryIO, Optional, Tuple, Union from torch import Tensor +from torchaudio.io import CodecConfig from .common import AudioMetaData @@ -37,6 +38,7 @@ def save( encoding: Optional[str] = None, bits_per_sample: Optional[int] = None, buffer_size: int = 4096, + compression: Optional[Union[CodecConfig, float, int]] = None, ) -> None: raise NotImplementedError diff --git a/src/torchaudio/_backend/ffmpeg.py b/src/torchaudio/_backend/ffmpeg.py index 0bcf6ee4bd..ca8374ea07 100644 --- a/src/torchaudio/_backend/ffmpeg.py +++ b/src/torchaudio/_backend/ffmpeg.py @@ -228,6 +228,7 @@ def save_audio( encoding: Optional[str] = None, bits_per_sample: Optional[int] = None, buffer_size: int = 4096, + compression: Optional[torchaudio.io.CodecConfig] = None, ) -> None: ext = None if hasattr(uri, "write"): @@ -250,6 +251,7 @@ def save_audio( format=_get_sample_format(src.dtype), encoder=encoder, encoder_format=enc_fmt, + codec_config=compression, ) with s.open(): s.write_audio_chunk(0, src) @@ -304,7 +306,13 @@ def save( encoding: Optional[str] = None, bits_per_sample: Optional[int] = None, buffer_size: int = 4096, + compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None, ) -> None: + if not isinstance(compression, (torchaudio.io.CodecConfig, type(None))): + raise ValueError( + "FFmpeg backend expects non-`None` value for argument `compression` to be of ", + f"type `torchaudio.io.CodecConfig`, but received value of type {type(compression)}", + ) save_audio( uri, src, @@ -314,6 +322,7 @@ def save( encoding, bits_per_sample, buffer_size, + compression, ) @staticmethod diff --git a/src/torchaudio/_backend/soundfile.py b/src/torchaudio/_backend/soundfile.py index 701e471c98..f4be1f7099 100644 --- a/src/torchaudio/_backend/soundfile.py +++ b/src/torchaudio/_backend/soundfile.py @@ -2,6 +2,7 @@ from typing import BinaryIO, Optional, Tuple, Union import torch +from torchaudio.io import CodecConfig from . import soundfile_backend from .backend import Backend @@ -35,7 +36,11 @@ def save( encoding: Optional[str] = None, bits_per_sample: Optional[int] = None, buffer_size: int = 4096, + compression: Optional[Union[CodecConfig, float, int]] = None, ) -> None: + if compression: + raise ValueError("soundfile backend does not support argument `compression`.") + soundfile_backend.save( uri, src, sample_rate, channels_first, format=format, encoding=encoding, bits_per_sample=bits_per_sample ) diff --git a/src/torchaudio/_backend/sox.py b/src/torchaudio/_backend/sox.py index 592890c95c..bfcd8a4f8b 100644 --- a/src/torchaudio/_backend/sox.py +++ b/src/torchaudio/_backend/sox.py @@ -56,7 +56,13 @@ def save( encoding: Optional[str] = None, bits_per_sample: Optional[int] = None, buffer_size: int = 4096, + compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None, ) -> None: + if not isinstance(compression, (float, int, type(None))): + raise ValueError( + "SoX backend expects non-`None` value for argument `compression` to be of ", + f"type `float` or `int`, but received value of type {type(compression)}", + ) if hasattr(uri, "write"): raise ValueError( "SoX backend does not support writing to file-like objects. ", @@ -68,7 +74,7 @@ def save( src, sample_rate, channels_first, - None, + compression, format, encoding, bits_per_sample, diff --git a/src/torchaudio/_backend/utils.py b/src/torchaudio/_backend/utils.py index 36cd5f11b1..96f40b6ba6 100644 --- a/src/torchaudio/_backend/utils.py +++ b/src/torchaudio/_backend/utils.py @@ -5,6 +5,7 @@ import torch from torchaudio._extension import lazy_import_ffmpeg_ext, lazy_import_sox_ext +from torchaudio.io import CodecConfig from . import soundfile_backend @@ -229,6 +230,7 @@ def save( bits_per_sample: Optional[int] = None, buffer_size: int = 4096, backend: Optional[str] = None, + compression: Optional[Union[CodecConfig, float, int]] = None, ): """Save audio data to file. @@ -283,8 +285,32 @@ def save( .. seealso:: :ref:`backend` + + compression (CodecConfig, float, int, or None, optional): + Compression configuration to apply. + + If the selected backend is FFmpeg, an instance of :py:class:`CodecConfig` must be provided. + + Otherwise, if the selected backend is SoX, a float or int value corresponding to option ``-C`` of the + ``sox`` command line interface must be provided. For instance: + + ``"mp3"`` + Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or + VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``. + + ``"flac"`` + Whole number from ``0`` to ``8``. ``8`` is default and highest compression. + + ``"ogg"``, ``"vorbis"`` + Number from ``-1`` to ``10``; ``-1`` is the highest compression + and lowest quality. Default: ``3``. + + Refer to http://sox.sourceforge.net/soxformat.html for more details. + """ backend = dispatcher(uri, format, backend) - return backend.save(uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size) + return backend.save( + uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size, compression + ) return save diff --git a/test/torchaudio_unittest/backend/dispatcher/dispatcher_test.py b/test/torchaudio_unittest/backend/dispatcher/dispatcher_test.py index 331cc856ad..8c473a0b1e 100644 --- a/test/torchaudio_unittest/backend/dispatcher/dispatcher_test.py +++ b/test/torchaudio_unittest/backend/dispatcher/dispatcher_test.py @@ -105,7 +105,7 @@ def test_save(self, available_backends, expected_backend): f"torchaudio._backend.utils.{expected_backend.__name__}.save" ) as mock_save: get_save_func()(filename, src, sample_rate, format=format) - mock_save.assert_called_once_with(filename, src, sample_rate, True, format, None, None, 4096) + mock_save.assert_called_once_with(filename, src, sample_rate, True, format, None, None, 4096, None) @parameterized.expand( [ @@ -126,4 +126,4 @@ def test_save_fileobj(self, available_backends, expected_backend): f"torchaudio._backend.utils.{expected_backend.__name__}.save" ) as mock_save: get_save_func()(f, src, sample_rate, format=format, buffer_size=buffer_size) - mock_save.assert_called_once_with(f, src, sample_rate, True, format, None, None, buffer_size) + mock_save.assert_called_once_with(f, src, sample_rate, True, format, None, None, buffer_size, None) diff --git a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py index 4e18911eeb..3fd9b70319 100644 --- a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py +++ b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py @@ -4,11 +4,13 @@ import subprocess import sys from functools import partial +from typing import Optional import torch from parameterized import parameterized from torchaudio._backend.ffmpeg import _parse_save_args from torchaudio._backend.utils import get_save_func +from torchaudio.io import CodecConfig from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func from torchaudio_unittest.common_utils import ( @@ -45,6 +47,7 @@ def assert_save_consistency( self, format: str, *, + compression: Optional[CodecConfig] = None, encoding: str = None, bits_per_sample: int = None, sample_rate: float = 8000, @@ -104,7 +107,15 @@ def assert_save_consistency( data = load_wav(src_path, normalize=False)[0] if test_mode == "path": ext = format - self._save(tgt_path, data, sample_rate, format=format, encoding=encoding, bits_per_sample=bits_per_sample) + self._save( + tgt_path, + data, + sample_rate, + compression=compression, + format=format, + encoding=encoding, + bits_per_sample=bits_per_sample, + ) elif test_mode == "fileobj": ext = None with open(tgt_path, "bw") as file_: @@ -112,6 +123,7 @@ def assert_save_consistency( file_, data, sample_rate, + compression=compression, format=format, encoding=encoding, bits_per_sample=bits_per_sample, @@ -123,6 +135,7 @@ def assert_save_consistency( file_, data, sample_rate, + compression=compression, format=format, encoding=encoding, bits_per_sample=bits_per_sample, @@ -198,11 +211,27 @@ def test_save_wav_dtype(self, test_mode, params): # NOTE: Supported sample formats: s16 s32 (24 bits) # [8, 16, 24], [16, 24], + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + ], ) - def test_save_flac(self, test_mode, bits_per_sample): + def test_save_flac(self, test_mode, bits_per_sample, compression_level): # -acodec flac -sample_fmt s16 # 24 bits needs to be mapped to s32 - self.assert_save_consistency("flac", bits_per_sample=bits_per_sample, test_mode=test_mode) + codec_config = CodecConfig( + compression_level=compression_level, + ) + self.assert_save_consistency( + "flac", compression=codec_config, bits_per_sample=bits_per_sample, test_mode=test_mode + ) # @nested_params( # ["path", "fileobj", "bytesio"], @@ -212,12 +241,25 @@ def test_save_flac(self, test_mode, bits_per_sample): # self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1) @nested_params( + [ + None, + -1, + 0, + 1, + 2, + 3, + 5, + 10, + ], ["path", "fileobj", "bytesio"], ) - def test_save_vorbis(self, test_mode): + def test_save_vorbis(self, quality_level, test_mode): # NOTE: ffmpeg doesn't recognize extension "vorbis", so we use "ogg" # self.assert_save_consistency("vorbis", test_mode=test_mode) - self.assert_save_consistency("ogg", test_mode=test_mode) + codec_config = CodecConfig( + qscale=quality_level, + ) + self.assert_save_consistency("ogg", compression=codec_config, test_mode=test_mode) # @nested_params( # ["path", "fileobj", "bytesio"], diff --git a/test/torchaudio_unittest/backend/dispatcher/sox/save_test.py b/test/torchaudio_unittest/backend/dispatcher/sox/save_test.py index c907541034..ec52e6eda3 100644 --- a/test/torchaudio_unittest/backend/dispatcher/sox/save_test.py +++ b/test/torchaudio_unittest/backend/dispatcher/sox/save_test.py @@ -40,6 +40,7 @@ def assert_save_consistency( self, format: str, *, + compression: float = None, encoding: str = None, bits_per_sample: int = None, sample_rate: float = 8000, @@ -101,13 +102,16 @@ def assert_save_consistency( # 2.1. Convert the original wav to target format with torchaudio data = load_wav(src_path, normalize=False)[0] if test_mode == "path": - self._save(tgt_path, data, sample_rate, encoding=encoding, bits_per_sample=bits_per_sample) + self._save( + tgt_path, data, sample_rate, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample + ) elif test_mode == "fileobj": with open(tgt_path, "bw") as file_: self._save( file_, data, sample_rate, + compression=compression, format=format, encoding=encoding, bits_per_sample=bits_per_sample, @@ -118,6 +122,7 @@ def assert_save_consistency( file_, data, sample_rate, + compression=compression, format=format, encoding=encoding, bits_per_sample=bits_per_sample, @@ -134,7 +139,9 @@ def assert_save_consistency( # 3.1. Convert the original wav to target format with sox sox_encoding = _get_sox_encoding(encoding) - sox_utils.convert_audio_file(src_path, sox_path, encoding=sox_encoding, bit_depth=bits_per_sample) + sox_utils.convert_audio_file( + src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample + ) # 3.2. Convert the target format to wav with sox sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth) # 3.3. Load with SciPy @@ -175,15 +182,42 @@ def test_save_wav_dtype(self, params): @nested_params( [8, 16, 24], + [ + None, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + ], ) - def test_save_flac(self, bits_per_sample): - self.assert_save_consistency("flac", bits_per_sample=bits_per_sample, test_mode="path") + def test_save_flac(self, bits_per_sample, compression_level): + self.assert_save_consistency( + "flac", compression=compression_level, bits_per_sample=bits_per_sample, test_mode="path" + ) def test_save_htk(self): self.assert_save_consistency("htk", test_mode="path", num_channels=1) - def test_save_vorbis(self): - self.assert_save_consistency("vorbis", test_mode="path") + @nested_params( + [ + None, + -1, + 0, + 1, + 2, + 3, + 3.6, + 5, + 10, + ], + ) + def test_save_vorbis(self, quality_level): + self.assert_save_consistency("vorbis", compression=quality_level, test_mode="path") @nested_params( [ @@ -254,9 +288,22 @@ def test_save_amb(self, enc_params): encoding, bits_per_sample = enc_params self.assert_save_consistency("amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode="path") + @nested_params( + [ + None, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + ], + ) @skipIfNoSoxEncoder("amr-nb") - def test_save_amr_nb(self): - self.assert_save_consistency("amr-nb", num_channels=1, test_mode="path") + def test_save_amr_nb(self, bit_rate): + self.assert_save_consistency("amr-nb", compression=bit_rate, num_channels=1, test_mode="path") def test_save_gsm(self): self.assert_save_consistency("gsm", num_channels=1, test_mode="path")