Skip to content

Commit

Permalink
Improve encoding speed, reduce mem, refactor
Browse files Browse the repository at this point in the history
Speed/mem improvements:
- Replace range loop with for loop to avoid mem copy when encoding struct.
- Change encodeState's scratch field from [64]byte to [16]byte.
- Change and reorder "field" struct's fields to reduce struct size.

Refactor:
- Rename varaibles and functions.
- Decouple encoding struct type cache and decoding struct type cache.
- Others.

benchmark                                                                      old ns/op     new ns/op     delta
BenchmarkMarshal/Go_struct_to_CBOR_map-2                                       828           728           -12.08%
  • Loading branch information
fxamacker committed Nov 11, 2019
1 parent d030867 commit 8ea465d
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 81 deletions.
2 changes: 1 addition & 1 deletion decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ func (d *decodeState) parseMap(t cborType, count int, v reflect.Value) error {
}

func (d *decodeState) parseStruct(t cborType, count int, v reflect.Value) error {
flds := getStructFields(v.Type())
flds := getDecodingStructType(v.Type())
foundFlds := make([]bool, len(flds))

hasSize := count >= 0
Expand Down
62 changes: 30 additions & 32 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ type EncOptions struct {
// An encodeState encodes CBOR into a bytes.Buffer.
type encodeState struct {
bytes.Buffer
scratch [64]byte
scratch [16]byte
}

// encodeStatePool caches unused encodeState objects for later reuse.
Expand All @@ -74,12 +74,12 @@ var encodeStatePool = sync.Pool{
},
}

func newEncodeState() *encodeState {
func getEncodeState() *encodeState {
return encodeStatePool.Get().(*encodeState)
}

// returnEncodeState returns e to encodeStatePool.
func returnEncodeState(e *encodeState) {
// putEncodeState returns e to encodeStatePool.
func putEncodeState(e *encodeState) {
e.Reset()
encodeStatePool.Put(e)
}
Expand Down Expand Up @@ -279,7 +279,7 @@ func (v byCanonical) Less(i, j int) bool {

var byCanonicalPool = sync.Pool{}

func newByCanonical(length int) *byCanonical {
func getByCanonical(length int) *byCanonical {
v := byCanonicalPool.Get()
if v == nil {
return &byCanonical{pairs: make([]pair, 0, length)}
Expand All @@ -295,7 +295,7 @@ func newByCanonical(length int) *byCanonical {
return s
}

func returnByCanonical(s *byCanonical) {
func putByCanonical(s *byCanonical) {
s.pairs = s.pairs[:0]
byCanonicalPool.Put(s)
}
Expand All @@ -312,21 +312,21 @@ func encodeMapCanonical(e *encodeState, v reflect.Value, opts EncOptions) (int,
if v.Len() == 0 {
return encodeTypeAndAdditionalValue(e, byte(cborTypeMap), uint64(0)), nil
}
pairEncodeState := newEncodeState() // accumulated cbor encoded map key-value pairs
pairs := newByCanonical(v.Len()) // for sorting keys
pairEncodeState := getEncodeState() // accumulated cbor encoded map key-value pairs
pairs := getByCanonical(v.Len()) // for sorting keys

iter := v.MapRange()
for iter.Next() {
n1, err := kf(pairEncodeState, iter.Key(), opts)
if err != nil {
returnEncodeState(pairEncodeState)
returnByCanonical(pairs)
putEncodeState(pairEncodeState)
putByCanonical(pairs)
return 0, err
}
n2, err := ef(pairEncodeState, iter.Value(), opts)
if err != nil {
returnEncodeState(pairEncodeState)
returnByCanonical(pairs)
putEncodeState(pairEncodeState)
putByCanonical(pairs)
return 0, err
}
pairs.pairs = append(pairs.pairs, pair{keyLen: n1, pairLen: n1 + n2})
Expand All @@ -346,36 +346,34 @@ func encodeMapCanonical(e *encodeState, v reflect.Value, opts EncOptions) (int,
n += n1
}

returnEncodeState(pairEncodeState)
returnByCanonical(pairs)
putEncodeState(pairEncodeState)
putByCanonical(pairs)
return n, nil
}

func encodeStruct(e *encodeState, v reflect.Value, opts EncOptions) (int, error) {
flds := getEncodingStructFields(v.Type(), opts.Canonical)
flds := getEncodingStructType(v.Type(), opts.Canonical)

kve := newEncodeState() // encode key-value pairs based on struct field tag options
kve := getEncodeState() // encode key-value pairs based on struct field tag options
kvcount := 0
for _, f := range flds {
if f.ef == nil {
for i := 0; i < len(flds); i++ {
if flds[i].ef == nil {
return 0, &UnsupportedTypeError{v.Type()}
}
fv, err := fieldByIndex(v, f.idx)
fv, err := fieldByIndex(v, flds[i].idx)
if err != nil {
returnEncodeState(kve)
putEncodeState(kve)
return 0, err
}
if !fv.IsValid() || (f.omitempty && isEmptyValue(fv)) {
if !fv.IsValid() || (flds[i].omitempty && isEmptyValue(fv)) {
continue
}
if f.cborNameLen > 0 {
kve.Write(f.cborName[:f.cborNameLen])
} else {
encodeStringInternal(kve, f.name, opts)
}
_, err = f.ef(kve, fv, opts)

kve.Write(flds[i].cborName)

_, err = flds[i].ef(kve, fv, opts)
if err != nil {
returnEncodeState(kve)
putEncodeState(kve)
return 0, err
}
kvcount++
Expand All @@ -384,7 +382,7 @@ func encodeStruct(e *encodeState, v reflect.Value, opts EncOptions) (int, error)
n := encodeTypeAndAdditionalValue(e, byte(cborTypeMap), uint64(kvcount))
n1, err := e.Write(kve.Bytes())

returnEncodeState(kve)
putEncodeState(kve)
return n + n1, err
}

Expand Down Expand Up @@ -588,17 +586,17 @@ func isEmptyValue(v reflect.Value) bool {
//
// Marshal supports RFC 7049 and CTAP2 canonical CBOR encoding.
func Marshal(v interface{}, encOpts EncOptions) ([]byte, error) {
e := newEncodeState()
e := getEncodeState()

err := e.marshal(v, encOpts)
if err != nil {
returnEncodeState(e)
putEncodeState(e)
return nil, err
}

buf := make([]byte, e.Len())
copy(buf, e.Bytes())

returnEncodeState(e)
putEncodeState(e)
return buf, nil
}
84 changes: 36 additions & 48 deletions structfields.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package cbor

import (
"bytes"
"errors"
"reflect"
"sort"
Expand Down Expand Up @@ -33,12 +34,11 @@ func fieldByIndex(v reflect.Value, index []int) (reflect.Value, error) {

type field struct {
name string
cborName []byte
idx []int
typ reflect.Type
isUnmarshaler bool
ef encodeFunc
cborName [64]byte
cborNameLen int
isUnmarshaler bool
tagged bool // used to choose dominant field (at the same level tagged fields dominate untagged fields)
omitempty bool // used to skip empty field
}
Expand Down Expand Up @@ -99,29 +99,21 @@ type byCanonicalRule struct {
}

func (s byCanonicalRule) Less(i, j int) bool {
if len(s.fields[i].name) != len(s.fields[j].name) {
return len(s.fields[i].name) < len(s.fields[j].name)
}
return s.fields[i].name <= s.fields[j].name
return bytes.Compare(s.fields[i].cborName, s.fields[j].cborName) <= 0
}

func getFieldNameOptionsFromTag(tag string) (string, string) {
func getFieldNameAndOptionsFromTag(tag string) (name string, omitEmpty bool) {
if len(tag) == 0 {
return "", ""
return
}
commaIdx := strings.IndexByte(tag, byte(','))
if commaIdx == -1 {
return tag, ""
}
return tag[:commaIdx], tag[commaIdx+1:]
}

func hasFieldOptionFromTag(options, key string) bool {
idx := strings.Index(options, key)
if idx == -1 {
return false
tokens := strings.Split(tag, ",")
name = tokens[0]
for _, s := range tokens[1:] {
if s == "omitempty" {
omitEmpty = true
}
}
return idx+len(key) == len(options) || options[idx+len(key)] == ','
return
}

// getFields returns a list of visible fields of struct type typ following Go
Expand Down Expand Up @@ -190,8 +182,7 @@ func getFields(typ reflect.Type) fields {
idx[len(fieldIdx)] = i

tagged := len(tag) > 0
tagFieldName, tagOptions := getFieldNameOptionsFromTag(tag)
omitempty := hasFieldOptionFromTag(tagOptions, "omitempty")
tagFieldName, omitempty := getFieldNameAndOptionsFromTag(tag)

fieldName := tagFieldName
if tagFieldName == "" {
Expand Down Expand Up @@ -230,60 +221,57 @@ func getFields(typ reflect.Type) fields {
return visibleFields
}

type structFields struct {
type encodingStructType struct {
typ reflect.Type
fields fields
canonicalFields fields
}

var (
cachedStructFields sync.Map
cachedEncodingStructFields sync.Map
decodingStructTypeCache sync.Map
encodingStructTypeCache sync.Map
)

func getStructFields(t reflect.Type) fields {
if v, _ := cachedStructFields.Load(t); v != nil {
func getDecodingStructType(t reflect.Type) fields {
if v, _ := decodingStructTypeCache.Load(t); v != nil {
return v.(fields)
}
flds := getFields(t)
for i := 0; i < len(flds); i++ {
flds[i].isUnmarshaler = implementsUnmarshaler(flds[i].typ)
}
cachedStructFields.Store(t, flds)
decodingStructTypeCache.Store(t, flds)
return flds
}

func getEncodingStructFields(t reflect.Type, canonical bool) fields {
if v, _ := cachedEncodingStructFields.Load(t); v != nil {
func getEncodingStructType(t reflect.Type, canonical bool) fields {
if v, _ := encodingStructTypeCache.Load(t); v != nil {
if canonical {
return v.(structFields).canonicalFields
return v.(encodingStructType).canonicalFields
}
return v.(structFields).fields
return v.(encodingStructType).fields
}

fldsOrig := getStructFields(t)
flds := make(fields, len(fldsOrig))
copy(flds, fldsOrig)
flds := getFields(t)

es := getEncodeState()
for i := 0; i < len(flds); i++ {
flds[i].ef = getEncodeFunc(flds[i].typ)
nameLen := len(flds[i].name)
if nameLen <= 23 {
flds[i].cborName[0] = byte(cborTypeTextString) | byte(nameLen)
copy(flds[i].cborName[1:], flds[i].name)
flds[i].cborNameLen = 1 + nameLen
} else if nameLen <= 62 {
flds[i].cborName[0] = byte(cborTypeTextString) | byte(24)
flds[i].cborName[1] = byte(nameLen)
copy(flds[i].cborName[2:], flds[i].name)
flds[i].cborNameLen = 2 + nameLen
}

encodeTypeAndAdditionalValue(es, byte(cborTypeTextString), uint64(len(flds[i].name)))
flds[i].cborName = make([]byte, es.Len()+len(flds[i].name))
copy(flds[i].cborName, es.Bytes())
copy(flds[i].cborName[es.Len():], flds[i].name)

es.Reset()
}
putEncodeState(es)

canonicalFields := make(fields, len(flds))
copy(canonicalFields, flds)
sort.Sort(byCanonicalRule{canonicalFields})

cachedEncodingStructFields.Store(t, structFields{typ: t, fields: flds, canonicalFields: canonicalFields})
encodingStructTypeCache.Store(t, encodingStructType{typ: t, fields: flds, canonicalFields: canonicalFields})

if canonical {
return canonicalFields
Expand Down

0 comments on commit 8ea465d

Please sign in to comment.