Skip to content

Commit

Permalink
fix: resolve panic when parsing corrupt messages with `lazyproto.Deco…
Browse files Browse the repository at this point in the history
…de()`

added length validation in `DecodeBytes()`, `DecodeNested()`, and `Skip()` on `Decoder`
added tests
  • Loading branch information
Dylan Bourque authored and dylan-bourque committed Aug 15, 2024
1 parent 7f0a4a2 commit 9373f47
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 13 deletions.
44 changes: 32 additions & 12 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ var (
ErrInvalidVarintData = errors.New("unable to read protobuf varint value")
// ErrValueOverflow is returned by DecodeUInt32() or DecodeInt32() when the decoded value is too large for a 32-bit value.
ErrValueOverflow = errors.New("value overflow trying to read protobuf varint value")
// ErrLenOverflow is returned when the LEN portion of a length-delimited field is larger than 2GB
ErrLenOverflow = errors.New("field length cannot be more than 2GB")
// ErrInvalidZigZagData is returned by the decoder when it fails to read a zigzag-encoded value.
ErrInvalidZigZagData = errors.New("unable to read protobuf zigzag value")
// ErrInvalidFixed32Data is returned by the decoder when it fails to read a fixed-size 32-bit value.
Expand All @@ -29,6 +31,9 @@ var (
// MaxTagValue is the largest supported protobuf field tag, which is 2^29 - 1 (or 536,870,911)
const MaxTagValue = 536870911

// length-delimited fields cannot contain more than 2GB
const maxFieldLen = math.MaxInt32

// DecoderMode defines the behavior of the decoder (safe vs fastest).
type DecoderMode int

Expand Down Expand Up @@ -183,17 +188,20 @@ func (d *Decoder) DecodeBytes() ([]byte, error) {
if d.offset >= len(d.p) {
return nil, io.ErrUnexpectedEOF
}

l, n, err := DecodeVarint(d.p[d.offset:])
if err != nil {
switch {
case err != nil:
return nil, fmt.Errorf("invalid data at byte %d: %w", d.offset, err)
}
if n == 0 {
case n == 0:
return nil, fmt.Errorf("invalid data at byte %d: %w", d.offset, ErrInvalidVarintData)
case l > maxFieldLen:
return nil, fmt.Errorf("invalid length (%d) for length-delimited field at byte %d: %w", l, d.offset, ErrLenOverflow)
default:
// length is good
}

nb := int(l)
if nb < 0 {
return nil, fmt.Errorf("csproto: bad byte length %d", nb)
}
if d.offset+n+nb > len(d.p) {
return nil, io.ErrUnexpectedEOF
}
Expand Down Expand Up @@ -865,13 +873,19 @@ func (d *Decoder) DecodeNested(m interface{}) error {
if d.offset >= len(d.p) {
return io.ErrUnexpectedEOF
}

l, n, err := DecodeVarint(d.p[d.offset:])
if err != nil {
switch {
case err != nil:
return fmt.Errorf("invalid data at byte %d: %w", d.offset, err)
}
if n == 0 {
case n == 0:
return fmt.Errorf("invalid data at byte %d: %w", d.offset, ErrInvalidVarintData)
case l > maxFieldLen:
return fmt.Errorf("invalid length (%d) for length-delimited field at byte %d: %w", l, d.offset, ErrLenOverflow)
default:
// length is good
}

nb := int(l)
if nb < 0 {
return fmt.Errorf("csproto: bad byte length %d at byte %d", nb, d.offset)
Expand Down Expand Up @@ -941,13 +955,19 @@ func (d *Decoder) Skip(tag int, wt WireType) ([]byte, error) {
skipped = 8
case WireTypeLengthDelimited:
l, n, err := DecodeVarint(d.p[d.offset:])
if err != nil {
switch {
case err != nil:
return nil, fmt.Errorf("invalid data at byte %d: %w", d.offset, err)
}
if n == 0 {
case n == 0:
return nil, fmt.Errorf("invalid data at byte %d: %w", d.offset, ErrInvalidVarintData)
case l > maxFieldLen:
return nil, fmt.Errorf("invalid length (%d) for length-delimited field at byte %d: %w", l, d.offset, ErrLenOverflow)
default:
// length is good
}

skipped = n + int(l)

case WireTypeFixed32:
skipped = 4
default:
Expand Down
44 changes: 43 additions & 1 deletion decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestDecoderSeek(t *testing.T) {
startPos := int64(dec.Offset())
pos, err := dec.Seek(0, io.SeekCurrent)
assert.NoError(t, err)
assert.Equal(t, int64(startPos), pos)
assert.Equal(t, startPos, pos)
})
t.Run("invalid positive seek", func(t *testing.T) {
dec.Reset()
Expand Down Expand Up @@ -290,6 +290,15 @@ func TestDecodeBytes(t *testing.T) {
data := []byte{0xCE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01, 0x3, 0x42, 0x11, 0x38}
dec := csproto.NewDecoder(data)

got, err := dec.DecodeBytes()
assert.Error(t, err)
assert.Nil(t, got)
})
t.Run("length overflow", func(t *testing.T) {
// field length greater than 2GB
data := []byte{0x80, 0x80, 0x80, 0x80, 0x08}
dec := csproto.NewDecoder(data)

got, err := dec.DecodeBytes()
assert.Error(t, err)
assert.Nil(t, got)
Expand Down Expand Up @@ -1195,6 +1204,39 @@ func TestDecoderInvalidSkip(t *testing.T) {
_, err := dec.Skip(1, wt)
assert.ErrorAs(t, err, &skipErr)
}

t.Run("corrupt messages", func(t *testing.T) {
t.Run("truncated data", func(t *testing.T) {
// length-delimited field value with a length of 3 but only 2 bytes
data := []byte{0x0A, 0x3, 0x42, 0x11}
dec := csproto.NewDecoder(data)

tag, wt, _ := dec.DecodeTag()
got, err := dec.Skip(tag, wt)
assert.Error(t, err)
assert.Nil(t, got)
})
t.Run("negative length", func(t *testing.T) {
// length-delimited field value with a length of -50
data := []byte{0x0A, 0xCE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01, 0x3, 0x42, 0x11, 0x38}
dec := csproto.NewDecoder(data)

tag, wt, _ := dec.DecodeTag()
got, err := dec.Skip(tag, wt)
assert.Error(t, err)
assert.Nil(t, got)
})
t.Run("length overflow", func(t *testing.T) {
// length-delimited field with length greater than 2GB
data := []byte{0x0A, 0x80, 0x80, 0x80, 0x80, 0x08}
dec := csproto.NewDecoder(data)

tag, wt, _ := dec.DecodeTag()
got, err := dec.Skip(tag, wt)
assert.Error(t, err)
assert.Nil(t, got)
})
})
}

func TestDecodePastEndOfBuffer(t *testing.T) {
Expand Down
87 changes: 87 additions & 0 deletions lazyproto/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/CrowdStrike/csproto"
"github.com/CrowdStrike/csproto/prototest"
)

func ExampleDecodeResult_FieldData() {
Expand Down Expand Up @@ -2000,3 +2002,88 @@ func TestFloat64FieldData(t *testing.T) {
assert.ErrorAs(t, err, &expectedErr, "should return a WireTypeMismatchError error")
})
}

func Test_Issue158(t *testing.T) {
// this test verified the bug, and subsequent fix, for Issue #159 (https://github.com/CrowdStrike/csproto/issues/158)

// This is the encoded content of a representative message that led to the panic noted in the
// linked issue. Roughly, there is an outer "envelope" message with a type discriminant where
// the "payload" data is defined by a proto2 extension field. In this case, the payload contains
// a bytes field that is itself an encoded proto message.
//
// In broad terms, the messages look like this:
// enum MessageType {
// ...
// }
// message Envelope {
// extensions 100 to max;
// required MessageType messageType = 1;
// ...
// }
// message WrappedMessagePayload {
// extend Envelope {
// optional WrappedMessagePayload payload = 100;
// }
// required uint32 eventTypeID = 1;
// required bytes dataBytes = 2;
// ...
// }
//
// And the wrapped message looks like:
// message InnerMessage {
// ...
// optional uint64 processID = 6;
// ...
// optional uint32 patternID = 74;
// ...
// optional string metadata = 503;
// ...
// optional uint32 templateID = 744;
// ...
// }
//
// Specifically for this bug, the upstream system generated a corrupted message where the length of
// the InnerMessage.metadata field was very large, which resulted in a negative value when converted
// from uint64 to int. That negative value was then used inside of Skip() to update the decoder's
// read offset and, subseqeuently, to extract the sub-slice containing the skipped field.
const data = `; envelope
08 ; tag=1 (messageType), varint
64 ; value=100 (payload type = WrappedMessagePayload)
A2 06 ; tag=100 (WrappedMessagePayload extension), length-delimited
1B ; len=27
08 ; tag=1 (eventTypeId), varint
01 ; value=1
12 ; tag=2 (dataBytes), length-delimited
17 ; len=23
; InnerMessage
30 ; tag=6 (processID), varint
01 ; value=1
D0 04 ; tag=74 (patternID), varint
01 ; value=1
BA 1F ; tag=503 (metadata), length-delimited
; * CORRUPT VALUE *
; 11,686,238,624,781,661,536 is a valid varint value but overflows the range of int and
; becomes negative when converted from uint64
E0 EA BD B4 CE 83 F7 96 A2 01 ; len=11x10^18
66 6F 6F ; "foo"
C0 2E ; tag=744 (templateID), varint
01 ; value=1`

bb, _ := prototest.ParseAnnotatedHex(data)
def := NewDef()
def.NestedTag(100, 2)
res, err := Decode(bb, def)
require.NoError(t, err, "error from first Decode()")

fd, err := res.FieldData(100, 2)
require.NoError(t, err, "error extracting field data from extension")

evt, err := fd.BytesValue()
require.NoError(t, err, "error extracting bytes from field data")

// these Decode() calls should fail due to the corrupt field length
_, err = Decode(evt, NewDef(74))
assert.Error(t, err, "expected error from Decode() when data is corrupted")
_, err = Decode(evt, NewDef(744))
assert.Error(t, err, "expected error from Decode() when data is corrupted")
}

0 comments on commit 9373f47

Please sign in to comment.