Skip to content

Commit

Permalink
Support compression level in i/o dispatcher backend (pytorch#3662)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#3662

Differential Revision: D50367721
  • Loading branch information
hwangjeff authored and facebook-github-bot committed Oct 18, 2023
1 parent 671261c commit 0883323
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/torchaudio/_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import BinaryIO, Optional, Tuple, Union

from torch import Tensor
from torchaudio.io import CodecConfig

from .common import AudioMetaData

Expand Down Expand Up @@ -37,6 +38,7 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[CodecConfig, float]] = None,
) -> None:
raise NotImplementedError

Expand Down
9 changes: 9 additions & 0 deletions src/torchaudio/_backend/ffmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def save_audio(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[torchaudio.io.CodecConfig] = None,
) -> None:
ext = None
if hasattr(uri, "write"):
Expand All @@ -250,6 +251,7 @@ def save_audio(
format=_get_sample_format(src.dtype),
encoder=encoder,
encoder_format=enc_fmt,
codec_config=compression,
)
with s.open():
s.write_audio_chunk(0, src)
Expand Down Expand Up @@ -304,7 +306,13 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[torchaudio.io.CodecConfig, float]] = None,
) -> None:
if not isinstance(compression, (torchaudio.io.CodecConfig, None)):
raise ValueError(
"FFmpeg backend expects non-`None` value for argument `compression` to be of ",
f"type `torchaudio.io.CodecConfig`, but received value of type {type(compression)}",
)
save_audio(
uri,
src,
Expand All @@ -314,6 +322,7 @@ def save(
encoding,
bits_per_sample,
buffer_size,
compression,
)

@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions src/torchaudio/_backend/soundfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import BinaryIO, Optional, Tuple, Union

import torch
from torchaudio.io import CodecConfig

from . import soundfile_backend
from .backend import Backend
Expand Down Expand Up @@ -35,7 +36,11 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[CodecConfig, float]] = None,
) -> None:
if compression:
raise ValueError("soundfile backend does not support argument `compression`.")

soundfile_backend.save(
uri, src, sample_rate, channels_first, format=format, encoding=encoding, bits_per_sample=bits_per_sample
)
Expand Down
8 changes: 7 additions & 1 deletion src/torchaudio/_backend/sox.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ def save(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[torchaudio.io.CodecConfig, float]] = None,
) -> None:
if not isinstance(compression, (float, None)):
raise ValueError(
"SoX backend expects non-`None` value for argument `compression` to be of ",
f"type `float`, but received value of type {type(compression)}",
)
if hasattr(uri, "write"):
raise ValueError(
"SoX backend does not support writing to file-like objects. ",
Expand All @@ -68,7 +74,7 @@ def save(
src,
sample_rate,
channels_first,
None,
compression,
format,
encoding,
bits_per_sample,
Expand Down
9 changes: 8 additions & 1 deletion src/torchaudio/_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from torchaudio._extension import lazy_import_ffmpeg_ext, lazy_import_sox_ext
from torchaudio.io import CodecConfig

from . import soundfile_backend

Expand Down Expand Up @@ -229,6 +230,7 @@ def save(
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
backend: Optional[str] = None,
compression: Optional[Union[CodecConfig, float]] = None,
):
"""Save audio data to file.
Expand Down Expand Up @@ -283,8 +285,13 @@ def save(
.. seealso::
:ref:`backend`
compression (CodecConfig, float, or None, optional):
To fill in.
"""
backend = dispatcher(uri, format, backend)
return backend.save(uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size)
return backend.save(
uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size, compression
)

return save

0 comments on commit 0883323

Please sign in to comment.