Skip to content

Commit

Permalink
Merge pull request #97 from DataDog/evan.jones/pointer-to-go-pointer
Browse files Browse the repository at this point in the history
Fix "Go pointer to Go pointer" panics
  • Loading branch information
Viq111 authored Mar 8, 2021
2 parents 12a1eb7 + ee09518 commit e292af4
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 26 deletions.
28 changes: 18 additions & 10 deletions zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,26 @@ func CompressLevel(dst, src []byte, level int) ([]byte, error) {
dst = make([]byte, bound)
}

var srcPtr *byte // Do not point anywhere, if src is empty
if len(src) > 0 {
srcPtr = &src[0]
// We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics.
// This means we need to special case empty input. See:
// https://github.com/golang/go/issues/14210#issuecomment-346402945
var cWritten C.size_t
if len(src) == 0 {
cWritten = C.ZSTD_compress(
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(nil),
C.size_t(0),
C.int(level))
} else {
cWritten = C.ZSTD_compress(
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(&src[0]),
C.size_t(len(src)),
C.int(level))
}

cWritten := C.ZSTD_compress(
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(srcPtr),
C.size_t(len(src)),
C.int(level))

written := int(cWritten)
// Check if the return is an Error code
if err := getError(written); err != nil {
Expand Down
31 changes: 20 additions & 11 deletions zstd_ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,28 @@ func (c *ctx) CompressLevel(dst, src []byte, level int) ([]byte, error) {
dst = make([]byte, bound)
}

var srcPtr *byte // Do not point anywhere, if src is empty
if len(src) > 0 {
srcPtr = &src[0]
// We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics.
// This means we need to special case empty input. See:
// https://github.com/golang/go/issues/14210#issuecomment-346402945
var cWritten C.size_t
if len(src) == 0 {
cWritten = C.ZSTD_compressCCtx(
c.cctx,
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(nil),
C.size_t(0),
C.int(level))
} else {
cWritten = C.ZSTD_compressCCtx(
c.cctx,
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(&src[0]),
C.size_t(len(src)),
C.int(level))
}

cWritten := C.ZSTD_compressCCtx(
c.cctx,
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(srcPtr),
C.size_t(len(src)),
C.int(level))

written := int(cWritten)
// Check if the return is an Error code
if err := getError(written); err != nil {
Expand Down
7 changes: 7 additions & 0 deletions zstd_ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ func TestCtxCompressLevel(t *testing.T) {
}
}

func TestCtxCompressLevelNoGoPointers(t *testing.T) {
testCompressNoGoPointers(t, func(input []byte) ([]byte, error) {
cctx := NewCtx()
return cctx.CompressLevel(nil, input, BestSpeed)
})
}

func TestCtxEmptySliceCompress(t *testing.T) {
ctx := NewCtx()

Expand Down
10 changes: 5 additions & 5 deletions zstd_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,17 @@ func (w *Writer) Write(p []byte) (int, error) {
srcData = w.srcBuffer
}

var srcPtr *byte // Do not point anywhere, if src is empty
if len(srcData) > 0 {
srcPtr = &srcData[0]
if len(srcData) == 0 {
// this is technically unnecessary: srcData is p or w.srcBuffer, and len() > 0 checked above
// but this ensures the code can change without dereferencing an srcData[0]
return 0, nil
}

C.ZSTD_compressStream2_wrapper(
w.resultBuffer,
w.ctx,
unsafe.Pointer(&w.dstBuffer[0]),
C.size_t(len(w.dstBuffer)),
unsafe.Pointer(srcPtr),
unsafe.Pointer(&srcData[0]),
C.size_t(len(srcData)),
)
ret := int(w.resultBuffer.return_code)
Expand Down
16 changes: 16 additions & 0 deletions zstd_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,22 @@ func TestStreamDecompressionChunks(t *testing.T) {
}
}

func TestStreamWriteNoGoPointers(t *testing.T) {
testCompressNoGoPointers(t, func(input []byte) ([]byte, error) {
buf := &bytes.Buffer{}
zw := NewWriter(buf)
_, err := zw.Write(input)
if err != nil {
return nil, err
}
err = zw.Close()
if err != nil {
return nil, err
}
return buf.Bytes(), nil
})
}

func BenchmarkStreamCompression(b *testing.B) {
if raw == nil {
b.Fatal(ErrNoPayloadEnv)
Expand Down
37 changes: 37 additions & 0 deletions zstd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,43 @@ func TestCompressLevel(t *testing.T) {
}
}

// structWithGoPointers contains a byte buffer and a pointer to Go objects (slice). This means
// Cgo checks can fail when passing a pointer to buffer:
// "panic: runtime error: cgo argument has Go pointer to Go pointer"
// https://github.com/golang/go/issues/14210#issuecomment-346402945
type structWithGoPointers struct {
buffer [1]byte
slice []byte
}

// testCompressDecompressByte ensures that functions use the correct unsafe.Pointer assignment
// to avoid "Go pointer to Go pointer" panics.
func testCompressNoGoPointers(t *testing.T, compressFunc func(input []byte) ([]byte, error)) {
t.Helper()

s := structWithGoPointers{}
s.buffer[0] = 0x42
s.slice = s.buffer[:1]

compressed, err := compressFunc(s.slice)
if err != nil {
t.Fatal(err)
}
decompressed, err := Decompress(nil, compressed)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(decompressed, s.slice) {
t.Errorf("decompressed=%#v input=%#v", decompressed, s.slice)
}
}

func TestCompressLevelNoGoPointers(t *testing.T) {
testCompressNoGoPointers(t, func(input []byte) ([]byte, error) {
return CompressLevel(nil, input, BestSpeed)
})
}

func doCompressLevel(payload []byte, out []byte) error {
out, err := CompressLevel(out, payload, DefaultCompression)
if err != nil {
Expand Down

0 comments on commit e292af4

Please sign in to comment.