diff --git a/snappy.go b/snappy.go index b8f8b51..de69eac 100644 --- a/snappy.go +++ b/snappy.go @@ -3,11 +3,22 @@ package snappy import ( "bytes" "encoding/binary" + "errors" master "github.com/golang/snappy" ) -var xerialHeader = []byte{130, 83, 78, 65, 80, 80, 89, 0} +const ( + sizeOffset = 16 + sizeBytes = 4 +) + +var ( + xerialHeader = []byte{130, 83, 78, 65, 80, 80, 89, 0} + // ErrMalformed is returned by the decoder when the xerial framing + // is malformed + ErrMalformed = errors.New("malformed xerial framing") +) // Encode encodes data as snappy with no framing header. func Encode(src []byte) []byte { @@ -17,26 +28,43 @@ func Encode(src []byte) []byte { // Decode decodes snappy data whether it is traditional unframed // or includes the xerial framing format. func Decode(src []byte) ([]byte, error) { + var max = len(src) + if max < len(xerialHeader) { + return nil, ErrMalformed + } + if !bytes.Equal(src[:8], xerialHeader) { return master.Decode(nil, src) } + if max < sizeOffset+sizeBytes { + return nil, ErrMalformed + } + var ( - pos = uint32(16) - max = uint32(len(src)) + pos = sizeOffset dst = make([]byte, 0, len(src)) chunk []byte err error ) - for pos < max { - size := binary.BigEndian.Uint32(src[pos : pos+4]) - pos += 4 - chunk, err = master.Decode(chunk, src[pos:pos+size]) + for pos+sizeBytes <= max { + size := int(binary.BigEndian.Uint32(src[pos : pos+sizeBytes])) + pos += sizeBytes + + nextPos := pos + size + // On architectures where int is 32-bytes wide size + pos could + // overflow so we need to check the low bound as well as the + // high + if nextPos < pos || nextPos > max { + return nil, ErrMalformed + } + + chunk, err = master.Decode(chunk, src[pos:nextPos]) if err != nil { return nil, err } - pos += size + pos = nextPos dst = append(dst, chunk...) } return dst, nil diff --git a/snappy_test.go b/snappy_test.go index e94f635..3e7eddb 100644 --- a/snappy_test.go +++ b/snappy_test.go @@ -47,3 +47,49 @@ func TestSnappyDecodeStreams(t *testing.T) { } } } + +func TestSnappyDecodeMalformedTruncatedHeader(t *testing.T) { + // Truncated headers should not cause a panic. + for i := 0; i < len(xerialHeader); i++ { + buf := make([]byte, i) + copy(buf, xerialHeader[:i]) + if _, err := Decode(buf); err != ErrMalformed { + t.Errorf("expected ErrMalformed got %v", err) + } + } +} + +func TestSnappyDecodeMalformedTruncatedSize(t *testing.T) { + // Inputs with valid Xerial header but truncated "size" field + sizes := []int{sizeOffset + 1, sizeOffset + 2, sizeOffset + 3} + for _, size := range sizes { + buf := make([]byte, size) + copy(buf, xerialHeader) + if _, err := Decode(buf); err != ErrMalformed { + t.Errorf("expected ErrMalformed got %v", err) + } + } +} + +func TestSnappyDecodeMalformedBNoData(t *testing.T) { + // No data after the size field + buf := make([]byte, 20) + copy(buf, xerialHeader) + // indicate that there's one byte of data to be read + buf[len(buf)-1] = 1 + if _, err := Decode(buf); err != ErrMalformed { + t.Errorf("expected ErrMalformed got %v", err) + } +} + +func TestSnappyMasterDecodeFailed(t *testing.T) { + buf := make([]byte, 21) + copy(buf, xerialHeader) + // indicate that there's one byte of data to be read + buf[len(buf)-2] = 1 + // A payload which will not decode + buf[len(buf)-1] = 1 + if _, err := Decode(buf); err == ErrMalformed || err == nil { + t.Errorf("unexpected err: %v", err) + } +}