Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions pkg/scale/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,47 @@ import (
"bytes"
"encoding/binary"
"fmt"
"io"
"math/big"
"reflect"
)

// Encoder scale encodes to a given io.Writer.
type Encoder struct {
encodeState
}

// NewEncoder creates a new encoder with the given writer.
func NewEncoder(writer io.Writer) (encoder *Encoder) {
return &Encoder{
encodeState: encodeState{
Writer: writer,
},
}
}

// Encode scale encodes value to the encoder writer.
func (e *Encoder) Encode(value interface{}) (err error) {
return e.marshal(value)
}

// Marshal takes in an interface{} and attempts to marshal into []byte
func Marshal(v interface{}) (b []byte, err error) {
buffer := bytes.NewBuffer(nil)
es := encodeState{
Writer: buffer,
fieldScaleIndicesCache: cache,
}
err = es.marshal(v)
if err != nil {
return
}
b = es.Bytes()
b = buffer.Bytes()
return
}

type encodeState struct {
bytes.Buffer
io.Writer
*fieldScaleIndicesCache
}

Expand Down Expand Up @@ -64,9 +86,9 @@ func (es *encodeState) marshal(in interface{}) (err error) {
elem := reflect.ValueOf(in).Elem()
switch elem.IsValid() {
case false:
err = es.WriteByte(0)
_, err = es.Write([]byte{0})
default:
err = es.WriteByte(1)
_, err = es.Write([]byte{1})
if err != nil {
return
}
Expand Down Expand Up @@ -133,13 +155,13 @@ func (es *encodeState) encodeResult(res Result) (err error) {
var in interface{}
switch res.mode {
case OK:
err = es.WriteByte(0)
_, err = es.Write([]byte{0})
if err != nil {
return
}
in = res.ok
case Err:
err = es.WriteByte(1)
_, err = es.Write([]byte{1})
if err != nil {
return
}
Expand All @@ -159,7 +181,7 @@ func (es *encodeState) encodeCustomVaryingDataType(in interface{}) (err error) {
}

func (es *encodeState) encodeVaryingDataType(vdt VaryingDataType) (err error) {
err = es.WriteByte(byte(vdt.value.Index()))
_, err = es.Write([]byte{byte(vdt.value.Index())})
if err != nil {
return
}
Expand Down
142 changes: 109 additions & 33 deletions pkg/scale/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,48 @@
package scale

import (
"bytes"
"math/big"
"reflect"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Comment thread
qdm12 marked this conversation as resolved.
)

func Test_NewEncoder(t *testing.T) {
t.Parallel()

writer := bytes.NewBuffer(nil)
encoder := NewEncoder(writer)

expectedEncoder := &Encoder{
encodeState: encodeState{
Writer: writer,
},
}

assert.Equal(t, expectedEncoder, encoder)
}

func Test_Encoder_Encode(t *testing.T) {
t.Parallel()

buffer := bytes.NewBuffer(nil)
encoder := NewEncoder(buffer)

err := encoder.Encode(uint16(1))
require.NoError(t, err)

err = encoder.Encode(uint8(2))
require.NoError(t, err)

written := buffer.Bytes()
expectedWritten := []byte{1, 0, 2}
assert.Equal(t, expectedWritten, written)
}

type test struct {
name string
in interface{}
Expand Down Expand Up @@ -869,12 +905,15 @@ type MyStructWithPrivate struct {
func Test_encodeState_encodeFixedWidthInteger(t *testing.T) {
for _, tt := range fixedWidthIntegerTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -883,12 +922,15 @@ func Test_encodeState_encodeFixedWidthInteger(t *testing.T) {
func Test_encodeState_encodeVariableWidthIntegers(t *testing.T) {
for _, tt := range variableWidthIntegerTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -897,12 +939,15 @@ func Test_encodeState_encodeVariableWidthIntegers(t *testing.T) {
func Test_encodeState_encodeBigInt(t *testing.T) {
for _, tt := range bigIntTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeBigInt() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeBigInt() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeBigInt() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -911,12 +956,15 @@ func Test_encodeState_encodeBigInt(t *testing.T) {
func Test_encodeState_encodeUint128(t *testing.T) {
for _, tt := range uint128Tests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeUin128() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeUin128() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeUin128() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -925,12 +973,16 @@ func Test_encodeState_encodeUint128(t *testing.T) {
func Test_encodeState_encodeBytes(t *testing.T) {
for _, tt := range stringTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{}

buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeBytes() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeBytes() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeBytes() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -939,12 +991,16 @@ func Test_encodeState_encodeBytes(t *testing.T) {
func Test_encodeState_encodeBool(t *testing.T) {
for _, tt := range boolTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{}

buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeBool() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeBool() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeBool() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -953,12 +1009,16 @@ func Test_encodeState_encodeBool(t *testing.T) {
func Test_encodeState_encodeStruct(t *testing.T) {
for _, tt := range structTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{fieldScaleIndicesCache: cache}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
fieldScaleIndicesCache: cache,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeStruct() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeStruct() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeStruct() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -967,12 +1027,16 @@ func Test_encodeState_encodeStruct(t *testing.T) {
func Test_encodeState_encodeSlice(t *testing.T) {
for _, tt := range sliceTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{fieldScaleIndicesCache: cache}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
fieldScaleIndicesCache: cache,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeSlice() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeSlice() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeSlice() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -981,12 +1045,16 @@ func Test_encodeState_encodeSlice(t *testing.T) {
func Test_encodeState_encodeArray(t *testing.T) {
for _, tt := range arrayTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{fieldScaleIndicesCache: cache}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
fieldScaleIndicesCache: cache,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeArray() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeArray() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeArray() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand All @@ -1007,12 +1075,16 @@ func Test_marshal_optionality(t *testing.T) {
}
for _, tt := range ptrTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{fieldScaleIndicesCache: cache}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
fieldScaleIndicesCache: cache,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand Down Expand Up @@ -1043,12 +1115,16 @@ func Test_marshal_optionality_nil_cases(t *testing.T) {
}
for _, tt := range ptrTests {
t.Run(tt.name, func(t *testing.T) {
es := &encodeState{fieldScaleIndicesCache: cache}
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
fieldScaleIndicesCache: cache,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want)
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
Expand Down
Loading