Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rlp/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func EncodeToRawList[T any](val []T) (RawList[T], error) {
bytes := make([]byte, contentSize+9)
offset := 9 - headsize(uint64(contentSize))
buf.copyTo(bytes[offset:])
return RawList[T]{enc: bytes}, nil
return RawList[T]{enc: bytes, length: len(val)}, nil
}

type listhead struct {
Expand Down
9 changes: 0 additions & 9 deletions rlp/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,6 @@ func (it *Iterator) Next() bool {
return true
}

// Count returns the remaining number of items.
// Note this is O(n) and the result may be incorrect if the list data is invalid.
// The returned count is always an upper bound on the remaining items
// that will be visited by the iterator.
func (it *Iterator) Count() int {
count, _ := CountValues(it.data)
return count
}

// Value returns the current value.
func (it *Iterator) Value() []byte {
return it.next
Expand Down
33 changes: 24 additions & 9 deletions rlp/raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ type RawList[T any] struct {
// The implementation code mostly works with the Content method because it
// returns something valid either way.
enc []byte

// length holds the number of items in the list.
length int
}

// Content returns the RLP-encoded data of the list.
Expand Down Expand Up @@ -87,7 +90,14 @@ func (r *RawList[T]) DecodeRLP(s *Stream) error {
if err := s.readFull(enc[9:]); err != nil {
return err
}
*r = RawList[T]{enc: enc}
n, err := CountValues(enc[9:])
if err != nil {
if err == ErrValueTooLarge {
return ErrElemTooLarge
}
return err
}
*r = RawList[T]{enc: enc, length: n}
return nil
}

Expand All @@ -105,20 +115,14 @@ func (r *RawList[T]) Items() ([]T, error) {

// Len returns the number of items in the list.
func (r *RawList[T]) Len() int {
len, _ := CountValues(r.Content())
return len
return r.length
}

// Size returns the encoded size of the list.
func (r *RawList[T]) Size() uint64 {
return ListSize(uint64(len(r.Content())))
}

// Empty returns true if the list contains no items.
func (r *RawList[T]) Empty() bool {
return len(r.Content()) == 0
}

// ContentIterator returns an iterator over the content of the list.
// Note the offsets returned by iterator.Offset are relative to the
// Content bytes of the list.
Expand All @@ -142,15 +146,26 @@ func (r *RawList[T]) Append(item T) error {
end := prevEnd + eb.size()
r.enc = slices.Grow(r.enc, eb.size())[:end]
eb.copyTo(r.enc[prevEnd:end])
r.length++
return nil
}

// AppendRaw adds an encoded item to the list.
func (r *RawList[T]) AppendRaw(b []byte) {
// The given byte slice must contain exactly one RLP value.
func (r *RawList[T]) AppendRaw(b []byte) error {
_, tagsize, contentsize, err := readKind(b)
if err != nil {
return err
}
if tagsize+contentsize != uint64(len(b)) {
return fmt.Errorf("rlp: input has trailing bytes in AppendRaw")
}
if r.enc == nil {
r.enc = make([]byte, 9)
}
r.enc = append(r.enc, b...)
r.length++
return nil
}

// StringSize returns the encoded size of a string.
Expand Down
58 changes: 52 additions & 6 deletions rlp/raw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ func (test rawListTest[T]) run(t *testing.T) {
// check iterator
it := rl.ContentIterator()
i := 0
if count := it.Count(); count != test.length {
t.Fatalf("iterator has wrong Count %d, want %d", count, test.length)
}
for it.Next() {
var item T
if err := DecodeBytes(it.Value(), &item); err != nil {
Expand Down Expand Up @@ -154,9 +151,6 @@ func TestRawListEmpty(t *testing.T) {
if !bytes.Equal(b, unhex("C0")) {
t.Fatalf("empty RawList has wrong encoding %x", b)
}
if !rl.Empty() {
t.Fatal("list should be Empty")
}
if rl.Len() != 0 {
t.Fatalf("empty list has Len %d", rl.Len())
}
Expand Down Expand Up @@ -226,6 +220,58 @@ func TestRawListAppend(t *testing.T) {
}
}

func TestRawListAppendRaw(t *testing.T) {
var rl RawList[uint64]

if err := rl.AppendRaw(unhex("01")); err != nil {
t.Fatal("AppendRaw(01) failed:", err)
}
if err := rl.AppendRaw(unhex("820102")); err != nil {
t.Fatal("AppendRaw(820102) failed:", err)
}
if rl.Len() != 2 {
t.Fatalf("wrong Len %d after valid appends", rl.Len())
}

if err := rl.AppendRaw(nil); err == nil {
t.Fatal("AppendRaw(nil) should fail")
}
if err := rl.AppendRaw(unhex("0102")); err == nil {
t.Fatal("AppendRaw(0102) should fail due to trailing bytes")
}
if err := rl.AppendRaw(unhex("8201")); err == nil {
t.Fatal("AppendRaw(8201) should fail due to truncated value")
}
if rl.Len() != 2 {
t.Fatalf("wrong Len %d after invalid appends, want 2", rl.Len())
}
}

func TestRawListDecodeInvalid(t *testing.T) {
tests := []struct {
input string
err error
}{
// Single item with non-canonical size (0x81 wrapping byte <= 0x7F).
{input: "C28142", err: ErrCanonSize},
// Single item claiming more bytes than available in the list.
{input: "C484020202", err: ErrElemTooLarge},
// Two items, second has non-canonical size.
{input: "C3018142", err: ErrCanonSize},
// Two items, second claims more bytes than remain in the list.
{input: "C401830202", err: ErrElemTooLarge},
// Item is a sub-list whose declared size exceeds available bytes.
{input: "C3C40102", err: ErrElemTooLarge},
}
for _, test := range tests {
var rl RawList[RawValue]
err := DecodeBytes(unhex(test.input), &rl)
if !errors.Is(err, test.err) {
t.Errorf("input %s: error mismatch: got %v, want %v", test.input, err, test.err)
}
}
}

func TestCountValues(t *testing.T) {
tests := []struct {
input string // note: spaces in input are stripped by unhex
Expand Down