diff --git a/beacon_chain/networking/eth2_network.nim b/beacon_chain/networking/eth2_network.nim index de4571bfa6..43de82ff3c 100644 --- a/beacon_chain/networking/eth2_network.nim +++ b/beacon_chain/networking/eth2_network.nim @@ -135,7 +135,8 @@ type ## Protocol requests using this type will produce request-making ## client-side procs that return `NetRes[MsgType]` - MultipleChunksResponse*[MsgType; maxLen: static Limit] = distinct UntypedResponse + MultipleChunksResponse*[ + MsgType; maxLen: static Limit] = distinct UntypedResponse ## Protocol requests using this type will produce request-making ## client-side procs that return `NetRes[List[MsgType, maxLen]]`. ## In the future, such procs will return an `InputStream[NetRes[MsgType]]`. @@ -952,9 +953,11 @@ proc readResponseChunk( return await readChunkPayload(conn, peer, MsgType) -proc readResponse(conn: Connection, peer: Peer, - MsgType: type, timeout: Duration): Future[NetRes[MsgType]] - {.async: (raises: [CancelledError]).} = +proc readResponse( + conn: Connection, peer: Peer, + MsgType: type, maxResponseItems: Limit, + timeout: Duration +): Future[NetRes[MsgType]] {.async: (raises: [CancelledError]).} = when MsgType is List: type E = MsgType.T var results: MsgType @@ -979,18 +982,20 @@ proc readResponse(conn: Connection, peer: Peer, return err nextRes.error else: trace "Got chunk", conn - if not results.add nextRes.value: + if results.len >= maxResponseItems or not results.add nextRes.value: return neterr(ResponseChunkOverflow) else: + discard maxResponseItems # Always set to 1 for non-`List` responses let nextFut = conn.readResponseChunk(peer, MsgType) if not await nextFut.withTimeout(timeout): return neterr(ReadResponseTimeout) return await nextFut # Guaranteed to complete without waiting -proc makeEth2Request(peer: Peer, protocolId: string, requestBytes: seq[byte], - ResponseMsg: type, - timeout: Duration): Future[NetRes[ResponseMsg]] - {.async: (raises: [CancelledError]).} = +proc doMakeEth2Request( + peer: Peer, protocolId: string, requestBytes: seq[byte], + ResponseMsg: type, maxResponseItems: Limit, + timeout: Duration +): Future[NetRes[ResponseMsg]] {.async: (raises: [CancelledError]).} = let deadline = sleepAsync timeout streamRes = @@ -1017,7 +1022,8 @@ proc makeEth2Request(peer: Peer, protocolId: string, requestBytes: seq[byte], nbc_reqresp_messages_sent.inc(1, [shortProtocolId(protocolId)]) # Read the response - let res = await readResponse(stream, peer, ResponseMsg, timeout) + let res = await readResponse( + stream, peer, ResponseMsg, maxResponseItems, timeout) if res.isErr(): if res.error().kind in ProtocolViolations: peer.updateScore(PeerScoreInvalidRequest) @@ -1036,6 +1042,31 @@ proc makeEth2Request(peer: Peer, protocolId: string, requestBytes: seq[byte], debug "Unexpected error while closing stream", peer, protocolId, exc = exc.msg +proc makeEth2Request( + peer: Peer, protocolId: string, requestBytes: seq[byte], + ResponseMsg: type, + timeout: Duration +): Future[NetRes[ResponseMsg]] {. + async: (raises: [CancelledError], raw: true).} = + when ResponseMsg is List: + doMakeEth2Request( + peer, protocolId, requestBytes, ResponseMsg, ResponseMsg.maxLen, timeout) + else: + doMakeEth2Request( + peer, protocolId, requestBytes, ResponseMsg, 1.Limit, timeout) + +proc makeEth2Request( + peer: Peer, protocolId: string, requestBytes: seq[byte], + ResponseMsg: type, maxResponseItems: Limit, + timeout: Duration +): Future[NetRes[ResponseMsg]] {. + async: (raises: [CancelledError], raw: true).} = + when ResponseMsg is List: + doMakeEth2Request( + peer, protocolId, requestBytes, ResponseMsg, maxResponseItems, timeout) + else: + static: raiseAssert $ResponseMsg & " does not support `maxResponseItems`" + func init*(T: type MultipleChunksResponse, peer: Peer, conn: Connection): T = T(UntypedResponse(peer: peer, stream: conn)) @@ -1098,7 +1129,7 @@ func setEventHandlers(p: ProtocolInfo, p.onPeerConnected = onPeerConnected p.onPeerDisconnected = onPeerDisconnected -proc implementSendProcBody(sendProc: SendProc) = +proc implementSendProcBody(sendProc: SendProc, isChunkStream: bool) = let msg = sendProc.msg UntypedResponse = bindSym "UntypedResponse" @@ -1109,9 +1140,16 @@ proc implementSendProcBody(sendProc: SendProc) = case msg.kind of msgRequest: let ResponseRecord = msg.response.recName - quote: - makeEth2Request(`peer`, `msgProto`, `bytes`, - `ResponseRecord`, `timeoutVar`) + if isChunkStream: + quote: + makeEth2Request( + `peer`, `msgProto`, `bytes`, + `ResponseRecord`, maxResponseItems, `timeoutVar`) + else: + quote: + makeEth2Request( + `peer`, `msgProto`, `bytes`, + `ResponseRecord`, `timeoutVar`) else: quote: sendNotificationMsg(`peer`, `msgProto`, `bytes`) else: @@ -2019,7 +2057,9 @@ proc p2pProtocolBackendImpl*(p: P2PProtocol): Backend = ## initialize the network object by creating handlers bound to the ## specific network. ## - var userHandlerCall = newTree(nnkDiscardStmt) + var + userHandlerCall = newTree(nnkDiscardStmt) + maxResponseItems: Opt[NimNode] if msg.userHandler != nil: var OutputParamType = if msg.kind == msgRequest: msg.outputParamType @@ -2037,6 +2077,7 @@ proc p2pProtocolBackendImpl*(p: P2PProtocol): Backend = let isChunkStream = eqIdent(OutputParamType[0], "MultipleChunksResponse") msg.response.recName = if isChunkStream: + maxResponseItems.ok OutputParamType[2] newTree(nnkBracketExpr, ident"List", OutputParamType[1], OutputParamType[2]) else: OutputParamType[1] @@ -2074,7 +2115,16 @@ proc p2pProtocolBackendImpl*(p: P2PProtocol): Backend = ## var sendProc = msg.createSendProc() - implementSendProcBody sendProc + if maxResponseItems.isSome: + sendProc.def.params.insert( + sendProc.def.params.len - 1, # Insert before implicit `timeout` param + newTree( + nnkIdentDefs, + ident"maxResponseItems", + bindSym"Limit", + maxResponseItems.get)) + + implementSendProcBody(sendProc, maxResponseItems.isSome) protocol.outProcRegistrations.add( newCall(registerMsg,