Skip to content

Commit

Permalink
Merge pull request #473 from fxamacker/fxamacker/refactor-map-encode-…
Browse files Browse the repository at this point in the history
…to-prepare-for-go-verion-bump

Refactor map encoding to prep for Go version bump
  • Loading branch information
fxamacker authored Jan 23, 2024
2 parents 23ec2c5 + 4918974 commit e5eaf7a
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 36 deletions.
2 changes: 1 addition & 1 deletion decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7514,7 +7514,7 @@ func TestUnmarshalToInterface(t *testing.T) {
if err != nil {
t.Errorf("Marshal(%+v) returned error %v", tc.v, err)
} else if !bytes.Equal(data, tc.data) {
t.Errorf("Marshal(%+v) = 0x%x, want 0x%x", tc.v, data, tc.v)
t.Errorf("Marshal(%+v) = 0x%x, want 0x%x", tc.v, data, tc.data)
}

// Unmarshal to empty interface
Expand Down
55 changes: 20 additions & 35 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -973,8 +973,13 @@ func (ae arrayEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value)
return nil
}

// encodeKeyValueFunc encodes key/value pairs in map (v).
// If kvs is provided (having the same length as v), length of encoded key and value are stored in kvs.
// kvs is used for canonical encoding of map.
type encodeKeyValueFunc func(e *encoderBuffer, em *encMode, v reflect.Value, kvs []keyValue) error

type mapEncodeFunc struct {
kf, ef encodeFunc
e encodeKeyValueFunc
}

func (me mapEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value) error {
Expand All @@ -993,16 +998,8 @@ func (me mapEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value) e
return me.encodeCanonical(e, em, v)
}
encodeHead(e, byte(cborTypeMap), uint64(mlen))
iter := v.MapRange()
for iter.Next() {
if err := me.kf(e, em, iter.Key()); err != nil {
return err
}
if err := me.ef(e, em, iter.Value()); err != nil {
return err
}
}
return nil

return me.e(e, em, v, nil)
}

type keyValue struct {
Expand Down Expand Up @@ -1071,26 +1068,17 @@ func putKeyValues(x *[]keyValue) {
}

func (me mapEncodeFunc) encodeCanonical(e *encoderBuffer, em *encMode, v reflect.Value) error {
kve := getEncoderBuffer() // accumulated cbor encoded key-values
kve := getEncoderBuffer() // accumulated cbor encoded key-values
defer putEncoderBuffer(kve)

kvsp := getKeyValues(v.Len()) // for sorting keys
defer putKeyValues(kvsp)

kvs := *kvsp
iter := v.MapRange()
for i := 0; iter.Next(); i++ {
off := kve.Len()
if err := me.kf(kve, em, iter.Key()); err != nil {
putEncoderBuffer(kve)
putKeyValues(kvsp)
return err
}
n1 := kve.Len() - off
if err := me.ef(kve, em, iter.Value()); err != nil {
putEncoderBuffer(kve)
putKeyValues(kvsp)
return err
}
n2 := kve.Len() - off
// Save key and keyvalue length to create slice later.
kvs[i] = keyValue{keyLen: n1, keyValueLen: n2}

err := me.e(kve, em, v, kvs)
if err != nil {
return err
}

b := kve.Bytes()
Expand All @@ -1111,8 +1099,6 @@ func (me mapEncodeFunc) encodeCanonical(e *encoderBuffer, em *encMode, v reflect
e.Write(kvs[i].keyValueCBORData)
}

putEncoderBuffer(kve)
putKeyValues(kvsp)
return nil
}

Expand Down Expand Up @@ -1463,12 +1449,11 @@ func getEncodeFuncInternal(t reflect.Type) (encodeFunc, isEmptyFunc) {
}
return arrayEncodeFunc{f: f}.encode, isEmptySlice
case reflect.Map:
kf, _ := getEncodeFunc(t.Key())
ef, _ := getEncodeFunc(t.Elem())
if kf == nil || ef == nil {
f := getEncodeMapFunc(t)
if f == nil {
return nil, nil
}
return mapEncodeFunc{kf: kf, ef: ef}.encode, isEmptyMap
return f, isEmptyMap
case reflect.Struct:
// Get struct's special field "_" tag options
if f, ok := t.FieldByName("_"); ok {
Expand Down
49 changes: 49 additions & 0 deletions encode_map_go117.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) Faye Amacker. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

package cbor

import (
"reflect"
)

type mapKeyValueEncodeFunc struct {
kf, ef encodeFunc
}

func (me *mapKeyValueEncodeFunc) encodeKeyValues(e *encoderBuffer, em *encMode, v reflect.Value, kvs []keyValue) error {
trackKeyValueLength := len(kvs) == v.Len()

iter := v.MapRange()
for i := 0; iter.Next(); i++ {
off := e.Len()

if err := me.kf(e, em, iter.Key()); err != nil {
return err
}
if trackKeyValueLength {
kvs[i].keyLen = e.Len() - off
}

if err := me.ef(e, em, iter.Value()); err != nil {
return err
}
if trackKeyValueLength {
kvs[i].keyValueLen = e.Len() - off
}
}

return nil
}

func getEncodeMapFunc(t reflect.Type) encodeFunc {
kf, _ := getEncodeFunc(t.Key())
ef, _ := getEncodeFunc(t.Elem())
if kf == nil || ef == nil {
return nil
}
mkv := &mapKeyValueEncodeFunc{kf: kf, ef: ef}
return mapEncodeFunc{
e: mkv.encodeKeyValues,
}.encode
}

0 comments on commit e5eaf7a

Please sign in to comment.