Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions lib/trie/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import (
"sync"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/scale"
"github.com/ChainSafe/gossamer/pkg/scale"
)

// node is the interface for trie methods
Expand Down Expand Up @@ -337,26 +337,28 @@ func (b *branch) decode(r io.Reader, header byte) (err error) {
return err
}

sd := &scale.Decoder{Reader: r}
sd := scale.NewDecoder(r)

if nodeType == 3 {
var value []byte
// branch w/ value
value, err := sd.Decode([]byte{})
err := sd.Decode(&value)
if err != nil {
return err
}
b.value = value.([]byte)
b.value = value
}

for i := 0; i < 16; i++ {
if (childrenBitmap[i/8]>>(i%8))&1 == 1 {
hash, err := sd.Decode([]byte{})
var hash []byte
err := sd.Decode(&hash)
if err != nil {
return err
}

b.children[i] = &leaf{
hash: hash.([]byte),
hash: hash,
}
}
}
Expand Down Expand Up @@ -386,14 +388,15 @@ func (l *leaf) decode(r io.Reader, header byte) (err error) {
return err
}

sd := &scale.Decoder{Reader: r}
value, err := sd.Decode([]byte{})
sd := scale.NewDecoder(r)
var value []byte
err = sd.Decode(&value)
if err != nil {
return err
}

if len(value.([]byte)) > 0 {
l.value = value.([]byte)
if len(value) > 0 {
l.value = value
}

l.dirty = true
Expand Down
14 changes: 5 additions & 9 deletions lib/trie/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"testing"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/scale"
"github.com/ChainSafe/gossamer/pkg/scale"

"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -160,14 +160,12 @@ func TestBranchEncode(t *testing.T) {
expected = append(expected, nibblesToKeyLE(b.key)...)
expected = append(expected, common.Uint16ToBytes(b.childrenBitmap())...)

buf := bytes.Buffer{}
encoder := &scale.Encoder{Writer: &buf}
_, err = encoder.Encode(b.value)
enc, err := scale.Marshal(b.value)
if err != nil {
t.Fatalf("Fail when encoding value with scale: %s", err)
}

expected = append(expected, buf.Bytes()...)
expected = append(expected, enc...)

for _, child := range b.children {
if child != nil {
Expand Down Expand Up @@ -207,14 +205,12 @@ func TestLeafEncode(t *testing.T) {
expected = append(expected, header...)
expected = append(expected, nibblesToKeyLE(l.key)...)

buf := bytes.Buffer{}
encoder := &scale.Encoder{Writer: &buf}
_, err = encoder.Encode(l.value)
enc, err := scale.Marshal(l.value)
if err != nil {
t.Fatalf("Fail when encoding value with scale: %s", err)
}

expected = append(expected, buf.Bytes()...)
expected = append(expected, enc...)

hasher := newHasher(false)
defer hasher.returnToPool()
Expand Down
46 changes: 42 additions & 4 deletions pkg/scale/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"math/big"
"reflect"
)
Expand Down Expand Up @@ -87,7 +89,7 @@ func Unmarshal(data []byte, dst interface{}) (err error) {
if err != nil {
return
}
ds.Buffer = *buf
ds.Reader = buf

err = ds.unmarshal(elem)
if err != nil {
Expand All @@ -96,8 +98,36 @@ func Unmarshal(data []byte, dst interface{}) (err error) {
return
}

// Decoder is used to decode from an io.Reader
type Decoder struct {
decodeState
}

// Decode accepts a pointer to a destination and decodes into supplied destination
func (d *Decoder) Decode(dst interface{}) (err error) {
dstv := reflect.ValueOf(dst)
if dstv.Kind() != reflect.Ptr || dstv.IsNil() {
err = fmt.Errorf("unsupported dst: %T, must be a pointer to a destination", dst)
return
}

elem := indirect(dstv)
if err != nil {
return
}
return d.unmarshal(elem)
}

// NewDecoder is constructor for Decoder
func NewDecoder(r io.Reader) (d *Decoder) {
d = &Decoder{
decodeState{r},
}
return
}

type decodeState struct {
bytes.Buffer
io.Reader
}

func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) {
Expand Down Expand Up @@ -230,6 +260,12 @@ func (ds *decodeState) decodeCustomPrimitive(dstv reflect.Value) (err error) {
return
}

func (ds *decodeState) ReadByte() (byte, error) {
b := make([]byte, 1) // make buffer
_, err := ds.Reader.Read(b) // read what's in the Decoder's underlying buffer to our new buffer b
return b[0], err
}

func (ds *decodeState) decodeResult(dstv reflect.Value) (err error) {
res := dstv.Interface().(Result)
var rb byte
Expand Down Expand Up @@ -263,7 +299,8 @@ func (ds *decodeState) decodeResult(dstv reflect.Value) (err error) {
}
dstv.Set(reflect.ValueOf(res))
default:
err = fmt.Errorf("unsupported Result value: %v, bytes: %v", rb, ds.Bytes())
bytes, _ := ioutil.ReadAll(ds.Reader)
err = fmt.Errorf("unsupported Result value: %v, bytes: %v", rb, bytes)
}
return
}
Expand Down Expand Up @@ -295,7 +332,8 @@ func (ds *decodeState) decodePointer(dstv reflect.Value) (err error) {
dstv.Set(tempElem)
}
default:
err = fmt.Errorf("unsupported Option value: %v, bytes: %v", rb, ds.Bytes())
bytes, _ := ioutil.ReadAll(ds.Reader)
err = fmt.Errorf("unsupported Option value: %v, bytes: %v", rb, bytes)
}
return
}
Expand Down
65 changes: 64 additions & 1 deletion pkg/scale/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package scale

import (
"bytes"
"math/big"
"reflect"
"testing"
Expand Down Expand Up @@ -189,7 +190,6 @@ func Test_unmarshal_optionality(t *testing.T) {
if diff != "" {
t.Errorf("decodeState.unmarshal() = %s", diff)
}

}
})
}
Expand Down Expand Up @@ -238,3 +238,66 @@ func Test_unmarshal_optionality_nil_case(t *testing.T) {
})
}
}

func Test_Decoder_Decode(t *testing.T) {
for _, tt := range newTests(fixedWidthIntegerTests, variableWidthIntegerTests, stringTests,
boolTests, sliceTests, arrayTests,
) {
t.Run(tt.name, func(t *testing.T) {
dst := reflect.New(reflect.TypeOf(tt.in)).Elem().Interface()
wantBuf := bytes.NewBuffer(tt.want)
d := NewDecoder(wantBuf)
if err := d.Decode(&dst); (err != nil) != tt.wantErr {
t.Errorf("Decoder.Decode() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(dst, tt.in) {
t.Errorf("Decoder.Decode() = %v, want %v", dst, tt.in)
}
})
}
}

func Test_Decoder_Decode_MultipleCalls(t *testing.T) {
tests := []struct {
name string
ins []interface{}
want []byte
wantErr []bool
}{
{
name: "int64 and []byte",
ins: []interface{}{int64(9223372036854775807), []byte{0x01}},
want: append([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, []byte{0x04, 0x01}...),
},
{
name: "eof error",
ins: []interface{}{int64(9223372036854775807), []byte{0x01}},
want: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
wantErr: []bool{false, true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := bytes.NewBuffer(tt.want)
d := NewDecoder(buf)

for i := range tt.ins {
in := tt.ins[i]
dst := reflect.New(reflect.TypeOf(in)).Elem().Interface()
var wantErr bool
if len(tt.wantErr) > i {
wantErr = tt.wantErr[i]
}
if err := d.Decode(&dst); (err != nil) != wantErr {
t.Errorf("Decoder.Decode() error = %v, wantErr %v", err, tt.wantErr[i])
return
}
if !wantErr && !reflect.DeepEqual(dst, in) {
t.Errorf("Decoder.Decode() = %v, want %v", dst, in)
return
}
}
})
}
}