diff --git a/zstd.go b/zstd.go index 164a923..634ed65 100644 --- a/zstd.go +++ b/zstd.go @@ -58,18 +58,26 @@ func CompressLevel(dst, src []byte, level int) ([]byte, error) { dst = make([]byte, bound) } - var srcPtr *byte // Do not point anywhere, if src is empty - if len(src) > 0 { - srcPtr = &src[0] + // We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics. + // This means we need to special case empty input. See: + // https://github.com/golang/go/issues/14210#issuecomment-346402945 + var cWritten C.size_t + if len(src) == 0 { + cWritten = C.ZSTD_compress( + unsafe.Pointer(&dst[0]), + C.size_t(len(dst)), + unsafe.Pointer(nil), + C.size_t(0), + C.int(level)) + } else { + cWritten = C.ZSTD_compress( + unsafe.Pointer(&dst[0]), + C.size_t(len(dst)), + unsafe.Pointer(&src[0]), + C.size_t(len(src)), + C.int(level)) } - cWritten := C.ZSTD_compress( - unsafe.Pointer(&dst[0]), - C.size_t(len(dst)), - unsafe.Pointer(srcPtr), - C.size_t(len(src)), - C.int(level)) - written := int(cWritten) // Check if the return is an Error code if err := getError(written); err != nil { diff --git a/zstd_ctx.go b/zstd_ctx.go index 12e9539..6b98943 100644 --- a/zstd_ctx.go +++ b/zstd_ctx.go @@ -63,19 +63,28 @@ func (c *ctx) CompressLevel(dst, src []byte, level int) ([]byte, error) { dst = make([]byte, bound) } - var srcPtr *byte // Do not point anywhere, if src is empty - if len(src) > 0 { - srcPtr = &src[0] + // We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics. + // This means we need to special case empty input. See: + // https://github.com/golang/go/issues/14210#issuecomment-346402945 + var cWritten C.size_t + if len(src) == 0 { + cWritten = C.ZSTD_compressCCtx( + c.cctx, + unsafe.Pointer(&dst[0]), + C.size_t(len(dst)), + unsafe.Pointer(nil), + C.size_t(0), + C.int(level)) + } else { + cWritten = C.ZSTD_compressCCtx( + c.cctx, + unsafe.Pointer(&dst[0]), + C.size_t(len(dst)), + unsafe.Pointer(&src[0]), + C.size_t(len(src)), + C.int(level)) } - cWritten := C.ZSTD_compressCCtx( - c.cctx, - unsafe.Pointer(&dst[0]), - C.size_t(len(dst)), - unsafe.Pointer(srcPtr), - C.size_t(len(src)), - C.int(level)) - written := int(cWritten) // Check if the return is an Error code if err := getError(written); err != nil { diff --git a/zstd_ctx_test.go b/zstd_ctx_test.go index 831a21f..ac82091 100644 --- a/zstd_ctx_test.go +++ b/zstd_ctx_test.go @@ -65,6 +65,13 @@ func TestCtxCompressLevel(t *testing.T) { } } +func TestCtxCompressLevelNoGoPointers(t *testing.T) { + testCompressNoGoPointers(t, func(input []byte) ([]byte, error) { + cctx := NewCtx() + return cctx.CompressLevel(nil, input, BestSpeed) + }) +} + func TestCtxEmptySliceCompress(t *testing.T) { ctx := NewCtx() diff --git a/zstd_stream.go b/zstd_stream.go index 1ed0e98..f9eb2de 100644 --- a/zstd_stream.go +++ b/zstd_stream.go @@ -168,17 +168,17 @@ func (w *Writer) Write(p []byte) (int, error) { srcData = w.srcBuffer } - var srcPtr *byte // Do not point anywhere, if src is empty - if len(srcData) > 0 { - srcPtr = &srcData[0] + if len(srcData) == 0 { + // this is technically unnecessary: srcData is p or w.srcBuffer, and len() > 0 checked above + // but this ensures the code can change without dereferencing an srcData[0] + return 0, nil } - C.ZSTD_compressStream2_wrapper( w.resultBuffer, w.ctx, unsafe.Pointer(&w.dstBuffer[0]), C.size_t(len(w.dstBuffer)), - unsafe.Pointer(srcPtr), + unsafe.Pointer(&srcData[0]), C.size_t(len(srcData)), ) ret := int(w.resultBuffer.return_code) diff --git a/zstd_stream_test.go b/zstd_stream_test.go index 06acece..79f412e 100644 --- a/zstd_stream_test.go +++ b/zstd_stream_test.go @@ -375,6 +375,22 @@ func TestStreamDecompressionChunks(t *testing.T) { } } +func TestStreamWriteNoGoPointers(t *testing.T) { + testCompressNoGoPointers(t, func(input []byte) ([]byte, error) { + buf := &bytes.Buffer{} + zw := NewWriter(buf) + _, err := zw.Write(input) + if err != nil { + return nil, err + } + err = zw.Close() + if err != nil { + return nil, err + } + return buf.Bytes(), nil + }) +} + func BenchmarkStreamCompression(b *testing.B) { if raw == nil { b.Fatal(ErrNoPayloadEnv) diff --git a/zstd_test.go b/zstd_test.go index e4d90c8..e5bb2d2 100644 --- a/zstd_test.go +++ b/zstd_test.go @@ -109,6 +109,43 @@ func TestCompressLevel(t *testing.T) { } } +// structWithGoPointers contains a byte buffer and a pointer to Go objects (slice). This means +// Cgo checks can fail when passing a pointer to buffer: +// "panic: runtime error: cgo argument has Go pointer to Go pointer" +// https://github.com/golang/go/issues/14210#issuecomment-346402945 +type structWithGoPointers struct { + buffer [1]byte + slice []byte +} + +// testCompressDecompressByte ensures that functions use the correct unsafe.Pointer assignment +// to avoid "Go pointer to Go pointer" panics. +func testCompressNoGoPointers(t *testing.T, compressFunc func(input []byte) ([]byte, error)) { + t.Helper() + + s := structWithGoPointers{} + s.buffer[0] = 0x42 + s.slice = s.buffer[:1] + + compressed, err := compressFunc(s.slice) + if err != nil { + t.Fatal(err) + } + decompressed, err := Decompress(nil, compressed) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decompressed, s.slice) { + t.Errorf("decompressed=%#v input=%#v", decompressed, s.slice) + } +} + +func TestCompressLevelNoGoPointers(t *testing.T) { + testCompressNoGoPointers(t, func(input []byte) ([]byte, error) { + return CompressLevel(nil, input, BestSpeed) + }) +} + func doCompressLevel(payload []byte, out []byte) error { out, err := CompressLevel(out, payload, DefaultCompression) if err != nil {