From c470588a36e48e5feda600b382a0934d7c01aada Mon Sep 17 00:00:00 2001 From: Mark Kremer Date: Sat, 3 Aug 2024 19:45:26 +0200 Subject: [PATCH] Rewrite flac.Stream() to load the samples directly from the frame --- flac/decode.go | 133 ++++++++++++++++++++++---------------------- flac/decode_test.go | 24 +++++++- 2 files changed, 88 insertions(+), 69 deletions(-) diff --git a/flac/decode.go b/flac/decode.go index f65aeea..2c82619 100644 --- a/flac/decode.go +++ b/flac/decode.go @@ -5,6 +5,7 @@ import ( "io" "github.com/mewkiz/flac" + "github.com/mewkiz/flac/frame" "github.com/pkg/errors" "github.com/gopxl/beep/v2" @@ -32,10 +33,16 @@ func Decode(r io.Reader) (s beep.StreamSeekCloser, format beep.Format, err error } else { d.stream, err = flac.New(r) } + if err != nil { + return nil, beep.Format{}, errors.Wrap(err, "flac") + } + // Read the first frame + d.frame, err = d.stream.ParseNext() if err != nil { return nil, beep.Format{}, errors.Wrap(err, "flac") } + format = beep.Format{ SampleRate: beep.SampleRate(d.stream.Info.SampleRate), NumChannels: int(d.stream.Info.NChannels), @@ -47,96 +54,86 @@ func Decode(r io.Reader) (s beep.StreamSeekCloser, format beep.Format, err error type decoder struct { r io.Reader stream *flac.Stream - buf [][2]float64 - pos int + frame *frame.Frame + posInFrame int err error seekEnabled bool } func (d *decoder) Stream(samples [][2]float64) (n int, ok bool) { - if d.err != nil { + if d.err != nil || d.frame == nil { return 0, false } - // Copy samples from buffer. - j := 0 - for i := range samples { - if j >= len(d.buf) { - // refill buffer. - if err := d.refill(); err != nil { - d.pos += n + + for len(samples) > 0 { + samplesLeft := int(d.frame.BlockSize) - d.posInFrame + if samplesLeft <= 0 { + // Read next frame + var err error + d.frame, err = d.stream.ParseNext() + if err != nil { + d.frame = nil if err == io.EOF { return n, n > 0 } - d.err = err + d.err = errors.Wrap(err, "flac") return 0, false } - j = 0 + d.posInFrame = 0 + continue } - samples[i] = d.buf[j] - j++ - n++ + + toFill := min(samplesLeft, len(samples)) + d.decodeFrameRangeInto(d.frame, d.posInFrame, toFill, samples) + d.posInFrame += toFill + n += toFill + samples = samples[toFill:] } - d.buf = d.buf[j:] - d.pos += n + return n, true } -// refill decodes audio samples to fill the decode buffer. -func (d *decoder) refill() error { - // Empty buffer. - d.buf = d.buf[:0] - // Parse audio frame. - frame, err := d.stream.ParseNext() - if err != nil { - return err - } - // Expand buffer size if needed. - n := len(frame.Subframes[0].Samples) - if cap(d.buf) < n { - d.buf = make([][2]float64, n) - } else { - d.buf = d.buf[:n] - } - // Decode audio samples. +// decodeFrameRangeInto decodes the samples frame from the position `start` up to `start + num` +// and stores them in Beep's format into the provided slice `into`. +func (d *decoder) decodeFrameRangeInto(frame *frame.Frame, start, num int, into [][2]float64) { bps := d.stream.Info.BitsPerSample - nchannels := d.stream.Info.NChannels + numChannels := d.stream.Info.NChannels s := 1 << (bps - 1) q := 1 / float64(s) switch { - case bps == 8 && nchannels == 1: - for i := 0; i < n; i++ { - d.buf[i][0] = float64(int8(frame.Subframes[0].Samples[i])) * q - d.buf[i][1] = float64(int8(frame.Subframes[0].Samples[i])) * q + case bps == 8 && numChannels == 1: + for i := 0; i < num; i++ { + into[i][0] = float64(int8(frame.Subframes[0].Samples[start+i])) * q + into[i][1] = into[i][0] } - case bps == 16 && nchannels == 1: - for i := 0; i < n; i++ { - d.buf[i][0] = float64(int16(frame.Subframes[0].Samples[i])) * q - d.buf[i][1] = float64(int16(frame.Subframes[0].Samples[i])) * q + case bps == 16 && numChannels == 1: + for i := 0; i < num; i++ { + into[i][0] = float64(int16(frame.Subframes[0].Samples[start+i])) * q + into[i][1] = into[i][0] } - case bps == 24 && nchannels == 1: - for i := 0; i < n; i++ { - d.buf[i][0] = float64(int32(frame.Subframes[0].Samples[i])) * q - d.buf[i][1] = float64(int32(frame.Subframes[0].Samples[i])) * q + case bps == 24 && numChannels == 1: + for i := 0; i < num; i++ { + into[i][0] = float64(int32(frame.Subframes[0].Samples[start+i])) * q + into[i][1] = into[i][0] } - case bps == 8 && nchannels >= 2: - for i := 0; i < n; i++ { - d.buf[i][0] = float64(int8(frame.Subframes[0].Samples[i])) * q - d.buf[i][1] = float64(int8(frame.Subframes[1].Samples[i])) * q + case bps == 8 && numChannels >= 2: + for i := 0; i < num; i++ { + into[i][0] = float64(int8(frame.Subframes[0].Samples[start+i])) * q + into[i][1] = float64(int8(frame.Subframes[1].Samples[start+i])) * q } - case bps == 16 && nchannels >= 2: - for i := 0; i < n; i++ { - d.buf[i][0] = float64(int16(frame.Subframes[0].Samples[i])) * q - d.buf[i][1] = float64(int16(frame.Subframes[1].Samples[i])) * q + case bps == 16 && numChannels >= 2: + for i := 0; i < num; i++ { + into[i][0] = float64(int16(frame.Subframes[0].Samples[start+i])) * q + into[i][1] = float64(int16(frame.Subframes[1].Samples[start+i])) * q } - case bps == 24 && nchannels >= 2: - for i := 0; i < n; i++ { - d.buf[i][0] = float64(frame.Subframes[0].Samples[i]) * q - d.buf[i][1] = float64(frame.Subframes[1].Samples[i]) * q + case bps == 24 && numChannels >= 2: + for i := 0; i < num; i++ { + into[i][0] = float64(frame.Subframes[0].Samples[start+i]) * q + into[i][1] = float64(frame.Subframes[1].Samples[start+i]) * q } default: - panic(fmt.Errorf("support for %d bits-per-sample and %d channels combination not yet implemented", bps, nchannels)) + panic(fmt.Errorf("flac: support for %d bits-per-sample and %d channels combination not yet implemented", bps, numChannels)) } - return nil } func (d *decoder) Err() error { @@ -148,7 +145,7 @@ func (d *decoder) Len() int { } func (d *decoder) Position() int { - return d.pos + return int(d.frame.SampleNumber()) + d.posInFrame } func (d *decoder) Seek(p int) error { @@ -156,19 +153,19 @@ func (d *decoder) Seek(p int) error { return errors.New("flac.decoder.Seek: not enabled") } + // d.stream.Seek() doesn't seek to the exact position p, instead + // it seeks to the start of the frame p is in. The frame position + // is returned and stored in pos. pos, err := d.stream.Seek(uint64(p)) if err != nil { return errors.Wrap(err, "flac") } + d.posInFrame = p - int(pos) - toDiscard := p - int(pos) - err = d.refill() + d.frame, err = d.stream.ParseNext() if err != nil { - return err + return errors.Wrap(err, "flac") } - d.buf = d.buf[toDiscard:] - - d.pos = p return err } diff --git a/flac/decode_test.go b/flac/decode_test.go index 16a1699..c98c16f 100644 --- a/flac/decode_test.go +++ b/flac/decode_test.go @@ -1,6 +1,7 @@ package flac_test import ( + "bytes" "io" "log" "os" @@ -44,7 +45,8 @@ func TestDecoder_Stream(t *testing.T) { wavStream, _, err := wav.Decode(wavFile) assert.NoError(t, err) - assert.Equal(t, wavStream.Len(), flacStream.Len()) + assert.Equal(t, 22050, wavStream.Len()) + assert.Equal(t, 22050, flacStream.Len()) wavSamples := testtools.Collect(wavStream) flacSamples := testtools.Collect(flacStream) @@ -139,3 +141,23 @@ func getFlacFrameStartPositions(r io.Reader) ([]uint64, error) { return frameStarts, nil } + +func BenchmarkDecoder_Stream(b *testing.B) { + // Load the file into memory, so the disk performance doesn't impact the benchmark. + data, err := os.ReadFile(testtools.TestFilePath("valid_44100hz_22050_samples_ffmpeg.flac")) + assert.NoError(b, err) + + r := bytes.NewReader(data) + + b.Run("test", func(b *testing.B) { + s, _, err := flac.Decode(r) + assert.NoError(b, err) + + samples := testtools.Collect(s) + assert.Equal(b, 22050, len(samples)) + + // Reset for next run. + _, err = r.Seek(0, io.SeekStart) + assert.NoError(b, err) + }) +}