Skip to content

Commit

Permalink
Support compression level in i/o dispatcher backend
Browse files Browse the repository at this point in the history
Differential Revision: D50367721

Pull Request resolved: pytorch#3662
  • Loading branch information
hwangjeff authored and mthrok committed Oct 19, 2023
1 parent e331c1a commit 1db768f
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_save(self, available_backends, expected_backend):
f"torchaudio._backend.utils.{expected_backend.__name__}.save"
) as mock_save:
get_save_func()(filename, src, sample_rate, format=format)
mock_save.assert_called_once_with(filename, src, sample_rate, True, format, None, None, 4096)
mock_save.assert_called_once_with(filename, src, sample_rate, True, format, None, None, 4096, None)

@parameterized.expand(
[
Expand All @@ -126,4 +126,4 @@ def test_save_fileobj(self, available_backends, expected_backend):
f"torchaudio._backend.utils.{expected_backend.__name__}.save"
) as mock_save:
get_save_func()(f, src, sample_rate, format=format, buffer_size=buffer_size)
mock_save.assert_called_once_with(f, src, sample_rate, True, format, None, None, buffer_size)
mock_save.assert_called_once_with(f, src, sample_rate, True, format, None, None, buffer_size, None)
52 changes: 47 additions & 5 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import subprocess
import sys
from functools import partial
from typing import Optional

import torch
from parameterized import parameterized
from torchaudio._backend.ffmpeg import _parse_save_args
from torchaudio._backend.utils import get_save_func
from torchaudio.io import CodecConfig

from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func
from torchaudio_unittest.common_utils import (
Expand Down Expand Up @@ -45,6 +47,7 @@ def assert_save_consistency(
self,
format: str,
*,
compression: Optional[CodecConfig] = None,
encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
Expand Down Expand Up @@ -104,14 +107,23 @@ def assert_save_consistency(
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
ext = format
self._save(tgt_path, data, sample_rate, format=format, encoding=encoding, bits_per_sample=bits_per_sample)
self._save(
tgt_path,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
elif test_mode == "fileobj":
ext = None
with open(tgt_path, "bw") as file_:
self._save(
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
Expand All @@ -123,6 +135,7 @@ def assert_save_consistency(
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
Expand Down Expand Up @@ -198,11 +211,27 @@ def test_save_wav_dtype(self, test_mode, params):
# NOTE: Supported sample formats: s16 s32 (24 bits)
# [8, 16, 24],
[16, 24],
[
0,
1,
2,
3,
4,
5,
6,
7,
8,
],
)
def test_save_flac(self, test_mode, bits_per_sample):
def test_save_flac(self, test_mode, bits_per_sample, compression_level):
# -acodec flac -sample_fmt s16
# 24 bits needs to be mapped to s32
self.assert_save_consistency("flac", bits_per_sample=bits_per_sample, test_mode=test_mode)
codec_config = CodecConfig(
compression_level=compression_level,
)
self.assert_save_consistency(
"flac", compression=codec_config, bits_per_sample=bits_per_sample, test_mode=test_mode
)

# @nested_params(
# ["path", "fileobj", "bytesio"],
Expand All @@ -212,12 +241,25 @@ def test_save_flac(self, test_mode, bits_per_sample):
# self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1)

@nested_params(
[
None,
-1,
0,
1,
2,
3,
5,
10,
],
["path", "fileobj", "bytesio"],
)
def test_save_vorbis(self, test_mode):
def test_save_vorbis(self, quality_level, test_mode):
# NOTE: ffmpeg doesn't recognize extension "vorbis", so we use "ogg"
# self.assert_save_consistency("vorbis", test_mode=test_mode)
self.assert_save_consistency("ogg", test_mode=test_mode)
codec_config = CodecConfig(
qscale=quality_level,
)
self.assert_save_consistency("ogg", compression=codec_config, test_mode=test_mode)

# @nested_params(
# ["path", "fileobj", "bytesio"],
Expand Down
63 changes: 55 additions & 8 deletions test/torchaudio_unittest/backend/dispatcher/sox/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def assert_save_consistency(
self,
format: str,
*,
compression: float = None,
encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
Expand Down Expand Up @@ -101,13 +102,16 @@ def assert_save_consistency(
# 2.1. Convert the original wav to target format with torchaudio
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
self._save(tgt_path, data, sample_rate, encoding=encoding, bits_per_sample=bits_per_sample)
self._save(
tgt_path, data, sample_rate, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample
)
elif test_mode == "fileobj":
with open(tgt_path, "bw") as file_:
self._save(
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
Expand All @@ -118,6 +122,7 @@ def assert_save_consistency(
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
Expand All @@ -134,7 +139,9 @@ def assert_save_consistency(

# 3.1. Convert the original wav to target format with sox
sox_encoding = _get_sox_encoding(encoding)
sox_utils.convert_audio_file(src_path, sox_path, encoding=sox_encoding, bit_depth=bits_per_sample)
sox_utils.convert_audio_file(
src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample
)
# 3.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 3.3. Load with SciPy
Expand Down Expand Up @@ -175,15 +182,42 @@ def test_save_wav_dtype(self, params):

@nested_params(
[8, 16, 24],
[
None,
0,
1,
2,
3,
4,
5,
6,
7,
8,
],
)
def test_save_flac(self, bits_per_sample):
self.assert_save_consistency("flac", bits_per_sample=bits_per_sample, test_mode="path")
def test_save_flac(self, bits_per_sample, compression_level):
self.assert_save_consistency(
"flac", compression=compression_level, bits_per_sample=bits_per_sample, test_mode="path"
)

def test_save_htk(self):
self.assert_save_consistency("htk", test_mode="path", num_channels=1)

def test_save_vorbis(self):
self.assert_save_consistency("vorbis", test_mode="path")
@nested_params(
[
None,
-1,
0,
1,
2,
3,
3.6,
5,
10,
],
)
def test_save_vorbis(self, quality_level):
self.assert_save_consistency("vorbis", compression=quality_level, test_mode="path")

@nested_params(
[
Expand Down Expand Up @@ -254,9 +288,22 @@ def test_save_amb(self, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency("amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode="path")

@nested_params(
[
None,
0,
1,
2,
3,
4,
5,
6,
7,
],
)
@skipIfNoSoxEncoder("amr-nb")
def test_save_amr_nb(self):
self.assert_save_consistency("amr-nb", num_channels=1, test_mode="path")
def test_save_amr_nb(self, bit_rate):
self.assert_save_consistency("amr-nb", compression=bit_rate, num_channels=1, test_mode="path")

def test_save_gsm(self):
self.assert_save_consistency("gsm", num_channels=1, test_mode="path")
Expand Down
2 changes: 2 additions & 0 deletions torchaudio/_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import BinaryIO, Optional, Tuple, Union

from torch import Tensor
from torchaudio.io import CodecConfig

from .common import AudioMetaData

Expand Down Expand Up @@ -37,6 +38,7 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[CodecConfig, float, int]] = None,
) -> None:
raise NotImplementedError

Expand Down
9 changes: 9 additions & 0 deletions torchaudio/_backend/ffmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def save_audio(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[torchaudio.io.CodecConfig] = None,
) -> None:
ext = None
if hasattr(uri, "write"):
Expand All @@ -275,6 +276,7 @@ def save_audio(
format=_get_sample_format(src.dtype),
encoder=encoder,
encoder_format=enc_fmt,
codec_config=compression,
)
with s.open():
s.write_audio_chunk(0, src)
Expand Down Expand Up @@ -343,7 +345,13 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
) -> None:
if not isinstance(compression, (torchaudio.io.CodecConfig, type(None))):
raise ValueError(
"FFmpeg backend expects non-`None` value for argument `compression` to be of ",
f"type `torchaudio.io.CodecConfig`, but received value of type {type(compression)}",
)
save_audio(
uri,
src,
Expand All @@ -353,6 +361,7 @@ def save(
encoding,
bits_per_sample,
buffer_size,
compression,
)

@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions torchaudio/_backend/soundfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import BinaryIO, Optional, Tuple, Union

import torch
from torchaudio.io import CodecConfig

from . import soundfile_backend
from .backend import Backend
Expand Down Expand Up @@ -35,7 +36,11 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[CodecConfig, float, int]] = None,
) -> None:
if compression:
raise ValueError("soundfile backend does not support argument `compression`.")

soundfile_backend.save(
uri, src, sample_rate, channels_first, format=format, encoding=encoding, bits_per_sample=bits_per_sample
)
Expand Down
8 changes: 7 additions & 1 deletion torchaudio/_backend/sox.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
) -> None:
if not isinstance(compression, (float, int, type(None))):
raise ValueError(
"SoX backend expects non-`None` value for argument `compression` to be of ",
f"type `float` or `int`, but received value of type {type(compression)}",
)
if hasattr(uri, "write"):
raise ValueError(
"SoX backend does not support writing to file-like objects. ",
Expand All @@ -67,7 +73,7 @@ def save(
src,
sample_rate,
channels_first,
None,
compression,
format,
encoding,
bits_per_sample,
Expand Down
28 changes: 27 additions & 1 deletion torchaudio/_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from torchaudio._extension import _FFMPEG_EXT, _SOX_INITIALIZED
from torchaudio.io import CodecConfig

from . import soundfile_backend

Expand Down Expand Up @@ -229,6 +230,7 @@ def save(
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
backend: Optional[str] = None,
compression: Optional[Union[CodecConfig, float, int]] = None,
):
"""Save audio data to file.
Expand Down Expand Up @@ -283,8 +285,32 @@ def save(
.. seealso::
:ref:`backend`
compression (CodecConfig, float, int, or None, optional):
Compression configuration to apply.
If the selected backend is FFmpeg, an instance of :py:class:`CodecConfig` must be provided.
Otherwise, if the selected backend is SoX, a float or int value corresponding to option ``-C`` of the
``sox`` command line interface must be provided. For instance:
``"mp3"``
Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
``"flac"``
Whole number from ``0`` to ``8``. ``8`` is default and highest compression.
``"ogg"``, ``"vorbis"``
Number from ``-1`` to ``10``; ``-1`` is the highest compression
and lowest quality. Default: ``3``.
Refer to http://sox.sourceforge.net/soxformat.html for more details.
"""
backend = dispatcher(uri, format, backend)
return backend.save(uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size)
return backend.save(
uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size, compression
)

return save

0 comments on commit 1db768f

Please sign in to comment.