Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions cmd/devp2p/internal/ethtest/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,3 @@ func protoOffset(proto Proto) uint64 {
panic("unhandled protocol")
}
}

// msgTypePtr is the constraint for protocol message types.
type msgTypePtr[U any] interface {
*U
Kind() byte
}
7 changes: 6 additions & 1 deletion cmd/devp2p/internal/ethtest/snap.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/eth/protocols/snap"
"github.com/ethereum/go-ethereum/internal/utesting"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
"github.com/ethereum/go-ethereum/trie/trienode"
)
Expand Down Expand Up @@ -937,10 +938,14 @@ func (s *Suite) snapGetTrieNodes(t *utesting.T, tc *trieNodesTest) error {
}

// write0 request
paths, err := rlp.EncodeToRawList(tc.paths)
if err != nil {
panic(err)
}
req := &snap.GetTrieNodesPacket{
ID: uint64(rand.Int63()),
Root: tc.root,
Paths: tc.paths,
Paths: paths,
Bytes: tc.nBytes,
}
msg, err := conn.snapRequest(snap.GetTrieNodesMsg, req)
Expand Down
99 changes: 56 additions & 43 deletions cmd/devp2p/internal/ethtest/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/ethereum/go-ethereum/internal/utesting"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/rlp"
"github.com/holiman/uint256"
)

Expand Down Expand Up @@ -151,7 +152,11 @@ func (s *Suite) TestGetBlockHeaders(t *utesting.T) {
if err != nil {
t.Fatalf("failed to get headers for given request: %v", err)
}
if !headersMatch(expected, headers.BlockHeadersRequest) {
received, err := headers.List.Items()
if err != nil {
t.Fatalf("invalid headers received: %v", err)
}
if !headersMatch(expected, received) {
t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers)
}
}
Expand Down Expand Up @@ -237,7 +242,7 @@ concurrently, with different request IDs.`)

// Wait for responses.
// Note they can arrive in either order.
resp, err := collectResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 {
resp, err := collectHeaderResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 {
if msg.RequestId != 111 && msg.RequestId != 222 {
t.Fatalf("response with unknown request ID: %v", msg.RequestId)
}
Expand All @@ -248,17 +253,11 @@ concurrently, with different request IDs.`)
}

// Check if headers match.
resp1 := resp[111]
if expected, err := s.chain.GetHeaders(req1); err != nil {
t.Fatalf("failed to get expected headers for request 1: %v", err)
} else if !headersMatch(expected, resp1.BlockHeadersRequest) {
t.Fatalf("header mismatch for request ID %v: \nexpected %v \ngot %v", 111, expected, resp1)
}
resp2 := resp[222]
if expected, err := s.chain.GetHeaders(req2); err != nil {
t.Fatalf("failed to get expected headers for request 2: %v", err)
} else if !headersMatch(expected, resp2.BlockHeadersRequest) {
t.Fatalf("header mismatch for request ID %v: \nexpected %v \ngot %v", 222, expected, resp2)
if err := s.checkHeadersAgainstChain(req1, resp[111]); err != nil {
t.Fatal(err)
}
if err := s.checkHeadersAgainstChain(req2, resp[222]); err != nil {
t.Fatal(err)
}
}

Expand Down Expand Up @@ -303,8 +302,8 @@ same request ID. The node should handle the request by responding to both reques

// Wait for the responses. They can arrive in either order, and we can't tell them
// apart by their request ID, so use the number of headers instead.
resp, err := collectResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 {
id := uint64(len(msg.BlockHeadersRequest))
resp, err := collectHeaderResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 {
id := uint64(msg.List.Len())
if id != 2 && id != 3 {
t.Fatalf("invalid number of headers in response: %d", id)
}
Expand All @@ -315,26 +314,35 @@ same request ID. The node should handle the request by responding to both reques
}

// Check if headers match.
resp1 := resp[2]
if expected, err := s.chain.GetHeaders(request1); err != nil {
t.Fatalf("failed to get expected headers for request 1: %v", err)
} else if !headersMatch(expected, resp1.BlockHeadersRequest) {
t.Fatalf("headers mismatch: \nexpected %v \ngot %v", expected, resp1)
}
resp2 := resp[3]
if expected, err := s.chain.GetHeaders(request2); err != nil {
t.Fatalf("failed to get expected headers for request 2: %v", err)
} else if !headersMatch(expected, resp2.BlockHeadersRequest) {
t.Fatalf("headers mismatch: \nexpected %v \ngot %v", expected, resp2)
if err := s.checkHeadersAgainstChain(request1, resp[2]); err != nil {
t.Fatal(err)
}
if err := s.checkHeadersAgainstChain(request2, resp[3]); err != nil {
t.Fatal(err)
}
}

func (s *Suite) checkHeadersAgainstChain(req *eth.GetBlockHeadersPacket, resp *eth.BlockHeadersPacket) error {
received2, err := resp.List.Items()
if err != nil {
return fmt.Errorf("invalid headers in response with request ID %v (%d items): %v", resp.RequestId, resp.List.Len(), err)
}
if expected, err := s.chain.GetHeaders(req); err != nil {
return fmt.Errorf("test chain failed to get expected headers for request: %v", err)
} else if !headersMatch(expected, received2) {
return fmt.Errorf("header mismatch for request ID %v (%d items): \nexpected %v \ngot %v", resp.RequestId, resp.List.Len(), expected, resp)
}
return nil
}

// collectResponses waits for n messages of type T on the given connection.
// The messsages are collected according to the 'identity' function.
func collectResponses[T any, P msgTypePtr[T]](conn *Conn, n int, identity func(P) uint64) (map[uint64]P, error) {
resp := make(map[uint64]P, n)
//
// This function is written in a generic way to handle
func collectHeaderResponses(conn *Conn, n int, identity func(*eth.BlockHeadersPacket) uint64) (map[uint64]*eth.BlockHeadersPacket, error) {
resp := make(map[uint64]*eth.BlockHeadersPacket, n)
for range n {
r := new(T)
r := new(eth.BlockHeadersPacket)
if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, r); err != nil {
return resp, fmt.Errorf("read error: %v", err)
}
Expand Down Expand Up @@ -373,10 +381,8 @@ and expects a response.`)
if got, want := headers.RequestId, req.RequestId; got != want {
t.Fatalf("unexpected request id")
}
if expected, err := s.chain.GetHeaders(req); err != nil {
t.Fatalf("failed to get expected block headers: %v", err)
} else if !headersMatch(expected, headers.BlockHeadersRequest) {
t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers)
if err := s.checkHeadersAgainstChain(req, headers); err != nil {
t.Fatal(err)
}
}

Expand Down Expand Up @@ -407,9 +413,8 @@ func (s *Suite) TestGetBlockBodies(t *utesting.T) {
if got, want := resp.RequestId, req.RequestId; got != want {
t.Fatalf("unexpected request id in respond", got, want)
}
bodies := resp.BlockBodiesResponse
if len(bodies) != len(req.GetBlockBodiesRequest) {
t.Fatalf("wrong bodies in response: expected %d bodies, got %d", len(req.GetBlockBodiesRequest), len(bodies))
if resp.List.Len() != len(req.GetBlockBodiesRequest) {
t.Fatalf("wrong bodies in response: expected %d bodies, got %d", len(req.GetBlockBodiesRequest), resp.List.Len())
}
}

Expand All @@ -433,7 +438,7 @@ func (s *Suite) TestGetReceipts(t *utesting.T) {
}
}

// Create block bodies request.
// Create receipts request.
req := &eth.GetReceiptsPacket{
RequestId: 66,
GetReceiptsRequest: (eth.GetReceiptsRequest)(hashes),
Expand All @@ -449,8 +454,8 @@ func (s *Suite) TestGetReceipts(t *utesting.T) {
if got, want := resp.RequestId, req.RequestId; got != want {
t.Fatalf("unexpected request id in respond", got, want)
}
if len(resp.List) != len(req.GetReceiptsRequest) {
t.Fatalf("wrong bodies in response: expected %d bodies, got %d", len(req.GetReceiptsRequest), len(resp.List))
if resp.List.Len() != len(req.GetReceiptsRequest) {
t.Fatalf("wrong receipts in response: expected %d receipts, got %d", len(req.GetReceiptsRequest), resp.List.Len())
}
}

Expand Down Expand Up @@ -804,7 +809,11 @@ on another peer connection using GetPooledTransactions.`)
if got, want := msg.RequestId, req.RequestId; got != want {
t.Fatalf("unexpected request id in response: got %d, want %d", got, want)
}
for _, got := range msg.PooledTransactionsResponse {
responseTxs, err := msg.List.Items()
if err != nil {
t.Fatalf("invalid transactions in response: %v", err)
}
for _, got := range responseTxs {
if _, exists := set[got.Hash()]; !exists {
t.Fatalf("unexpected tx received: %v", got.Hash())
}
Expand Down Expand Up @@ -976,7 +985,9 @@ func (s *Suite) TestBlobViolations(t *utesting.T) {
if err := conn.ReadMsg(ethProto, eth.GetPooledTransactionsMsg, req); err != nil {
t.Fatalf("reading pooled tx request failed: %v", err)
}
resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, PooledTransactionsResponse: test.resp}

encTxs, _ := rlp.EncodeToRawList(test.resp)
resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, List: encTxs}
if err := conn.Write(ethProto, eth.PooledTransactionsMsg, resp); err != nil {
t.Fatalf("writing pooled tx response failed: %v", err)
}
Expand Down Expand Up @@ -1104,7 +1115,8 @@ func (s *Suite) testBadBlobTx(t *utesting.T, tx *types.Transaction, badTx *types
// the good peer is connected, and has announced the tx.
// proceed to send the incorrect one from the bad peer.

resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, PooledTransactionsResponse: eth.PooledTransactionsResponse(types.Transactions{badTx})}
encTxs, _ := rlp.EncodeToRawList([]*types.Transaction{badTx})
resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, List: encTxs}
if err := conn.Write(ethProto, eth.PooledTransactionsMsg, resp); err != nil {
errc <- fmt.Errorf("writing pooled tx response failed: %v", err)
return
Expand Down Expand Up @@ -1164,7 +1176,8 @@ func (s *Suite) testBadBlobTx(t *utesting.T, tx *types.Transaction, badTx *types
return
}

resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, PooledTransactionsResponse: eth.PooledTransactionsResponse(types.Transactions{tx})}
encTxs, _ := rlp.EncodeToRawList([]*types.Transaction{tx})
resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, List: encTxs}
if err := conn.Write(ethProto, eth.PooledTransactionsMsg, resp); err != nil {
errc <- fmt.Errorf("writing pooled tx response failed: %v", err)
return
Expand Down
17 changes: 11 additions & 6 deletions cmd/devp2p/internal/ethtest/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/internal/utesting"
"github.com/ethereum/go-ethereum/rlp"
)

// sendTxs sends the given transactions to the node and
Expand All @@ -51,7 +52,8 @@ func (s *Suite) sendTxs(t *utesting.T, txs []*types.Transaction) error {
return fmt.Errorf("peering failed: %v", err)
}

if err = sendConn.Write(ethProto, eth.TransactionsMsg, eth.TransactionsPacket(txs)); err != nil {
encTxs, _ := rlp.EncodeToRawList(txs)
if err = sendConn.Write(ethProto, eth.TransactionsMsg, eth.TransactionsPacket{RawList: encTxs}); err != nil {
return fmt.Errorf("failed to write message to connection: %v", err)
}

Expand All @@ -68,7 +70,8 @@ func (s *Suite) sendTxs(t *utesting.T, txs []*types.Transaction) error {
}
switch msg := msg.(type) {
case *eth.TransactionsPacket:
for _, tx := range *msg {
txs, _ := msg.Items()
for _, tx := range txs {
got[tx.Hash()] = true
}
case *eth.NewPooledTransactionHashesPacket:
Expand All @@ -80,9 +83,10 @@ func (s *Suite) sendTxs(t *utesting.T, txs []*types.Transaction) error {
if err != nil {
t.Logf("invalid GetBlockHeaders request: %v", err)
}
encHeaders, _ := rlp.EncodeToRawList(headers)
recvConn.Write(ethProto, eth.BlockHeadersMsg, &eth.BlockHeadersPacket{
RequestId: msg.RequestId,
BlockHeadersRequest: headers,
RequestId: msg.RequestId,
List: encHeaders,
})
default:
return fmt.Errorf("unexpected eth wire msg: %s", pretty.Sdump(msg))
Expand Down Expand Up @@ -167,9 +171,10 @@ func (s *Suite) sendInvalidTxs(t *utesting.T, txs []*types.Transaction) error {
if err != nil {
t.Logf("invalid GetBlockHeaders request: %v", err)
}
encHeaders, _ := rlp.EncodeToRawList(headers)
recvConn.Write(ethProto, eth.BlockHeadersMsg, &eth.BlockHeadersPacket{
RequestId: msg.RequestId,
BlockHeadersRequest: headers,
RequestId: msg.RequestId,
List: encHeaders,
})
default:
return fmt.Errorf("unexpected eth message: %v", pretty.Sdump(msg))
Expand Down
41 changes: 25 additions & 16 deletions eth/downloader/downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,12 @@ func (dlp *downloadTesterPeer) RequestHeadersByNumber(origin uint64, amount int,
func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash, sink chan *eth.Response) (*eth.Request, error) {
blobs := eth.ServiceGetBlockBodiesQuery(dlp.chain, hashes)

bodies := make([]*eth.BlockBody, len(blobs))
bodies := make([]*types.Body, len(blobs))
ethbodies := make([]eth.BlockBody, len(blobs))
for i, blob := range blobs {
bodies[i] = new(eth.BlockBody)
bodies[i] = new(types.Body)
rlp.DecodeBytes(blob, bodies[i])
rlp.DecodeBytes(blob, &ethbodies[i])
}
var (
txsHashes = make([]common.Hash, len(bodies))
Expand All @@ -239,9 +241,13 @@ func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash, sink chan *et
Peer: dlp.id,
}
res := &eth.Response{
Req: req,
Res: (*eth.BlockBodiesResponse)(&bodies),
Meta: [][]common.Hash{txsHashes, uncleHashes, withdrawalHashes},
Req: req,
Res: (*eth.BlockBodiesResponse)(&ethbodies),
Meta: eth.BlockBodyHashes{
TransactionRoots: txsHashes,
UncleHashes: uncleHashes,
WithdrawalRoots: withdrawalHashes,
},
Time: 1,
Done: make(chan error, 1), // Ignore the returned status
}
Expand Down Expand Up @@ -290,14 +296,14 @@ func (dlp *downloadTesterPeer) ID() string {

// RequestAccountRange fetches a batch of accounts rooted in a specific account
// trie, starting with the origin.
func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes uint64) error {
func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes int) error {
// Create the request and service it
req := &snap.GetAccountRangePacket{
ID: id,
Root: root,
Origin: origin,
Limit: limit,
Bytes: bytes,
Bytes: uint64(bytes),
}
slimaccs, proofs := snap.ServiceGetAccountRangeQuery(dlp.chain, req)

Expand All @@ -316,15 +322,15 @@ func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limi
// RequestStorageRanges fetches a batch of storage slots belonging to one or
// more accounts. If slots from only one account is requested, an origin marker
// may also be used to retrieve from there.
func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error {
func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error {
// Create the request and service it
req := &snap.GetStorageRangesPacket{
ID: id,
Accounts: accounts,
Root: root,
Origin: origin,
Limit: limit,
Bytes: bytes,
Bytes: uint64(bytes),
}
storage, proofs := snap.ServiceGetStorageRangesQuery(dlp.chain, req)

Expand All @@ -341,25 +347,28 @@ func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash,
}

// RequestByteCodes fetches a batch of bytecodes by hash.
func (dlp *downloadTesterPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error {
func (dlp *downloadTesterPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error {
req := &snap.GetByteCodesPacket{
ID: id,
Hashes: hashes,
Bytes: bytes,
Bytes: uint64(bytes),
}
codes := snap.ServiceGetByteCodesQuery(dlp.chain, req)
go dlp.dl.downloader.SnapSyncer.OnByteCodes(dlp, id, codes)
return nil
}

// RequestTrieNodes fetches a batch of account or storage trie nodes rooted in
// a specific state trie.
func (dlp *downloadTesterPeer) RequestTrieNodes(id uint64, root common.Hash, paths []snap.TrieNodePathSet, bytes uint64) error {
// RequestTrieNodes fetches a batch of account or storage trie nodes.
func (dlp *downloadTesterPeer) RequestTrieNodes(id uint64, root common.Hash, count int, paths []snap.TrieNodePathSet, bytes int) error {
encPaths, err := rlp.EncodeToRawList(paths)
if err != nil {
panic(err)
}
req := &snap.GetTrieNodesPacket{
ID: id,
Root: root,
Paths: paths,
Bytes: bytes,
Paths: encPaths,
Bytes: uint64(bytes),
}
nodes, _ := snap.ServiceGetTrieNodesQuery(dlp.chain, req, time.Now())
go dlp.dl.downloader.SnapSyncer.OnTrieNodes(dlp, id, nodes)
Expand Down
Loading