Skip to content

Commit

Permalink
Rewrite flac.Stream() to load the samples directly from the frame
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkKremer committed Aug 3, 2024
1 parent cb5e727 commit c470588
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 69 deletions.
133 changes: 65 additions & 68 deletions flac/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"

"github.com/mewkiz/flac"
"github.com/mewkiz/flac/frame"
"github.com/pkg/errors"

"github.com/gopxl/beep/v2"
Expand Down Expand Up @@ -32,10 +33,16 @@ func Decode(r io.Reader) (s beep.StreamSeekCloser, format beep.Format, err error
} else {
d.stream, err = flac.New(r)
}
if err != nil {
return nil, beep.Format{}, errors.Wrap(err, "flac")
}

// Read the first frame
d.frame, err = d.stream.ParseNext()
if err != nil {
return nil, beep.Format{}, errors.Wrap(err, "flac")
}

format = beep.Format{
SampleRate: beep.SampleRate(d.stream.Info.SampleRate),
NumChannels: int(d.stream.Info.NChannels),
Expand All @@ -47,96 +54,86 @@ func Decode(r io.Reader) (s beep.StreamSeekCloser, format beep.Format, err error
type decoder struct {
r io.Reader
stream *flac.Stream
buf [][2]float64
pos int
frame *frame.Frame
posInFrame int
err error
seekEnabled bool
}

func (d *decoder) Stream(samples [][2]float64) (n int, ok bool) {
if d.err != nil {
if d.err != nil || d.frame == nil {
return 0, false
}
// Copy samples from buffer.
j := 0
for i := range samples {
if j >= len(d.buf) {
// refill buffer.
if err := d.refill(); err != nil {
d.pos += n

for len(samples) > 0 {
samplesLeft := int(d.frame.BlockSize) - d.posInFrame
if samplesLeft <= 0 {
// Read next frame
var err error
d.frame, err = d.stream.ParseNext()
if err != nil {
d.frame = nil
if err == io.EOF {
return n, n > 0
}
d.err = err
d.err = errors.Wrap(err, "flac")
return 0, false
}
j = 0
d.posInFrame = 0
continue
}
samples[i] = d.buf[j]
j++
n++

toFill := min(samplesLeft, len(samples))
d.decodeFrameRangeInto(d.frame, d.posInFrame, toFill, samples)
d.posInFrame += toFill
n += toFill
samples = samples[toFill:]
}
d.buf = d.buf[j:]
d.pos += n

return n, true
}

// refill decodes audio samples to fill the decode buffer.
func (d *decoder) refill() error {
// Empty buffer.
d.buf = d.buf[:0]
// Parse audio frame.
frame, err := d.stream.ParseNext()
if err != nil {
return err
}
// Expand buffer size if needed.
n := len(frame.Subframes[0].Samples)
if cap(d.buf) < n {
d.buf = make([][2]float64, n)
} else {
d.buf = d.buf[:n]
}
// Decode audio samples.
// decodeFrameRangeInto decodes the samples frame from the position `start` up to `start + num`
// and stores them in Beep's format into the provided slice `into`.
func (d *decoder) decodeFrameRangeInto(frame *frame.Frame, start, num int, into [][2]float64) {
bps := d.stream.Info.BitsPerSample
nchannels := d.stream.Info.NChannels
numChannels := d.stream.Info.NChannels
s := 1 << (bps - 1)
q := 1 / float64(s)
switch {
case bps == 8 && nchannels == 1:
for i := 0; i < n; i++ {
d.buf[i][0] = float64(int8(frame.Subframes[0].Samples[i])) * q
d.buf[i][1] = float64(int8(frame.Subframes[0].Samples[i])) * q
case bps == 8 && numChannels == 1:
for i := 0; i < num; i++ {
into[i][0] = float64(int8(frame.Subframes[0].Samples[start+i])) * q
into[i][1] = into[i][0]
}
case bps == 16 && nchannels == 1:
for i := 0; i < n; i++ {
d.buf[i][0] = float64(int16(frame.Subframes[0].Samples[i])) * q
d.buf[i][1] = float64(int16(frame.Subframes[0].Samples[i])) * q
case bps == 16 && numChannels == 1:
for i := 0; i < num; i++ {
into[i][0] = float64(int16(frame.Subframes[0].Samples[start+i])) * q
into[i][1] = into[i][0]
}
case bps == 24 && nchannels == 1:
for i := 0; i < n; i++ {
d.buf[i][0] = float64(int32(frame.Subframes[0].Samples[i])) * q
d.buf[i][1] = float64(int32(frame.Subframes[0].Samples[i])) * q
case bps == 24 && numChannels == 1:
for i := 0; i < num; i++ {
into[i][0] = float64(int32(frame.Subframes[0].Samples[start+i])) * q
into[i][1] = into[i][0]
}
case bps == 8 && nchannels >= 2:
for i := 0; i < n; i++ {
d.buf[i][0] = float64(int8(frame.Subframes[0].Samples[i])) * q
d.buf[i][1] = float64(int8(frame.Subframes[1].Samples[i])) * q
case bps == 8 && numChannels >= 2:
for i := 0; i < num; i++ {
into[i][0] = float64(int8(frame.Subframes[0].Samples[start+i])) * q
into[i][1] = float64(int8(frame.Subframes[1].Samples[start+i])) * q
}
case bps == 16 && nchannels >= 2:
for i := 0; i < n; i++ {
d.buf[i][0] = float64(int16(frame.Subframes[0].Samples[i])) * q
d.buf[i][1] = float64(int16(frame.Subframes[1].Samples[i])) * q
case bps == 16 && numChannels >= 2:
for i := 0; i < num; i++ {
into[i][0] = float64(int16(frame.Subframes[0].Samples[start+i])) * q
into[i][1] = float64(int16(frame.Subframes[1].Samples[start+i])) * q
}
case bps == 24 && nchannels >= 2:
for i := 0; i < n; i++ {
d.buf[i][0] = float64(frame.Subframes[0].Samples[i]) * q
d.buf[i][1] = float64(frame.Subframes[1].Samples[i]) * q
case bps == 24 && numChannels >= 2:
for i := 0; i < num; i++ {
into[i][0] = float64(frame.Subframes[0].Samples[start+i]) * q
into[i][1] = float64(frame.Subframes[1].Samples[start+i]) * q
}
default:
panic(fmt.Errorf("support for %d bits-per-sample and %d channels combination not yet implemented", bps, nchannels))
panic(fmt.Errorf("flac: support for %d bits-per-sample and %d channels combination not yet implemented", bps, numChannels))
}
return nil
}

func (d *decoder) Err() error {
Expand All @@ -148,27 +145,27 @@ func (d *decoder) Len() int {
}

func (d *decoder) Position() int {
return d.pos
return int(d.frame.SampleNumber()) + d.posInFrame
}

func (d *decoder) Seek(p int) error {
if !d.seekEnabled {
return errors.New("flac.decoder.Seek: not enabled")
}

// d.stream.Seek() doesn't seek to the exact position p, instead
// it seeks to the start of the frame p is in. The frame position
// is returned and stored in pos.
pos, err := d.stream.Seek(uint64(p))
if err != nil {
return errors.Wrap(err, "flac")
}
d.posInFrame = p - int(pos)

toDiscard := p - int(pos)
err = d.refill()
d.frame, err = d.stream.ParseNext()
if err != nil {
return err
return errors.Wrap(err, "flac")
}
d.buf = d.buf[toDiscard:]

d.pos = p

return err
}
Expand Down
24 changes: 23 additions & 1 deletion flac/decode_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package flac_test

import (
"bytes"
"io"
"log"
"os"
Expand Down Expand Up @@ -44,7 +45,8 @@ func TestDecoder_Stream(t *testing.T) {
wavStream, _, err := wav.Decode(wavFile)
assert.NoError(t, err)

assert.Equal(t, wavStream.Len(), flacStream.Len())
assert.Equal(t, 22050, wavStream.Len())
assert.Equal(t, 22050, flacStream.Len())

wavSamples := testtools.Collect(wavStream)
flacSamples := testtools.Collect(flacStream)
Expand Down Expand Up @@ -139,3 +141,23 @@ func getFlacFrameStartPositions(r io.Reader) ([]uint64, error) {

return frameStarts, nil
}

func BenchmarkDecoder_Stream(b *testing.B) {
// Load the file into memory, so the disk performance doesn't impact the benchmark.
data, err := os.ReadFile(testtools.TestFilePath("valid_44100hz_22050_samples_ffmpeg.flac"))
assert.NoError(b, err)

r := bytes.NewReader(data)

b.Run("test", func(b *testing.B) {
s, _, err := flac.Decode(r)
assert.NoError(b, err)

samples := testtools.Collect(s)
assert.Equal(b, 22050, len(samples))

// Reset for next run.
_, err = r.Seek(0, io.SeekStart)
assert.NoError(b, err)
})
}

0 comments on commit c470588

Please sign in to comment.