diff --git a/pkg/kgo/source.go b/pkg/kgo/source.go index 38311088..3324c104 100644 --- a/pkg/kgo/source.go +++ b/pkg/kgo/source.go @@ -668,11 +668,22 @@ func (o *seqOffsetFrom) processRespPartition( switch version { case 0, 1: - o.processV0Messages(topic, &fetchPart, kmsg.ReadV0Messages(rPartition.RecordBatches), decompressor) + msgs, err := kmsg.ReadV0Messages(rPartition.RecordBatches) + if err != nil { + fetchPart.Err = err + } + o.processV0Messages(topic, &fetchPart, msgs, decompressor) case 2, 3: - o.processV1Messages(topic, &fetchPart, kmsg.ReadV1Messages(rPartition.RecordBatches), decompressor) + msgs, err := kmsg.ReadV1Messages(rPartition.RecordBatches) + if err != nil { + fetchPart.Err = err + } + o.processV1Messages(topic, &fetchPart, msgs, decompressor) default: - batches := kmsg.ReadRecordBatches(rPartition.RecordBatches) + batches, err := kmsg.ReadRecordBatches(rPartition.RecordBatches) + if err != nil { + fetchPart.Err = err + } var numPartitionRecords int for i := range batches { numPartitionRecords += int(batches[i].NumRecords) @@ -823,7 +834,10 @@ func (o *seqOffset) processV1Messages( fetchPart.Err = fmt.Errorf("unable to decompress messages: %v", err) return } - innerMessages := kmsg.ReadV1Messages(rawMessages) + innerMessages, err := kmsg.ReadV1Messages(rawMessages) + if err != nil { + fetchPart.Err = err + } if len(innerMessages) == 0 { return } @@ -877,7 +891,10 @@ func (o *seqOffset) processV0Messages( fetchPart.Err = fmt.Errorf("unable to decompress messages: %v", err) return } - innerMessages := kmsg.ReadV0Messages(rawMessages) + innerMessages, err := kmsg.ReadV0Messages(rawMessages) + if err != nil { + fetchPart.Err = err + } if len(innerMessages) == 0 { return } diff --git a/pkg/kmsg/interface.go b/pkg/kmsg/interface.go index 4412bfd5..d184ad97 100644 --- a/pkg/kmsg/interface.go +++ b/pkg/kmsg/interface.go @@ -21,6 +21,8 @@ package kmsg import ( "encoding/binary" + "errors" + "hash/crc32" "github.com/twmb/kafka-go/pkg/kbin" ) @@ -224,70 +226,104 @@ func ReadRecords(n int, in []byte) ([]Record, error) { return rs, nil } +// ErrEncodedCRCMismatch is returned from reading record batches or message sets when +// any batch or set has an encoded crc that does not match a calculated crc. +var ErrEncodedCRCMismatch = errors.New("encoded crc does not match calculated crc") + +// ErrEncodedLengthMismatch is returned from reading record batches or message +// sets when any batch or set has an encoded length that does not match the +// earlier read length of the batch / set. +var ErrEncodedLengthMismatch = errors.New("encoded length does not match read length") + +var crc32c = crc32.MakeTable(crc32.Castagnoli) // record crc's use Castagnoli table + // ReadRecordBatches reads as many record batches as possible from in, // discarding any final trailing record batch. This is intended to be used // for processing RecordBatches from a FetchResponse, where Kafka, as an // internal optimization, may include a partial final RecordBatch. -func ReadRecordBatches(in []byte) []RecordBatch { +func ReadRecordBatches(in []byte) ([]RecordBatch, error) { var bs []RecordBatch for len(in) > 12 { length := int32(binary.BigEndian.Uint32(in[8:])) length += 12 if len(in) < int(length) { - return bs + return bs, nil } + var b RecordBatch if err := b.ReadFrom(in[:length]); err != nil { - return bs + return bs, nil } + + if int32(len(in[12:length])) != b.Length { + return bs, ErrEncodedLengthMismatch + } + + // If we did not error, the length was at _least_ 21. + if int32(crc32.Checksum(in[21:length], crc32c)) != b.CRC { + return bs, ErrEncodedCRCMismatch + } + bs = append(bs, b) in = in[length:] } - return bs + return bs, nil } // ReadV1Messages reads as many v1 message sets as possible from // in, discarding any final trailing message set. This is intended to be used // for processing v1 MessageSets from a FetchResponse, where Kafka, as an // internal optimization, may include a partial final MessageSet. -func ReadV1Messages(in []byte) []MessageV1 { +func ReadV1Messages(in []byte) ([]MessageV1, error) { var ms []MessageV1 for len(in) > 12 { length := int32(binary.BigEndian.Uint32(in[8:])) length += 12 if len(in) < int(length) { - return ms + return ms, nil } var m MessageV1 if err := m.ReadFrom(in[:length]); err != nil { - return ms + return ms, nil + } + if int32(len(in[12:length])) != m.MessageSize { + return ms, ErrEncodedLengthMismatch + } + if int32(crc32.ChecksumIEEE(in[16:length])) != m.CRC { + return ms, ErrEncodedCRCMismatch } ms = append(ms, m) in = in[length:] } - return ms + return ms, nil } // ReadV0Messages reads as many v0 message sets as possible from // in, discarding any final trailing message set. This is intended to be used // for processing v0 MessageSets from a FetchResponse, where Kafka, as an // internal optimization, may include a partial final MessageSet. -func ReadV0Messages(in []byte) []MessageV0 { +func ReadV0Messages(in []byte) ([]MessageV0, error) { var ms []MessageV0 for len(in) > 12 { length := int32(binary.BigEndian.Uint32(in[8:])) length += 12 if len(in) < int(length) { - return ms + return ms, nil } var m MessageV0 if err := m.ReadFrom(in[:length]); err != nil { - return ms + return ms, nil + } + if int32(len(in[12:length])) != m.MessageSize { + return ms, ErrEncodedLengthMismatch + } + if int32(crc32.ChecksumIEEE(in[16:length])) != m.CRC { + return ms, ErrEncodedCRCMismatch } ms = append(ms, m) in = in[length:] } - return ms + return ms, nil } // ReadFrom provides decoding various versions of sticky member metadata. A key