From 0cdc2b67f6defb64b1e60bd527f958f406a88926 Mon Sep 17 00:00:00 2001 From: Travis Bischel Date: Mon, 22 Feb 2021 15:44:55 -0700 Subject: [PATCH] kgo / kmsg: process the RecordBatches field in FetchResponses properly As it turns out, a FetchResponse can be of high enough version to use actual record batches, but still use message sets. Rather than relying on the response version, we need to check the magic in each batch. It may be possible that each batch has a different magic from one batch to the next, so we cannot just check the first batch and use that for all decoding. For message set v1, it is also possible that a client could have used messages v0, compressed them, and used that as the "inner messages" in a message v1. All of these cases are now handled properly, which necessitated the removal of some functions from kmsg. The functions in kmsg were not necessarily correct. Further, it was a bit odd to stuff message decoding and validating into kmsg. This makes kmsg a more dedicated encoding / decoding package. --- pkg/kgo/client.go | 3 + pkg/kgo/sink.go | 2 - pkg/kgo/source.go | 298 +++++++++++++++++++++++++++++++----------- pkg/kmsg/interface.go | 121 ----------------- 4 files changed, 226 insertions(+), 198 deletions(-) diff --git a/pkg/kgo/client.go b/pkg/kgo/client.go index ce42f555..46926be4 100644 --- a/pkg/kgo/client.go +++ b/pkg/kgo/client.go @@ -20,6 +20,7 @@ package kgo import ( "context" "fmt" + "hash/crc32" "math/rand" "reflect" "sort" @@ -33,6 +34,8 @@ import ( "github.com/twmb/franz-go/pkg/kmsg" ) +var crc32c = crc32.MakeTable(crc32.Castagnoli) // record crc's use Castagnoli table; for consuming/producing + // Client issues requests and handles responses to a Kafka cluster. type Client struct { cfg cfg diff --git a/pkg/kgo/sink.go b/pkg/kgo/sink.go index 798129f7..055c659a 100644 --- a/pkg/kgo/sink.go +++ b/pkg/kgo/sink.go @@ -1581,8 +1581,6 @@ func (r seqRecBatch) appendTo( return dst } -var crc32c = crc32.MakeTable(crc32.Castagnoli) // record crc's use Castagnoli table - func (pnr promisedNumberedRecord) appendTo(dst []byte, offsetDelta int32) []byte { dst = kbin.AppendVarint(dst, pnr.lengthField) dst = kbin.AppendInt8(dst, 0) // attributes, currently unused diff --git a/pkg/kgo/source.go b/pkg/kgo/source.go index 77deaf1f..d7b6e7a3 100644 --- a/pkg/kgo/source.go +++ b/pkg/kgo/source.go @@ -2,15 +2,22 @@ package kgo import ( "context" + "encoding/binary" "fmt" + "hash/crc32" "sync" "sync/atomic" "time" + "github.com/twmb/franz-go/pkg/kbin" "github.com/twmb/franz-go/pkg/kerr" "github.com/twmb/franz-go/pkg/kmsg" ) +type readerFrom interface { + ReadFrom([]byte) error +} + // A source consumes from an individual broker. // // As long as there is at least one active cursor, a source aims to have *one* @@ -764,37 +771,91 @@ func (o *cursorOffsetNext) processRespPartition(version int16, rp *kmsg.FetchRes LogStartOffset: rp.LogStartOffset, } - switch version { - case 0, 1: - msgs, err := kmsg.ReadV0Messages(rp.RecordBatches) - if err != nil { - fp.Err = err + aborter := buildAborter(rp) + + // A response could contain any of message v0, message v1, or record + // batches, and this is solely dictated by the magic byte (not the + // fetch response version). The magic byte is located at byte 17. + // + // 0 thru 8: int64 offset / first offset + // 9 thru 12: int32 length + // 13 thru 16: crc (magic 0 or 1), or partition leader epoch (magic 2) + // 17: magic + // + // We decode and validate similarly for messages and record batches, so + // we "abstract" away the high level stuff into a check function just + // below, and then switch based on the magic for how to process. + var ( + in = rp.RecordBatches + + r readerFrom + length int32 + lengthField *int32 + crcField *int32 + crcTable *crc32.Table + crcAt int + + check = func() bool { + if err := r.ReadFrom(in[:length]); err != nil { + return false + } + if length := int32(len(in[12:length])); length != *lengthField { + fp.Err = fmt.Errorf("encoded length %d does not match read length %d", *lengthField, length) + return false + } + if crcCalc := int32(crc32.Checksum(in[crcAt:length], crcTable)); crcCalc != *crcField { + fp.Err = fmt.Errorf("encoded crc %x does not match calculated crc %x", *crcField, crcCalc) + return false + } + return true } - o.processV0Messages(&fp, msgs, decompressor) + ) - case 2, 3: - msgs, err := kmsg.ReadV1Messages(rp.RecordBatches) - if err != nil { - fp.Err = err + for len(in) > 17 && fp.Err == nil { + length = int32(binary.BigEndian.Uint32(in[8:])) + length += 12 // for the int64 offset we skipped and int32 length field itself + if len(in) < int(length) { + break } - o.processV1Messages(&fp, msgs, decompressor) - default: - batches, err := kmsg.ReadRecordBatches(rp.RecordBatches) - if err != nil { - fp.Err = err + switch magic := in[16]; magic { + case 0: + m := new(kmsg.MessageV0) + lengthField = &m.MessageSize + crcField = &m.CRC + crcTable = crc32.IEEETable + crcAt = 16 + r = m + case 1: + m := new(kmsg.MessageV1) + lengthField = &m.MessageSize + crcField = &m.CRC + crcTable = crc32.IEEETable + crcAt = 16 + r = m + case 2: + rb := new(kmsg.RecordBatch) + lengthField = &rb.Length + crcField = &rb.CRC + crcTable = crc32c + crcAt = 21 + r = rb + } - var numPartitionRecords int - for i := range batches { - numPartitionRecords += int(batches[i].NumRecords) + + if !check() { + break } - fp.Records = make([]*Record, 0, numPartitionRecords) - aborter := buildAborter(rp) - for i := range batches { - o.processRecordBatch(&fp, &batches[i], aborter, decompressor) - if fp.Err != nil { - break - } + + in = in[length:] + + switch t := r.(type) { + case *kmsg.MessageV0: + o.processV0OuterMessage(&fp, t, decompressor) + case *kmsg.MessageV1: + o.processV1OuterMessage(&fp, t, decompressor) + case *kmsg.RecordBatch: + o.processRecordBatch(&fp, t, aborter, decompressor) } } @@ -843,6 +904,24 @@ func (a aborter) trackAbortedPID(producerID int64) { // processing records to fetch part // ////////////////////////////////////// +// readRawRecords reads n records from in and returns them, returning +// kbin.ErrNotEnoughData if in does not contain enough data. +func readRawRecords(n int, in []byte) ([]kmsg.Record, error) { + rs := make([]kmsg.Record, n) + for i := 0; i < n; i++ { + length, used := kbin.Varint(in) + total := used + int(length) + if used == 0 || length < 0 || len(in) < total { + return nil, kbin.ErrNotEnoughData + } + if err := (&rs[i]).ReadFrom(in[:total]); err != nil { + return nil, err + } + in = in[total:] + } + return rs, nil +} + func (o *cursorOffsetNext) processRecordBatch( fp *FetchPartition, batch *kmsg.RecordBatch, @@ -870,7 +949,7 @@ func (o *cursorOffsetNext) processRecordBatch( return } - krecords, err := kmsg.ReadRecords(int(batch.NumRecords), rawRecords) + krecords, err := readRawRecords(int(batch.NumRecords), rawRecords) if err != nil { fp.Err = fmt.Errorf("invalid record batch: %v", err) return @@ -903,36 +982,88 @@ func (o *cursorOffsetNext) processRecordBatch( } } -func (o *cursorOffsetNext) processV1Messages( +// Processes an outer v1 message. There could be no inner message, which makes +// this easy, but if not, we decompress and process each inner message as +// either v0 or v1. We only expect the inner message to be v1, but technically +// a crazy pipeline could have v0 anywhere. +func (o *cursorOffsetNext) processV1OuterMessage( fp *FetchPartition, - messages []kmsg.MessageV1, + message *kmsg.MessageV1, decompressor *decompressor, ) { - for i := range messages { - message := &messages[i] - compression := byte(message.Attributes & 0x0003) - if compression == 0 { - if !o.processV1Message(fp, message) { - return - } - continue + compression := byte(message.Attributes & 0x0003) + if compression == 0 { + o.processV1Message(fp, message) + return + } + + rawInner, err := decompressor.decompress(message.Value, compression) + if err != nil { + fp.Err = fmt.Errorf("unable to decompress messages: %v", err) + return + } + + var innerMessages []readerFrom + for len(rawInner) > 17 { // magic at byte 17 + length := int32(binary.BigEndian.Uint32(rawInner[8:])) + length += 12 // skip offset and length fields + if len(rawInner) < int(length) { + break } - rawMessages, err := decompressor.decompress(message.Value, compression) - if err != nil { - fp.Err = fmt.Errorf("unable to decompress messages: %v", err) - return + var ( + magic = rawInner[16] + + msg readerFrom + lengthField *int32 + crcField *int32 + ) + + switch magic { + case 0: + m := new(kmsg.MessageV0) + msg = m + lengthField = &m.MessageSize + crcField = &m.CRC + case 1: + m := new(kmsg.MessageV1) + msg = m + lengthField = &m.MessageSize + crcField = &m.CRC + + default: + fp.Err = fmt.Errorf("message set v1 has inner message with invalid magic %d", magic) + break } - innerMessages, err := kmsg.ReadV1Messages(rawMessages) - if err != nil { - fp.Err = err + + if err := msg.ReadFrom(rawInner[:length]); err != nil { + break } - if len(innerMessages) == 0 { - return + if length := int32(len(rawInner[12:length])); length != *lengthField { + fp.Err = fmt.Errorf("encoded length %d does not match read length %d", *lengthField, length) + break } - firstOffset := message.Offset - int64(len(innerMessages)) + 1 - for i := range innerMessages { - innerMessage := &innerMessages[i] + if crcCalc := int32(crc32.ChecksumIEEE(rawInner[16:length])); crcCalc != *crcField { + fp.Err = fmt.Errorf("encoded crc %x does not match calculated crc %x", *crcField, crcCalc) + break + } + innerMessages = append(innerMessages, msg) + rawInner = rawInner[length:] + } + if len(innerMessages) == 0 { + return + } + + firstOffset := message.Offset - int64(len(innerMessages)) + 1 + for i := range innerMessages { + innerMessage := innerMessages[i] + switch innerMessage := innerMessage.(type) { + case *kmsg.MessageV0: + innerMessage.Offset = firstOffset + int64(i) + if !o.processV0Message(fp, innerMessage) { + return + } + case *kmsg.MessageV1: innerMessage.Offset = firstOffset + int64(i) if !o.processV1Message(fp, innerMessage) { return @@ -958,40 +1089,57 @@ func (o *cursorOffsetNext) processV1Message( return true } -func (o *cursorOffsetNext) processV0Messages( +// Processes an outer v0 message. We expect inner messages to be entirely v0 as +// well, so this only tries v0 always. +func (o *cursorOffsetNext) processV0OuterMessage( fp *FetchPartition, - messages []kmsg.MessageV0, + message *kmsg.MessageV0, decompressor *decompressor, ) { - for i := range messages { - message := &messages[i] - compression := byte(message.Attributes & 0x0003) - if compression == 0 { - if !o.processV0Message(fp, message) { - return - } - continue - } + compression := byte(message.Attributes & 0x0003) + if compression == 0 { + o.processV0Message(fp, message) + return + } - rawMessages, err := decompressor.decompress(message.Value, compression) - if err != nil { - fp.Err = fmt.Errorf("unable to decompress messages: %v", err) - return + rawInner, err := decompressor.decompress(message.Value, compression) + if err != nil { + fp.Err = fmt.Errorf("unable to decompress messages: %v", err) + return + } + + var innerMessages []kmsg.MessageV0 + for len(rawInner) > 17 { // magic at byte 17 + length := int32(binary.BigEndian.Uint32(rawInner[8:])) + length += 12 // skip offset and length fields + if len(rawInner) < int(length) { + break } - innerMessages, err := kmsg.ReadV0Messages(rawMessages) - if err != nil { - fp.Err = err + var m kmsg.MessageV0 + if err := m.ReadFrom(rawInner[:length]); err != nil { + break } - if len(innerMessages) == 0 { - return + if length := int32(len(rawInner[12:length])); length != m.MessageSize { + fp.Err = fmt.Errorf("encoded length %d does not match read length %d", m.MessageSize, length) + break } - firstOffset := message.Offset - int64(len(innerMessages)) + 1 - for i := range innerMessages { - innerMessage := &innerMessages[i] - innerMessage.Offset = firstOffset + int64(i) - if !o.processV0Message(fp, innerMessage) { - return - } + if crcCalc := int32(crc32.ChecksumIEEE(rawInner[16:length])); crcCalc != m.CRC { + fp.Err = fmt.Errorf("encoded crc %x does not match calculated crc %x", m.CRC, crcCalc) + break + } + innerMessages = append(innerMessages, m) + rawInner = rawInner[length:] + } + if len(innerMessages) == 0 { + return + } + + firstOffset := message.Offset - int64(len(innerMessages)) + 1 + for i := range innerMessages { + innerMessage := &innerMessages[i] + innerMessage.Offset = firstOffset + int64(i) + if !o.processV0Message(fp, innerMessage) { + return } } } diff --git a/pkg/kmsg/interface.go b/pkg/kmsg/interface.go index 6712400e..89f30946 100644 --- a/pkg/kmsg/interface.go +++ b/pkg/kmsg/interface.go @@ -34,9 +34,6 @@ package kmsg import ( "context" - "encoding/binary" - "errors" - "hash/crc32" "github.com/twmb/franz-go/pkg/kbin" ) @@ -221,124 +218,6 @@ func StringPtr(in string) *string { return &in } -// ReadRecords reads n records from in and returns them, returning -// kerr.ErrNotEnoughData if in does not contain enough data. -func ReadRecords(n int, in []byte) ([]Record, error) { - rs := make([]Record, n) - for i := 0; i < n; i++ { - length, used := kbin.Varint(in) - total := used + int(length) - if used == 0 || length < 0 || len(in) < total { - return nil, kbin.ErrNotEnoughData - } - if err := (&rs[i]).ReadFrom(in[:total]); err != nil { - return nil, err - } - in = in[total:] - } - return rs, nil -} - -// ErrEncodedCRCMismatch is returned from reading record batches or message sets when -// any batch or set has an encoded crc that does not match a calculated crc. -var ErrEncodedCRCMismatch = errors.New("encoded crc does not match calculated crc") - -// ErrEncodedLengthMismatch is returned from reading record batches or message -// sets when any batch or set has an encoded length that does not match the -// earlier read length of the batch / set. -var ErrEncodedLengthMismatch = errors.New("encoded length does not match read length") - -var crc32c = crc32.MakeTable(crc32.Castagnoli) // record crc's use Castagnoli table - -// ReadRecordBatches reads as many record batches as possible from in, -// discarding any final trailing record batch. This is intended to be used -// for processing RecordBatches from a FetchResponse, where Kafka, as an -// internal optimization, may include a partial final RecordBatch. -func ReadRecordBatches(in []byte) ([]RecordBatch, error) { - var bs []RecordBatch - for len(in) > 12 { - length := int32(binary.BigEndian.Uint32(in[8:])) - length += 12 - if len(in) < int(length) { - return bs, nil - } - - var b RecordBatch - if err := b.ReadFrom(in[:length]); err != nil { - return bs, nil - } - - if int32(len(in[12:length])) != b.Length { - return bs, ErrEncodedLengthMismatch - } - - // If we did not error, the length was at _least_ 21. - if int32(crc32.Checksum(in[21:length], crc32c)) != b.CRC { - return bs, ErrEncodedCRCMismatch - } - - bs = append(bs, b) - in = in[length:] - } - return bs, nil -} - -// ReadV1Messages reads as many v1 message sets as possible from -// in, discarding any final trailing message set. This is intended to be used -// for processing v1 MessageSets from a FetchResponse, where Kafka, as an -// internal optimization, may include a partial final MessageSet. -func ReadV1Messages(in []byte) ([]MessageV1, error) { - var ms []MessageV1 - for len(in) > 12 { - length := int32(binary.BigEndian.Uint32(in[8:])) - length += 12 - if len(in) < int(length) { - return ms, nil - } - var m MessageV1 - if err := m.ReadFrom(in[:length]); err != nil { - return ms, nil - } - if int32(len(in[12:length])) != m.MessageSize { - return ms, ErrEncodedLengthMismatch - } - if int32(crc32.ChecksumIEEE(in[16:length])) != m.CRC { - return ms, ErrEncodedCRCMismatch - } - ms = append(ms, m) - in = in[length:] - } - return ms, nil -} - -// ReadV0Messages reads as many v0 message sets as possible from -// in, discarding any final trailing message set. This is intended to be used -// for processing v0 MessageSets from a FetchResponse, where Kafka, as an -// internal optimization, may include a partial final MessageSet. -func ReadV0Messages(in []byte) ([]MessageV0, error) { - var ms []MessageV0 - for len(in) > 12 { - length := int32(binary.BigEndian.Uint32(in[8:])) - length += 12 - if len(in) < int(length) { - return ms, nil - } - var m MessageV0 - if err := m.ReadFrom(in[:length]); err != nil { - return ms, nil - } - if int32(len(in[12:length])) != m.MessageSize { - return ms, ErrEncodedLengthMismatch - } - if int32(crc32.ChecksumIEEE(in[16:length])) != m.CRC { - return ms, ErrEncodedCRCMismatch - } - ms = append(ms, m) - in = in[length:] - } - return ms, nil -} - // ReadFrom provides decoding various versions of sticky member metadata. A key // point of this type is that it does not contain a version number inside it, // but it is versioned: if decoding v1 fails, this falls back to v0.