From 9e247e09acff0c3163715bf33ae78c57490ffc90 Mon Sep 17 00:00:00 2001 From: Ben Luddy <bluddy@redhat.com> Date: Fri, 3 Nov 2023 12:38:31 -0400 Subject: [PATCH] Add FieldNameMatching decode option. When decoding a CBOR map into a Go struct, FieldNameMatching controls how string keys are matched to struct fields in the destination struct and allows users to require case-sensitive matches. The default value of this option preserves the existing behavior, which prefers case-sensitive matches but will fall back to a case-insensitive match. Signed-off-by: Ben Luddy <bluddy@redhat.com> --- decode.go | 31 +++++++++++++++- decode_test.go | 99 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 1 deletion(-) diff --git a/decode.go b/decode.go index c0ee14ab..65be7850 100644 --- a/decode.go +++ b/decode.go @@ -353,6 +353,26 @@ func (um UTF8Mode) valid() bool { return um < maxUTF8Mode } +// FieldNameMatchingMode specifies how string keys in CBOR maps are matched to Go struct field names. +type FieldNameMatchingMode int + +const ( + // FieldNameMatchingPreferCaseSensitive prefers to decode map items into struct fields whose names (or tag + // names) exactly match the item's key. If there is no such field, a map item will be decoded into a field whose + // name is a case-insensitive match for the item's key. + FieldNameMatchingPreferCaseSensitive = iota + + // FieldNameMatchingCaseSensitive decodes map items only into a struct field whose name (or tag name) is an + // exact match for the item's key. + FieldNameMatchingCaseSensitive + + maxFieldNameMatchingMode +) + +func (fnmm FieldNameMatchingMode) valid() bool { + return fnmm >= 0 && fnmm < maxFieldNameMatchingMode +} + // DecOptions specifies decoding options. type DecOptions struct { // DupMapKey specifies whether to enforce duplicate map key. @@ -402,6 +422,9 @@ type DecOptions struct { // UTF8 specifies if decoder should decode CBOR Text containing invalid UTF-8. // By default, unmarshal rejects CBOR text containing invalid UTF-8. UTF8 UTF8Mode + + // FieldNameMatching specifies how string keys in CBOR maps are matched to Go struct field names. + FieldNameMatching FieldNameMatchingMode } // DecMode returns DecMode with immutable options and no tags (safe for concurrency). @@ -510,6 +533,9 @@ func (opts DecOptions) decMode() (*decMode, error) { if !opts.UTF8.valid() { return nil, errors.New("cbor: invalid UTF8 " + strconv.Itoa(int(opts.UTF8))) } + if !opts.FieldNameMatching.valid() { + return nil, errors.New("cbor: invalid FieldNameMatching " + strconv.Itoa(int(opts.FieldNameMatching))) + } dm := decMode{ dupMapKey: opts.DupMapKey, timeTag: opts.TimeTag, @@ -523,6 +549,7 @@ func (opts DecOptions) decMode() (*decMode, error) { extraReturnErrors: opts.ExtraReturnErrors, defaultMapType: opts.DefaultMapType, utf8: opts.UTF8, + fieldNameMatching: opts.FieldNameMatching, } return &dm, nil } @@ -587,6 +614,7 @@ type decMode struct { extraReturnErrors ExtraDecErrorCond defaultMapType reflect.Type utf8 UTF8Mode + fieldNameMatching FieldNameMatchingMode } var defaultDecMode, _ = DecOptions{}.decMode() @@ -605,6 +633,7 @@ func (dm *decMode) DecOptions() DecOptions { MapKeyByteString: dm.mapKeyByteString, ExtraReturnErrors: dm.extraReturnErrors, UTF8: dm.utf8, + FieldNameMatching: dm.fieldNameMatching, } } @@ -1681,7 +1710,7 @@ func (d *decoder) parseMapToStruct(v reflect.Value, tInfo *typeInfo) error { //n } } // Find field with case-insensitive match - if f == nil { + if f == nil && d.dm.fieldNameMatching == FieldNameMatchingPreferCaseSensitive { keyString := string(keyBytes) for i := 0; i < len(structType.fields); i++ { fld := structType.fields[i] diff --git a/decode_test.go b/decode_test.go index ea3d84d4..e9dc84d3 100644 --- a/decode_test.go +++ b/decode_test.go @@ -6038,3 +6038,102 @@ func TestUnmarshalFirstInvalidItem(t *testing.T) { t.Errorf("UnmarshalFirst(0x%x) = (%x, %v), want (nil, err)", invalidCBOR, rest, err) } } + +func TestDecModeInvalidFieldNameMatchingMode(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + wantErrorMsg string + }{ + { + name: "below range of valid modes", + opts: DecOptions{FieldNameMatching: -1}, + wantErrorMsg: "cbor: invalid FieldNameMatching -1", + }, + { + name: "above range of valid modes", + opts: DecOptions{FieldNameMatching: 101}, + wantErrorMsg: "cbor: invalid FieldNameMatching 101", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.DecMode() + if err == nil { + t.Errorf("DecMode() didn't return an error") + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg) + } + }) + } +} + +func TestDecodeFieldNameMatching(t *testing.T) { + type s struct { + LowerA int `cbor:"a"` + UpperB int `cbor:"B"` + LowerB int `cbor:"b"` + } + + testCases := []struct { + name string + opts DecOptions + cborData []byte + wantValue s + }{ + { + name: "case-insensitive match", + cborData: hexDecode("a1614101"), // {"A": 1} + wantValue: s{LowerA: 1}, + }, + { + name: "ignore case-insensitive match", + opts: DecOptions{FieldNameMatching: FieldNameMatchingCaseSensitive}, + cborData: hexDecode("a1614101"), // {"A": 1} + wantValue: s{}, + }, + { + name: "exact match before case-insensitive match", + cborData: hexDecode("a2616101614102"), // {"a": 1, "A": 2} + wantValue: s{LowerA: 1}, + }, + { + name: "case-insensitive match before exact match", + cborData: hexDecode("a2614101616102"), // {"A": 1, "a": 2} + wantValue: s{LowerA: 1}, + }, + { + name: "ignore case-insensitive match before exact match", + opts: DecOptions{FieldNameMatching: FieldNameMatchingCaseSensitive}, + cborData: hexDecode("a2614101616102"), // {"A": 1, "a": 2} + wantValue: s{LowerA: 2}, + }, + { + name: "earliest exact match wins", + opts: DecOptions{FieldNameMatching: FieldNameMatchingCaseSensitive}, + cborData: hexDecode("a2616101616102"), // {"a": 1, "a": 2} (invalid) + wantValue: s{LowerA: 1}, + }, + { + // the field tags themselves are case-insensitive matches for each other + name: "duplicate keys decode to different fields", + cborData: hexDecode("a2614201614202"), // {"B": 1, "B": 2} (invalid) + wantValue: s{UpperB: 1, LowerB: 2}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + decMode, _ := tc.opts.DecMode() + + var dst s + err := decMode.Unmarshal(tc.cborData, &dst) + if err != nil { + t.Fatalf("Unmarshal(0x%x) returned unexpected error %v", tc.cborData, err) + } + + if !reflect.DeepEqual(dst, tc.wantValue) { + t.Errorf("Unmarshal(0x%x) = %#v, want %#v", tc.cborData, dst, tc.wantValue) + } + }) + } +}