Skip to content

Commit ee1a9c9

Browse files
committed
Add DecOptions.DefaultMapType
Add option to specify Go map type to decode to when unmarshalling to interface{}.
1 parent 1ca0c31 commit ee1a9c9

File tree

2 files changed

+223
-1
lines changed

2 files changed

+223
-1
lines changed

decode.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,11 @@ type DecOptions struct {
290290

291291
// ExtraReturnErrors specifies extra conditions that should be treated as errors.
292292
ExtraReturnErrors ExtraDecErrorCond
293+
294+
// DefaultMapType specifies Go map type to create and decode to
295+
// when unmarshalling CBOR into an empty interface value.
296+
// By default, unmarshal uses map[interface{}]interface{}.
297+
DefaultMapType reflect.Type
293298
}
294299

295300
// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
@@ -389,6 +394,9 @@ func (opts DecOptions) decMode() (*decMode, error) {
389394
if !opts.ExtraReturnErrors.valid() {
390395
return nil, errors.New("cbor: invalid ExtraReturnErrors " + strconv.Itoa(int(opts.ExtraReturnErrors)))
391396
}
397+
if opts.DefaultMapType != nil && opts.DefaultMapType.Kind() != reflect.Map {
398+
return nil, fmt.Errorf("cbor: invalid DefaultMapType %s", opts.DefaultMapType)
399+
}
392400
dm := decMode{
393401
dupMapKey: opts.DupMapKey,
394402
timeTag: opts.TimeTag,
@@ -399,6 +407,7 @@ func (opts DecOptions) decMode() (*decMode, error) {
399407
tagsMd: opts.TagsMd,
400408
intDec: opts.IntDec,
401409
extraReturnErrors: opts.ExtraReturnErrors,
410+
defaultMapType: opts.DefaultMapType,
402411
}
403412
return &dm, nil
404413
}
@@ -430,6 +439,7 @@ type decMode struct {
430439
tagsMd TagsMode
431440
intDec IntDecMode
432441
extraReturnErrors ExtraDecErrorCond
442+
defaultMapType reflect.Type
433443
}
434444

435445
var defaultDecMode, _ = DecOptions{}.decMode()
@@ -988,6 +998,14 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
988998
case cborTypeArray:
989999
return d.parseArray()
9901000
case cborTypeMap:
1001+
if d.dm.defaultMapType != nil {
1002+
m := reflect.New(d.dm.defaultMapType)
1003+
err := d.parseToValue(m, getTypeInfo(m.Elem().Type()))
1004+
if err != nil {
1005+
return nil, err
1006+
}
1007+
return m.Elem().Interface(), nil
1008+
}
9911009
return d.parseMap()
9921010
}
9931011
return nil, nil
@@ -1117,7 +1135,7 @@ func (d *decoder) parseArrayToArray(v reflect.Value, tInfo *typeInfo) error {
11171135
return err
11181136
}
11191137

1120-
func (d *decoder) parseMap() (map[interface{}]interface{}, error) {
1138+
func (d *decoder) parseMap() (interface{}, error) {
11211139
_, ai, val := d.getHead()
11221140
hasSize := (ai != 31)
11231141
count := int(val)

decode_test.go

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5400,3 +5400,207 @@ func TestUnmarshalTaggedDataToInterface(t *testing.T) {
54005400
t.Errorf("Unmarshal(0x%x) = %v, want %v", data, v2, v)
54015401
}
54025402
}
5403+
5404+
func TestDecModeInvalidDefaultMapType(t *testing.T) {
5405+
testCases := []struct {
5406+
name string
5407+
opts DecOptions
5408+
wantErrorMsg string
5409+
}{
5410+
{
5411+
name: "byte slice",
5412+
opts: DecOptions{DefaultMapType: reflect.TypeOf([]byte(nil))},
5413+
wantErrorMsg: "cbor: invalid DefaultMapType []uint8",
5414+
},
5415+
{
5416+
name: "int slice",
5417+
opts: DecOptions{DefaultMapType: reflect.TypeOf([]int(nil))},
5418+
wantErrorMsg: "cbor: invalid DefaultMapType []int",
5419+
},
5420+
{
5421+
name: "string",
5422+
opts: DecOptions{DefaultMapType: reflect.TypeOf("")},
5423+
wantErrorMsg: "cbor: invalid DefaultMapType string",
5424+
},
5425+
{
5426+
name: "unnamed struct type",
5427+
opts: DecOptions{DefaultMapType: reflect.TypeOf(struct{}{})},
5428+
wantErrorMsg: "cbor: invalid DefaultMapType struct {}",
5429+
},
5430+
}
5431+
for _, tc := range testCases {
5432+
t.Run(tc.name, func(t *testing.T) {
5433+
_, err := tc.opts.DecMode()
5434+
if err == nil {
5435+
t.Errorf("DecMode() didn't return an error")
5436+
} else if err.Error() != tc.wantErrorMsg {
5437+
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
5438+
}
5439+
})
5440+
}
5441+
}
5442+
5443+
func TestUnmarshalToDefaultMapType(t *testing.T) {
5444+
5445+
cborDataMapIntInt := hexDecode("a201020304") // {1: 2, 3: 4}
5446+
cborDataMapStringInt := hexDecode("a2616101616202") // {"a": 1, "b": 2}
5447+
cborDataArrayOfMapStringint := hexDecode("82a2616101616202a2616303616404") // [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
5448+
cborDataNestedMap := hexDecode("a268496e744669656c6401684d61704669656c64a2616101616202") // {"IntField": 1, "MapField": {"a": 1, "b": 2}}
5449+
5450+
decOptionsDefault := DecOptions{}
5451+
decOptionsMapIntfIntfType := DecOptions{DefaultMapType: reflect.TypeOf(map[interface{}]interface{}(nil))}
5452+
decOptionsMapStringIntType := DecOptions{DefaultMapType: reflect.TypeOf(map[string]int(nil))}
5453+
decOptionsMapStringIntfType := DecOptions{DefaultMapType: reflect.TypeOf(map[string]interface{}(nil))}
5454+
5455+
testCases := []struct {
5456+
name string
5457+
opts DecOptions
5458+
cborData []byte
5459+
wantValue interface{}
5460+
wantErrorMsg string
5461+
}{
5462+
// Decode CBOR map to map[interface{}]interface{} using default options
5463+
{
5464+
name: "decode CBOR map[int]int to Go map[interface{}]interface{} (default)",
5465+
opts: decOptionsDefault,
5466+
cborData: cborDataMapIntInt,
5467+
wantValue: map[interface{}]interface{}{uint64(1): uint64(2), uint64(3): uint64(4)},
5468+
},
5469+
{
5470+
name: "decode CBOR map[string]int to Go map[interface{}]interface{} (default)",
5471+
opts: decOptionsDefault,
5472+
cborData: cborDataMapStringInt,
5473+
wantValue: map[interface{}]interface{}{"a": uint64(1), "b": uint64(2)},
5474+
},
5475+
{
5476+
name: "decode CBOR array of map[string]int to Go []map[interface{}]interface{} (default)",
5477+
opts: decOptionsDefault,
5478+
cborData: cborDataArrayOfMapStringint,
5479+
wantValue: []interface{}{
5480+
map[interface{}]interface{}{"a": uint64(1), "b": uint64(2)},
5481+
map[interface{}]interface{}{"c": uint64(3), "d": uint64(4)},
5482+
},
5483+
},
5484+
{
5485+
name: "decode CBOR nested map to Go map[interface{}]interface{} (default)",
5486+
opts: decOptionsDefault,
5487+
cborData: cborDataNestedMap,
5488+
wantValue: map[interface{}]interface{}{
5489+
"IntField": uint64(1),
5490+
"MapField": map[interface{}]interface{}{"a": uint64(1), "b": uint64(2)},
5491+
},
5492+
},
5493+
// Decode CBOR map to map[interface{}]interface{} using default map type option
5494+
{
5495+
name: "decode CBOR map[int]int to Go map[interface{}]interface{}",
5496+
opts: decOptionsMapIntfIntfType,
5497+
cborData: cborDataMapIntInt,
5498+
wantValue: map[interface{}]interface{}{uint64(1): uint64(2), uint64(3): uint64(4)},
5499+
},
5500+
{
5501+
name: "decode CBOR map[string]int to Go map[interface{}]interface{}",
5502+
opts: decOptionsMapIntfIntfType,
5503+
cborData: cborDataMapStringInt,
5504+
wantValue: map[interface{}]interface{}{"a": uint64(1), "b": uint64(2)},
5505+
},
5506+
{
5507+
name: "decode CBOR array of map[string]int to Go []map[interface{}]interface{}",
5508+
opts: decOptionsMapIntfIntfType,
5509+
cborData: cborDataArrayOfMapStringint,
5510+
wantValue: []interface{}{
5511+
map[interface{}]interface{}{"a": uint64(1), "b": uint64(2)},
5512+
map[interface{}]interface{}{"c": uint64(3), "d": uint64(4)},
5513+
},
5514+
},
5515+
{
5516+
name: "decode CBOR nested map to Go map[interface{}]interface{}",
5517+
opts: decOptionsMapIntfIntfType,
5518+
cborData: cborDataNestedMap,
5519+
wantValue: map[interface{}]interface{}{
5520+
"IntField": uint64(1),
5521+
"MapField": map[interface{}]interface{}{"a": uint64(1), "b": uint64(2)},
5522+
},
5523+
},
5524+
// Decode CBOR map to map[string]interface{} using default map type option
5525+
{
5526+
name: "decode CBOR map[int]int to Go map[string]interface{}",
5527+
opts: decOptionsMapStringIntfType,
5528+
cborData: cborDataMapIntInt,
5529+
wantErrorMsg: "cbor: cannot unmarshal positive integer into Go value of type string",
5530+
},
5531+
{
5532+
name: "decode CBOR map[string]int to Go map[string]interface{}",
5533+
opts: decOptionsMapStringIntfType,
5534+
cborData: cborDataMapStringInt,
5535+
wantValue: map[string]interface{}{"a": uint64(1), "b": uint64(2)},
5536+
},
5537+
{
5538+
name: "decode CBOR array of map[string]int to Go []map[string]interface{}",
5539+
opts: decOptionsMapStringIntfType,
5540+
cborData: cborDataArrayOfMapStringint,
5541+
wantValue: []interface{}{
5542+
map[string]interface{}{"a": uint64(1), "b": uint64(2)},
5543+
map[string]interface{}{"c": uint64(3), "d": uint64(4)},
5544+
},
5545+
},
5546+
{
5547+
name: "decode CBOR nested map to Go map[string]interface{}",
5548+
opts: decOptionsMapStringIntfType,
5549+
cborData: cborDataNestedMap,
5550+
wantValue: map[string]interface{}{
5551+
"IntField": uint64(1),
5552+
"MapField": map[string]interface{}{"a": uint64(1), "b": uint64(2)},
5553+
},
5554+
},
5555+
// Decode CBOR map to map[string]int using default map type option
5556+
{
5557+
name: "decode CBOR map[int]int to Go map[string]int",
5558+
opts: decOptionsMapStringIntType,
5559+
cborData: cborDataMapIntInt,
5560+
wantErrorMsg: "cbor: cannot unmarshal positive integer into Go value of type string",
5561+
},
5562+
{
5563+
name: "decode CBOR map[string]int to Go map[string]int",
5564+
opts: decOptionsMapStringIntType,
5565+
cborData: cborDataMapStringInt,
5566+
wantValue: map[string]int{"a": 1, "b": 2},
5567+
},
5568+
{
5569+
name: "decode CBOR array of map[string]int to Go []map[string]int",
5570+
opts: decOptionsMapStringIntType,
5571+
cborData: cborDataArrayOfMapStringint,
5572+
wantValue: []interface{}{
5573+
map[string]int{"a": 1, "b": 2},
5574+
map[string]int{"c": 3, "d": 4},
5575+
},
5576+
},
5577+
{
5578+
name: "decode CBOR nested map to Go map[string]int",
5579+
opts: decOptionsMapStringIntType,
5580+
cborData: cborDataNestedMap,
5581+
wantErrorMsg: "cbor: cannot unmarshal map into Go value of type int",
5582+
},
5583+
}
5584+
5585+
for _, tc := range testCases {
5586+
t.Run(tc.name, func(t *testing.T) {
5587+
decMode, _ := tc.opts.DecMode()
5588+
5589+
var v interface{}
5590+
err := decMode.Unmarshal(tc.cborData, &v)
5591+
if err != nil {
5592+
if tc.wantErrorMsg == "" {
5593+
t.Errorf("Unmarshal(0x%x) to empty interface returned error %v", tc.cborData, err)
5594+
} else if tc.wantErrorMsg != err.Error() {
5595+
t.Errorf("Unmarshal(0x%x) error %q, want %q", tc.cborData, err.Error(), tc.wantErrorMsg)
5596+
}
5597+
} else {
5598+
if tc.wantValue == nil {
5599+
t.Errorf("Unmarshal(0x%x) = %v (%T), want error %q", tc.cborData, v, v, tc.wantErrorMsg)
5600+
} else if !reflect.DeepEqual(v, tc.wantValue) {
5601+
t.Errorf("Unmarshal(0x%x) = %v (%T), want %v (%T)", tc.cborData, v, v, tc.wantValue, tc.wantValue)
5602+
}
5603+
}
5604+
})
5605+
}
5606+
}

0 commit comments

Comments
 (0)