diff --git a/internal/trie/node/decode_test.go b/internal/trie/node/decode_test.go index 31b3e4c102..b81be55ea1 100644 --- a/internal/trie/node/decode_test.go +++ b/internal/trie/node/decode_test.go @@ -166,7 +166,7 @@ func Test_decodeBranch(t *testing.T) { variant: branchVariant.bits, partialKeyLength: 1, errWrapped: ErrDecodeChildHash, - errMessage: "cannot decode child hash: at index 10: EOF", + errMessage: "cannot decode child hash: at index 10: reading byte: EOF", }, "success for branch variant": { reader: bytes.NewBuffer( @@ -203,7 +203,7 @@ func Test_decodeBranch(t *testing.T) { variant: branchWithValueVariant.bits, partialKeyLength: 1, errWrapped: ErrDecodeValue, - errMessage: "cannot decode value: EOF", + errMessage: "cannot decode value: reading byte: EOF", }, "success for branch with value": { reader: bytes.NewBuffer(concatByteSlices([][]byte{ @@ -333,7 +333,7 @@ func Test_decodeLeaf(t *testing.T) { variant: leafVariant.bits, partialKeyLength: 1, errWrapped: ErrDecodeValue, - errMessage: "cannot decode value: could not decode invalid integer", + errMessage: "cannot decode value: unknown prefix for compact uint: 255", }, "zero value": { reader: bytes.NewBuffer([]byte{ diff --git a/lib/runtime/version_test.go b/lib/runtime/version_test.go index 412941995e..88050b7b5e 100644 --- a/lib/runtime/version_test.go +++ b/lib/runtime/version_test.go @@ -39,7 +39,7 @@ func Test_DecodeVersion(t *testing.T) { {255, 255}, // error }), errWrapped: ErrDecodingVersionField, - errMessage: "decoding version field impl name: could not decode invalid integer", + errMessage: "decoding version field impl name: unknown prefix for compact uint: 255", }, // TODO add transaction version decode error once // https://github.com/ChainSafe/gossamer/pull/2683 diff --git a/pkg/scale/decode.go b/pkg/scale/decode.go index 4e427e1b35..e10cf4c4e6 100644 --- a/pkg/scale/decode.go +++ b/pkg/scale/decode.go @@ -335,7 +335,7 @@ func (ds *decodeState) decodeVaryingDataTypeSlice(dstv reflect.Value) (err error if err != nil { return } - for i := 0; i < l; i++ { + for i := uint(0); i < l; i++ { vdt := vdts.VaryingDataType vdtv := reflect.New(reflect.TypeOf(vdt)) vdtv.Elem().Set(reflect.ValueOf(vdt)) @@ -397,7 +397,7 @@ func (ds *decodeState) decodeSlice(dstv reflect.Value) (err error) { } in := dstv.Interface() temp := reflect.New(reflect.ValueOf(in).Type()) - for i := 0; i < l; i++ { + for i := uint(0); i < l; i++ { tempElemType := reflect.TypeOf(in).Elem() tempElem := reflect.New(tempElemType).Elem() @@ -478,59 +478,90 @@ func (ds *decodeState) decodeBool(dstv reflect.Value) (err error) { // decodeUint will decode unsigned integer func (ds *decodeState) decodeUint(dstv reflect.Value) (err error) { - b, err := ds.ReadByte() + const maxUint32 = ^uint32(0) + const maxUint64 = ^uint64(0) + prefix, err := ds.ReadByte() if err != nil { - return + return fmt.Errorf("reading byte: %w", err) } in := dstv.Interface() temp := reflect.New(reflect.TypeOf(in)) // check mode of encoding, stored at 2 least significant bits - mode := b & 3 - switch { - case mode <= 2: - var val int64 - val, err = ds.decodeSmallInt(b, mode) + mode := prefix % 4 + var value uint64 + switch mode { + case 0: + value = uint64(prefix >> 2) + case 1: + buf, err := ds.ReadByte() if err != nil { - return + return fmt.Errorf("reading byte: %w", err) } - temp.Elem().Set(reflect.ValueOf(val).Convert(reflect.TypeOf(in))) - dstv.Set(temp.Elem()) - default: - // >4 byte mode - topSixBits := b >> 2 - byteLen := uint(topSixBits) + 4 - + value = uint64(binary.LittleEndian.Uint16([]byte{prefix, buf}) >> 2) + if value <= 0b0011_1111 || value > 0b0111_1111_1111_1111 { + return fmt.Errorf("%w: %d (%b)", ErrU16OutOfRange, value, value) + } + case 2: + buf := make([]byte, 3) + _, err = ds.Read(buf) + if err != nil { + return fmt.Errorf("reading bytes: %w", err) + } + value = uint64(binary.LittleEndian.Uint32(append([]byte{prefix}, buf...)) >> 2) + if value <= 0b0011_1111_1111_1111 || value > uint64(maxUint32>>2) { + return fmt.Errorf("%w: %d (%b)", ErrU32OutOfRange, value, value) + } + case 3: + byteLen := (prefix >> 2) + 4 buf := make([]byte, byteLen) _, err = ds.Read(buf) if err != nil { - return + return fmt.Errorf("reading bytes: %w", err) } - - var o uint64 - if byteLen == 4 { - o = uint64(binary.LittleEndian.Uint32(buf)) - } else if byteLen > 4 && byteLen <= 8 { + switch byteLen { + case 4: + value = uint64(binary.LittleEndian.Uint32(buf)) + if value <= uint64(maxUint32>>2) { + return fmt.Errorf("%w: %d (%b)", ErrU32OutOfRange, value, value) + } + case 8: + const uintSize = 32 << (^uint(0) >> 32 & 1) + if uintSize == 32 { + return ErrU64NotSupported + } tmp := make([]byte, 8) copy(tmp, buf) - o = binary.LittleEndian.Uint64(tmp) - } else { - err = errors.New("could not decode invalid integer") - return + value = binary.LittleEndian.Uint64(tmp) + if value <= maxUint64>>8 { + return fmt.Errorf("%w: %d (%b)", ErrU64OutOfRange, value, value) + } + default: + return fmt.Errorf("%w: %d", ErrCompactUintPrefixUnknown, prefix) + } - dstv.Set(reflect.ValueOf(o).Convert(reflect.TypeOf(in))) } + temp.Elem().Set(reflect.ValueOf(value).Convert(reflect.TypeOf(in))) + dstv.Set(temp.Elem()) return } +var ( + ErrU16OutOfRange = errors.New("uint16 out of range") + ErrU32OutOfRange = errors.New("uint32 out of range") + ErrU64OutOfRange = errors.New("uint64 out of range") + ErrU64NotSupported = errors.New("uint64 is not supported") + ErrCompactUintPrefixUnknown = errors.New("unknown prefix for compact uint") +) + // decodeLength is helper method which calls decodeUint and casts to int -func (ds *decodeState) decodeLength() (l int, err error) { +func (ds *decodeState) decodeLength() (l uint, err error) { dstv := reflect.New(reflect.TypeOf(l)) err = ds.decodeUint(dstv.Elem()) if err != nil { return } - l = dstv.Elem().Interface().(int) + l = dstv.Elem().Interface().(uint) return } diff --git a/pkg/scale/decode_test.go b/pkg/scale/decode_test.go index 669bddb3a0..da808d8514 100644 --- a/pkg/scale/decode_test.go +++ b/pkg/scale/decode_test.go @@ -11,6 +11,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/assert" ) func Test_decodeState_decodeFixedWidthInt(t *testing.T) { @@ -302,3 +303,101 @@ func Test_Decoder_Decode_MultipleCalls(t *testing.T) { }) } } + +func Test_decodeState_decodeUint(t *testing.T) { + t.Parallel() + decodeUint32Tests := tests{ + { + name: "int(1) mode 0", + in: uint32(1), + want: []byte{0x04}, + }, + { + name: "int(16383) mode 1", + in: int(16383), + want: []byte{0xfd, 0xff}, + }, + { + name: "int(1073741823) mode 2", + in: int(1073741823), + want: []byte{0xfe, 0xff, 0xff, 0xff}, + }, + { + name: "int(4294967295) mode 3", + in: int(4294967295), + want: []byte{0x3, 0xff, 0xff, 0xff, 0xff}, + }, + { + name: "myCustomInt(9223372036854775807) mode 3, 64bit", + in: myCustomInt(9223372036854775807), + want: []byte{19, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, + }, + { + name: "uint(overload)", + in: int(0), + want: []byte{0x07, 0x08, 0x09, 0x10, 0x0, 0x40}, + wantErr: true, + }, + { + name: "uint(16384) mode 2", + in: int(16384), + want: []byte{0x02, 0x00, 0x01, 0x0}, + }, + { + name: "uint(0) mode 1, error", + in: int(0), + want: []byte{0x01, 0x00}, + wantErr: true, + }, + { + name: "uint(0) mode 2, error", + in: int(0), + want: []byte{0x02, 0x00, 0x00, 0x0}, + wantErr: true, + }, + { + name: "uint(0) mode 3, error", + in: int(0), + want: []byte{0x03, 0x00, 0x00, 0x0}, + wantErr: true, + }, + { + name: "mode 3, 64bit, error", + in: int(0), + want: []byte{19, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + wantErr: true, + }, + { + name: "[]int{1 << 32, 2, 3, 1 << 32}", + in: uint(4), + want: []byte{0x10, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x0c, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01}, + }, + { + name: "[4]int{1 << 32, 2, 3, 1 << 32}", + in: [4]int{0, 0, 0, 0}, + want: []byte{0x07, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x0c, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01}, + wantErr: true, + }, + } + + for _, tt := range decodeUint32Tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + dst := reflect.New(reflect.TypeOf(tt.in)).Elem().Interface() + dstv := reflect.ValueOf(&dst) + elem := indirect(dstv) + + ds := decodeState{ + Reader: bytes.NewBuffer(tt.want), + } + err := ds.decodeUint(elem) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.in, dst) + }) + } +} diff --git a/pkg/scale/encode_test.go b/pkg/scale/encode_test.go index 4d56b40e77..fd0d6bc13d 100644 --- a/pkg/scale/encode_test.go +++ b/pkg/scale/encode_test.go @@ -176,6 +176,11 @@ var ( in: int(1), want: []byte{0x04}, }, + { + name: "int(42)", + in: int(42), + want: []byte{0xa8}, + }, { name: "int(16383)", in: int(16383), @@ -820,9 +825,11 @@ var ( want: []byte{0x10, 0x03, 0x00, 0x00, 0x00, 0x40, 0x08, 0x0c, 0x10}, }, { - name: "[]int{1 << 32, 2, 3, 1 << 32}", - in: []int{1 << 32, 2, 3, 1 << 32}, - want: []byte{0x10, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x0c, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01}, + name: "[]int64{1 << 32, 2, 3, 1 << 32}", + in: []int64{1 << 32, 2, 3, 1 << 32}, + want: []byte{0x10, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, + 0x00}, }, { name: "[]bool{true, false, true}", @@ -863,9 +870,11 @@ var ( want: []byte{0x03, 0x00, 0x00, 0x00, 0x40, 0x08, 0x0c, 0x10}, }, { - name: "[4]int{1 << 32, 2, 3, 1 << 32}", - in: [4]int{1 << 32, 2, 3, 1 << 32}, - want: []byte{0x07, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x0c, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01}, + name: "[4]int64{1 << 32, 2, 3, 1 << 32}", + in: [4]int64{1 << 32, 2, 3, 1 << 32}, + want: []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, + 0x00}, }, { name: "[3]bool{true, false, true}",