From 9e247e09acff0c3163715bf33ae78c57490ffc90 Mon Sep 17 00:00:00 2001
From: Ben Luddy <bluddy@redhat.com>
Date: Fri, 3 Nov 2023 12:38:31 -0400
Subject: [PATCH] Add FieldNameMatching decode option.

When decoding a CBOR map into a Go struct, FieldNameMatching controls how string
keys are matched to struct fields in the destination struct and allows users to
require case-sensitive matches. The default value of this option preserves the
existing behavior, which prefers case-sensitive matches but will fall back to a
case-insensitive match.

Signed-off-by: Ben Luddy <bluddy@redhat.com>
---
 decode.go      | 31 +++++++++++++++-
 decode_test.go | 99 ++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 129 insertions(+), 1 deletion(-)

diff --git a/decode.go b/decode.go
index c0ee14ab..65be7850 100644
--- a/decode.go
+++ b/decode.go
@@ -353,6 +353,26 @@ func (um UTF8Mode) valid() bool {
 	return um < maxUTF8Mode
 }
 
+// FieldNameMatchingMode specifies how string keys in CBOR maps are matched to Go struct field names.
+type FieldNameMatchingMode int
+
+const (
+	// FieldNameMatchingPreferCaseSensitive prefers to decode map items into struct fields whose names (or tag
+	// names) exactly match the item's key. If there is no such field, a map item will be decoded into a field whose
+	// name is a case-insensitive match for the item's key.
+	FieldNameMatchingPreferCaseSensitive = iota
+
+	// FieldNameMatchingCaseSensitive decodes map items only into a struct field whose name (or tag name) is an
+	// exact match for the item's key.
+	FieldNameMatchingCaseSensitive
+
+	maxFieldNameMatchingMode
+)
+
+func (fnmm FieldNameMatchingMode) valid() bool {
+	return fnmm >= 0 && fnmm < maxFieldNameMatchingMode
+}
+
 // DecOptions specifies decoding options.
 type DecOptions struct {
 	// DupMapKey specifies whether to enforce duplicate map key.
@@ -402,6 +422,9 @@ type DecOptions struct {
 	// UTF8 specifies if decoder should decode CBOR Text containing invalid UTF-8.
 	// By default, unmarshal rejects CBOR text containing invalid UTF-8.
 	UTF8 UTF8Mode
+
+	// FieldNameMatching specifies how string keys in CBOR maps are matched to Go struct field names.
+	FieldNameMatching FieldNameMatchingMode
 }
 
 // DecMode returns DecMode with immutable options and no tags (safe for concurrency).
@@ -510,6 +533,9 @@ func (opts DecOptions) decMode() (*decMode, error) {
 	if !opts.UTF8.valid() {
 		return nil, errors.New("cbor: invalid UTF8 " + strconv.Itoa(int(opts.UTF8)))
 	}
+	if !opts.FieldNameMatching.valid() {
+		return nil, errors.New("cbor: invalid FieldNameMatching " + strconv.Itoa(int(opts.FieldNameMatching)))
+	}
 	dm := decMode{
 		dupMapKey:         opts.DupMapKey,
 		timeTag:           opts.TimeTag,
@@ -523,6 +549,7 @@ func (opts DecOptions) decMode() (*decMode, error) {
 		extraReturnErrors: opts.ExtraReturnErrors,
 		defaultMapType:    opts.DefaultMapType,
 		utf8:              opts.UTF8,
+		fieldNameMatching: opts.FieldNameMatching,
 	}
 	return &dm, nil
 }
@@ -587,6 +614,7 @@ type decMode struct {
 	extraReturnErrors ExtraDecErrorCond
 	defaultMapType    reflect.Type
 	utf8              UTF8Mode
+	fieldNameMatching FieldNameMatchingMode
 }
 
 var defaultDecMode, _ = DecOptions{}.decMode()
@@ -605,6 +633,7 @@ func (dm *decMode) DecOptions() DecOptions {
 		MapKeyByteString:  dm.mapKeyByteString,
 		ExtraReturnErrors: dm.extraReturnErrors,
 		UTF8:              dm.utf8,
+		FieldNameMatching: dm.fieldNameMatching,
 	}
 }
 
@@ -1681,7 +1710,7 @@ func (d *decoder) parseMapToStruct(v reflect.Value, tInfo *typeInfo) error { //n
 				}
 			}
 			// Find field with case-insensitive match
-			if f == nil {
+			if f == nil && d.dm.fieldNameMatching == FieldNameMatchingPreferCaseSensitive {
 				keyString := string(keyBytes)
 				for i := 0; i < len(structType.fields); i++ {
 					fld := structType.fields[i]
diff --git a/decode_test.go b/decode_test.go
index ea3d84d4..e9dc84d3 100644
--- a/decode_test.go
+++ b/decode_test.go
@@ -6038,3 +6038,102 @@ func TestUnmarshalFirstInvalidItem(t *testing.T) {
 		t.Errorf("UnmarshalFirst(0x%x) = (%x, %v), want (nil, err)", invalidCBOR, rest, err)
 	}
 }
+
+func TestDecModeInvalidFieldNameMatchingMode(t *testing.T) {
+	for _, tc := range []struct {
+		name         string
+		opts         DecOptions
+		wantErrorMsg string
+	}{
+		{
+			name:         "below range of valid modes",
+			opts:         DecOptions{FieldNameMatching: -1},
+			wantErrorMsg: "cbor: invalid FieldNameMatching -1",
+		},
+		{
+			name:         "above range of valid modes",
+			opts:         DecOptions{FieldNameMatching: 101},
+			wantErrorMsg: "cbor: invalid FieldNameMatching 101",
+		},
+	} {
+		t.Run(tc.name, func(t *testing.T) {
+			_, err := tc.opts.DecMode()
+			if err == nil {
+				t.Errorf("DecMode() didn't return an error")
+			} else if err.Error() != tc.wantErrorMsg {
+				t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
+			}
+		})
+	}
+}
+
+func TestDecodeFieldNameMatching(t *testing.T) {
+	type s struct {
+		LowerA int `cbor:"a"`
+		UpperB int `cbor:"B"`
+		LowerB int `cbor:"b"`
+	}
+
+	testCases := []struct {
+		name      string
+		opts      DecOptions
+		cborData  []byte
+		wantValue s
+	}{
+		{
+			name:      "case-insensitive match",
+			cborData:  hexDecode("a1614101"), // {"A": 1}
+			wantValue: s{LowerA: 1},
+		},
+		{
+			name:      "ignore case-insensitive match",
+			opts:      DecOptions{FieldNameMatching: FieldNameMatchingCaseSensitive},
+			cborData:  hexDecode("a1614101"), // {"A": 1}
+			wantValue: s{},
+		},
+		{
+			name:      "exact match before case-insensitive match",
+			cborData:  hexDecode("a2616101614102"), // {"a": 1, "A": 2}
+			wantValue: s{LowerA: 1},
+		},
+		{
+			name:      "case-insensitive match before exact match",
+			cborData:  hexDecode("a2614101616102"), // {"A": 1, "a": 2}
+			wantValue: s{LowerA: 1},
+		},
+		{
+			name:      "ignore case-insensitive match before exact match",
+			opts:      DecOptions{FieldNameMatching: FieldNameMatchingCaseSensitive},
+			cborData:  hexDecode("a2614101616102"), // {"A": 1, "a": 2}
+			wantValue: s{LowerA: 2},
+		},
+		{
+			name:      "earliest exact match wins",
+			opts:      DecOptions{FieldNameMatching: FieldNameMatchingCaseSensitive},
+			cborData:  hexDecode("a2616101616102"), // {"a": 1, "a": 2} (invalid)
+			wantValue: s{LowerA: 1},
+		},
+		{
+			// the field tags themselves are case-insensitive matches for each other
+			name:      "duplicate keys decode to different fields",
+			cborData:  hexDecode("a2614201614202"), // {"B": 1, "B": 2} (invalid)
+			wantValue: s{UpperB: 1, LowerB: 2},
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			decMode, _ := tc.opts.DecMode()
+
+			var dst s
+			err := decMode.Unmarshal(tc.cborData, &dst)
+			if err != nil {
+				t.Fatalf("Unmarshal(0x%x) returned unexpected error %v", tc.cborData, err)
+			}
+
+			if !reflect.DeepEqual(dst, tc.wantValue) {
+				t.Errorf("Unmarshal(0x%x) = %#v, want %#v", tc.cborData, dst, tc.wantValue)
+			}
+		})
+	}
+}