diff --git a/generate/gen.go b/generate/gen.go index 5eee0586..77b15a62 100644 --- a/generate/gen.go +++ b/generate/gen.go @@ -172,11 +172,6 @@ func (s Struct) WriteAppend(l *LineWriter) { l.Write("if isFlexible {") defer l.Write("}") - if len(tags) == 0 { - l.Write("dst = append(dst, 0)") - return - } - var tagsCanDefault bool for i := 0; i < len(tags); i++ { f, exists := tags[i] @@ -188,6 +183,8 @@ func (s Struct) WriteAppend(l *LineWriter) { } } + defer l.Write("dst = v.UnknownTags.AppendEach(dst)") + if tagsCanDefault { l.Write("var toEncode []uint32") for i := 0; i < len(tags); i++ { @@ -199,7 +196,12 @@ func (s Struct) WriteAppend(l *LineWriter) { if !has { def = d.GetTypeDefault() } - l.Write("if v.%s != %v {", f.FieldName, def) + switch f.Type.(type) { + case Struct: + l.Write("if !reflect.DeepEqual(v.%s, %v) {", f.FieldName, def) + default: + l.Write("if v.%s != %v {", f.FieldName, def) + } } l.Write("toEncode = append(toEncode, %d)", i) if canDefault { @@ -207,13 +209,13 @@ func (s Struct) WriteAppend(l *LineWriter) { } } - l.Write("dst = kbin.AppendUvarint(dst, uint32(len(toEncode)))") + l.Write("dst = kbin.AppendUvarint(dst, uint32(len(toEncode) + v.UnknownTags.Len()))") l.Write("for _, tag := range toEncode {") l.Write("switch tag {") defer l.Write("}") defer l.Write("}") } else { - l.Write("dst = kbin.AppendUvarint(dst, %d)", len(tags)) + l.Write("dst = kbin.AppendUvarint(dst, %d + uint32(v.UnknownTags.Len()))", len(tags)) } for i := 0; i < len(tags); i++ { @@ -492,7 +494,7 @@ func (s Struct) WriteDecode(l *LineWriter) { l.Write("if isFlexible {") if len(tags) == 0 { - l.Write("SkipTags(&b)") + l.Write("s.UnknownTags = ReadTags(&b)") l.Write("}") return } @@ -501,11 +503,11 @@ func (s Struct) WriteDecode(l *LineWriter) { l.Write("for i := b.Uvarint(); i > 0; i-- {") defer l.Write("}") - l.Write("switch b.Uvarint() {") + l.Write("switch key := b.Uvarint(); key {") defer l.Write("}") l.Write("default:") - l.Write("b.Span(int(b.Uvarint()))") // unknown tag + l.Write("s.UnknownTags.Set(key, b.Span(int(b.Uvarint())))") for i := 0; i < len(tags); i++ { f, exists := tags[i] @@ -566,7 +568,7 @@ func (s Struct) WriteDefn(l *LineWriter) { l.Write("type %s struct {", s.Name) if s.TopLevel { // Top level messages always have a Version field. - l.Write("\t// Version is the version of this message used with a Kafka broker.") + l.Write("// Version is the version of this message used with a Kafka broker.") l.Write("Version int16") l.Write("") } @@ -591,6 +593,16 @@ func (s Struct) WriteDefn(l *LineWriter) { l.Write("") // blank between fields } } + if s.FlexibleAt >= 0 { + l.Write("") + l.Write("// UnknownTags are tags Kafka sent that we do not know the purpose of.") + if s.FlexibleAt == 0 { + l.Write("UnknownTags Tags") + } else { + l.Write("UnknownTags Tags // v%d+", s.FlexibleAt) + } + l.Write("") + } l.Write("}") } diff --git a/generate/main.go b/generate/main.go index 27d2d15f..d89f2b66 100644 --- a/generate/main.go +++ b/generate/main.go @@ -391,6 +391,7 @@ func main() { l.Write("package kmsg") l.Write("import (") l.Write(`"context"`) + l.Write(`"reflect"`) l.Write("") l.Write(`"github.com/twmb/franz-go/pkg/kbin"`) l.Write(")")