From 598261505033d0255c37dc06b9b6c1112818a1be Mon Sep 17 00:00:00 2001 From: Travis Bischel Date: Mon, 24 Aug 2020 12:56:12 -0600 Subject: [PATCH] kmsg: breaking API in Read{RecordBatches,V{0,1}Messages} As pointed out in #8, I had no crc validation on decode. The most obvious place to add that validation (and length validation) is in kmsg, which also makes this validation more broadly applicable to those that do not want to use the kgo package. This has been tested with the chaining tests, and v0 / v1 message sets have been tested with kcl. --- pkg/kgo/source.go | 27 +++++++++++++++---- pkg/kmsg/interface.go | 60 ++++++++++++++++++++++++++++++++++--------- 2 files changed, 70 insertions(+), 17 deletions(-) diff --git a/pkg/kgo/source.go b/pkg/kgo/source.go index 38311088..3324c104 100644 --- a/pkg/kgo/source.go +++ b/pkg/kgo/source.go @@ -668,11 +668,22 @@ func (o *seqOffsetFrom) processRespPartition( switch version { case 0, 1: - o.processV0Messages(topic, &fetchPart, kmsg.ReadV0Messages(rPartition.RecordBatches), decompressor) + msgs, err := kmsg.ReadV0Messages(rPartition.RecordBatches) + if err != nil { + fetchPart.Err = err + } + o.processV0Messages(topic, &fetchPart, msgs, decompressor) case 2, 3: - o.processV1Messages(topic, &fetchPart, kmsg.ReadV1Messages(rPartition.RecordBatches), decompressor) + msgs, err := kmsg.ReadV1Messages(rPartition.RecordBatches) + if err != nil { + fetchPart.Err = err + } + o.processV1Messages(topic, &fetchPart, msgs, decompressor) default: - batches := kmsg.ReadRecordBatches(rPartition.RecordBatches) + batches, err := kmsg.ReadRecordBatches(rPartition.RecordBatches) + if err != nil { + fetchPart.Err = err + } var numPartitionRecords int for i := range batches { numPartitionRecords += int(batches[i].NumRecords) @@ -823,7 +834,10 @@ func (o *seqOffset) processV1Messages( fetchPart.Err = fmt.Errorf("unable to decompress messages: %v", err) return } - innerMessages := kmsg.ReadV1Messages(rawMessages) + innerMessages, err := kmsg.ReadV1Messages(rawMessages) + if err != nil { + fetchPart.Err = err + } if len(innerMessages) == 0 { return } @@ -877,7 +891,10 @@ func (o *seqOffset) processV0Messages( fetchPart.Err = fmt.Errorf("unable to decompress messages: %v", err) return } - innerMessages := kmsg.ReadV0Messages(rawMessages) + innerMessages, err := kmsg.ReadV0Messages(rawMessages) + if err != nil { + fetchPart.Err = err + } if len(innerMessages) == 0 { return } diff --git a/pkg/kmsg/interface.go b/pkg/kmsg/interface.go index 4412bfd5..d184ad97 100644 --- a/pkg/kmsg/interface.go +++ b/pkg/kmsg/interface.go @@ -21,6 +21,8 @@ package kmsg import ( "encoding/binary" + "errors" + "hash/crc32" "github.com/twmb/kafka-go/pkg/kbin" ) @@ -224,70 +226,104 @@ func ReadRecords(n int, in []byte) ([]Record, error) { return rs, nil } +// ErrEncodedCRCMismatch is returned from reading record batches or message sets when +// any batch or set has an encoded crc that does not match a calculated crc. +var ErrEncodedCRCMismatch = errors.New("encoded crc does not match calculated crc") + +// ErrEncodedLengthMismatch is returned from reading record batches or message +// sets when any batch or set has an encoded length that does not match the +// earlier read length of the batch / set. +var ErrEncodedLengthMismatch = errors.New("encoded length does not match read length") + +var crc32c = crc32.MakeTable(crc32.Castagnoli) // record crc's use Castagnoli table + // ReadRecordBatches reads as many record batches as possible from in, // discarding any final trailing record batch. This is intended to be used // for processing RecordBatches from a FetchResponse, where Kafka, as an // internal optimization, may include a partial final RecordBatch. -func ReadRecordBatches(in []byte) []RecordBatch { +func ReadRecordBatches(in []byte) ([]RecordBatch, error) { var bs []RecordBatch for len(in) > 12 { length := int32(binary.BigEndian.Uint32(in[8:])) length += 12 if len(in) < int(length) { - return bs + return bs, nil } + var b RecordBatch if err := b.ReadFrom(in[:length]); err != nil { - return bs + return bs, nil } + + if int32(len(in[12:length])) != b.Length { + return bs, ErrEncodedLengthMismatch + } + + // If we did not error, the length was at _least_ 21. + if int32(crc32.Checksum(in[21:length], crc32c)) != b.CRC { + return bs, ErrEncodedCRCMismatch + } + bs = append(bs, b) in = in[length:] } - return bs + return bs, nil } // ReadV1Messages reads as many v1 message sets as possible from // in, discarding any final trailing message set. This is intended to be used // for processing v1 MessageSets from a FetchResponse, where Kafka, as an // internal optimization, may include a partial final MessageSet. -func ReadV1Messages(in []byte) []MessageV1 { +func ReadV1Messages(in []byte) ([]MessageV1, error) { var ms []MessageV1 for len(in) > 12 { length := int32(binary.BigEndian.Uint32(in[8:])) length += 12 if len(in) < int(length) { - return ms + return ms, nil } var m MessageV1 if err := m.ReadFrom(in[:length]); err != nil { - return ms + return ms, nil + } + if int32(len(in[12:length])) != m.MessageSize { + return ms, ErrEncodedLengthMismatch + } + if int32(crc32.ChecksumIEEE(in[16:length])) != m.CRC { + return ms, ErrEncodedCRCMismatch } ms = append(ms, m) in = in[length:] } - return ms + return ms, nil } // ReadV0Messages reads as many v0 message sets as possible from // in, discarding any final trailing message set. This is intended to be used // for processing v0 MessageSets from a FetchResponse, where Kafka, as an // internal optimization, may include a partial final MessageSet. -func ReadV0Messages(in []byte) []MessageV0 { +func ReadV0Messages(in []byte) ([]MessageV0, error) { var ms []MessageV0 for len(in) > 12 { length := int32(binary.BigEndian.Uint32(in[8:])) length += 12 if len(in) < int(length) { - return ms + return ms, nil } var m MessageV0 if err := m.ReadFrom(in[:length]); err != nil { - return ms + return ms, nil + } + if int32(len(in[12:length])) != m.MessageSize { + return ms, ErrEncodedLengthMismatch + } + if int32(crc32.ChecksumIEEE(in[16:length])) != m.CRC { + return ms, ErrEncodedCRCMismatch } ms = append(ms, m) in = in[length:] } - return ms + return ms, nil } // ReadFrom provides decoding various versions of sticky member metadata. A key