Skip to content

Commit

Permalink
Refactor sorted map encode to use fewer buffers for nested maps.
Browse files Browse the repository at this point in the history
Signed-off-by: Ben Luddy <[email protected]>
  • Loading branch information
benluddy committed May 19, 2024
1 parent 367b524 commit 3d0c715
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 72 deletions.
103 changes: 52 additions & 51 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -1217,25 +1217,58 @@ func (me mapEncodeFunc) encode(e *bytes.Buffer, em *encMode, v reflect.Value) er
if mlen == 0 {
return e.WriteByte(byte(cborTypeMap))
}
switch em.sort {
case SortNone, SortFastShuffle:
default:
if mlen > 1 {
return me.encodeCanonical(e, em, v)
}
}

encodeHead(e, byte(cborTypeMap), uint64(mlen))
if em.sort == SortNone || em.sort == SortFastShuffle || mlen <= 1 {
return me.e(e, em, v, nil)
}

kvsp := getKeyValues(v.Len()) // for sorting keys
defer putKeyValues(kvsp)
kvs := *kvsp

kvBeginOffset := e.Len()
if err := me.e(e, em, v, kvs); err != nil {
return err
}
kvTotalLen := e.Len() - kvBeginOffset

// Use the capacity at the tail of the encode buffer as a staging area to rearrange the
// encoded pairs into sorted order.
e.Grow(kvTotalLen)
tmp := e.Bytes()[e.Len() : e.Len()+kvTotalLen] // Can use e.AvailableBuffer() in Go 1.21+.
dst := e.Bytes()[kvBeginOffset:]

if em.sort == SortBytewiseLexical {
sort.Sort(&bytewiseKeyValueSorter{kvs: kvs, data: dst})
} else {
sort.Sort(&lengthFirstKeyValueSorter{kvs: kvs, data: dst})
}

// This is where the encoded bytes are actually rearranged in the output buffer to reflect
// the desired order.
sortedOffset := 0
for _, kv := range kvs {
copy(tmp[sortedOffset:], dst[kv.offset:kv.nextOffset])
sortedOffset += kv.nextOffset - kv.offset
}
copy(dst, tmp[:kvTotalLen])

return nil

return me.e(e, em, v, nil)
}

// keyValue is the position of an encoded pair in a buffer. All offsets are zero-based and relative
// to the first byte of the first encoded pair.
type keyValue struct {
keyCBORData, keyValueCBORData []byte
keyLen, keyValueLen int
offset int
valueOffset int
nextOffset int
}

type bytewiseKeyValueSorter struct {
kvs []keyValue
kvs []keyValue
data []byte
}

func (x *bytewiseKeyValueSorter) Len() int {
Expand All @@ -1247,11 +1280,13 @@ func (x *bytewiseKeyValueSorter) Swap(i, j int) {
}

func (x *bytewiseKeyValueSorter) Less(i, j int) bool {
return bytes.Compare(x.kvs[i].keyCBORData, x.kvs[j].keyCBORData) <= 0
kvi, kvj := x.kvs[i], x.kvs[j]
return bytes.Compare(x.data[kvi.offset:kvi.nextOffset], x.data[kvj.offset:kvj.nextOffset]) <= 0
}

type lengthFirstKeyValueSorter struct {
kvs []keyValue
kvs []keyValue
data []byte
}

func (x *lengthFirstKeyValueSorter) Len() int {
Expand All @@ -1263,10 +1298,11 @@ func (x *lengthFirstKeyValueSorter) Swap(i, j int) {
}

func (x *lengthFirstKeyValueSorter) Less(i, j int) bool {
if len(x.kvs[i].keyCBORData) != len(x.kvs[j].keyCBORData) {
return len(x.kvs[i].keyCBORData) < len(x.kvs[j].keyCBORData)
kvi, kvj := x.kvs[i], x.kvs[j]
if keyLengthDifference := (kvi.nextOffset - kvi.offset) - (kvj.nextOffset - kvj.offset); keyLengthDifference != 0 {
return keyLengthDifference < 0
}
return bytes.Compare(x.kvs[i].keyCBORData, x.kvs[j].keyCBORData) <= 0
return bytes.Compare(x.data[kvi.offset:kvi.nextOffset], x.data[kvj.offset:kvj.nextOffset]) <= 0
}

var keyValuePool = sync.Pool{}
Expand Down Expand Up @@ -1294,41 +1330,6 @@ func putKeyValues(x *[]keyValue) {
keyValuePool.Put(x)
}

func (me mapEncodeFunc) encodeCanonical(e *bytes.Buffer, em *encMode, v reflect.Value) error {
kve := getEncodeBuffer() // accumulated cbor encoded key-values
defer putEncodeBuffer(kve)

kvsp := getKeyValues(v.Len()) // for sorting keys
defer putKeyValues(kvsp)

kvs := *kvsp

err := me.e(kve, em, v, kvs)
if err != nil {
return err
}

b := kve.Bytes()
for i, off := 0, 0; i < len(kvs); i++ {
kvs[i].keyCBORData = b[off : off+kvs[i].keyLen]
kvs[i].keyValueCBORData = b[off : off+kvs[i].keyValueLen]
off += kvs[i].keyValueLen
}

if em.sort == SortBytewiseLexical {
sort.Sort(&bytewiseKeyValueSorter{kvs})
} else {
sort.Sort(&lengthFirstKeyValueSorter{kvs})
}

encodeHead(e, byte(cborTypeMap), uint64(len(kvs)))
for i := 0; i < len(kvs); i++ {
e.Write(kvs[i].keyValueCBORData)
}

return nil
}

func encodeStructToArray(e *bytes.Buffer, em *encMode, v reflect.Value) (err error) {
structType, err := getEncodingStructType(v.Type())
if err != nil {
Expand Down
34 changes: 24 additions & 10 deletions encode_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ type mapKeyValueEncodeFunc struct {
}

func (me *mapKeyValueEncodeFunc) encodeKeyValues(e *bytes.Buffer, em *encMode, v reflect.Value, kvs []keyValue) error {
trackKeyValueLength := len(kvs) == v.Len()
iterk := me.kpool.Get().(*reflect.Value)
defer func() {
iterk.SetZero()
Expand All @@ -28,24 +27,39 @@ func (me *mapKeyValueEncodeFunc) encodeKeyValues(e *bytes.Buffer, em *encMode, v
iterv.SetZero()
me.vpool.Put(iterv)
}()
iter := v.MapRange()
for i := 0; iter.Next(); i++ {
off := e.Len()

if kvs == nil {
for i, iter := 0, v.MapRange(); iter.Next(); i++ {
iterk.SetIterKey(iter)
iterv.SetIterValue(iter)

if err := me.kf(e, em, *iterk); err != nil {
return err
}
if err := me.ef(e, em, *iterv); err != nil {
return err
}
}
return nil
}

initial := e.Len()
for i, iter := 0, v.MapRange(); iter.Next(); i++ {
iterk.SetIterKey(iter)
iterv.SetIterValue(iter)

offset := e.Len()
if err := me.kf(e, em, *iterk); err != nil {
return err
}
if trackKeyValueLength {
kvs[i].keyLen = e.Len() - off
}

valueOffset := e.Len()
if err := me.ef(e, em, *iterv); err != nil {
return err
}
if trackKeyValueLength {
kvs[i].keyValueLen = e.Len() - off
kvs[i] = keyValue{
offset: offset - initial,
valueOffset: valueOffset - initial,
nextOffset: e.Len() - initial,
}
}

Expand Down
30 changes: 19 additions & 11 deletions encode_map_go117.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,32 @@ type mapKeyValueEncodeFunc struct {
}

func (me *mapKeyValueEncodeFunc) encodeKeyValues(e *bytes.Buffer, em *encMode, v reflect.Value, kvs []keyValue) error {
trackKeyValueLength := len(kvs) == v.Len()

iter := v.MapRange()
for i := 0; iter.Next(); i++ {
off := e.Len()
if kvs == nil {
for i, iter := 0, v.MapRange(); iter.Next(); i++ {
if err := me.kf(e, em, iter.Key()); err != nil {
return err
}
if err := me.ef(e, em, iter.Value()); err != nil {
return err
}
}
return nil
}

initial := e.Len()
for i, iter := 0, v.MapRange(); iter.Next(); i++ {
offset := e.Len()
if err := me.kf(e, em, iter.Key()); err != nil {
return err
}
if trackKeyValueLength {
kvs[i].keyLen = e.Len() - off
}

valueOffset := e.Len()
if err := me.ef(e, em, iter.Value()); err != nil {
return err
}
if trackKeyValueLength {
kvs[i].keyValueLen = e.Len() - off
kvs[i] = keyValue{
offset: offset - initial,
valueOffset: valueOffset - initial,
nextOffset: e.Len() - initial,
}
}

Expand Down

0 comments on commit 3d0c715

Please sign in to comment.