diff --git a/.pinned b/.pinned index a3f67b6fc4..d856608c92 100644 --- a/.pinned +++ b/.pinned @@ -15,7 +15,7 @@ serialization;https://github.com/status-im/nim-serialization@#548d0adc9797a10b2d stew;https://github.com/status-im/nim-stew@#b66168735d6f3841c5239c3169d3fe5fe98b1257 testutils;https://github.com/status-im/nim-testutils@#9e842bd58420d23044bc55e16088e8abbe93ce51 unittest2;https://github.com/status-im/nim-unittest2@#8b51e99b4a57fcfb31689230e75595f024543024 -websock;https://github.com/status-im/nim-websock@#35ae76f1559e835c80f9c1a3943bf995d3dd9eb5 +websock;https://github.com/status-im/nim-websock@#f30d4633a761c6615e679de5fa0c0e63460a9ce3 zlib;https://github.com/status-im/nim-zlib@#daa8723fd32299d4ca621c837430c29a5a11e19a jwt;https://github.com/vacp2p/nim-jwt@#18f8378de52b241f321c1f9ea905456e89b95c6f bearssl_pkey_decoder;https://github.com/vacp2p/bearssl_pkey_decoder@#21dd3710df9345ed2ad8bf8f882761e07863b8e0 diff --git a/libp2p/transports/wstransport.nim b/libp2p/transports/wstransport.nim index de52fcb35a..84ad4c5ac4 100644 --- a/libp2p/transports/wstransport.nim +++ b/libp2p/transports/wstransport.nim @@ -26,6 +26,7 @@ import ../utility, ../stream/connection, ../upgrademngrs/upgrade, + ../utils/semaphore, websock/websock logScope: @@ -34,9 +35,10 @@ logScope: export transport, websock, results const - DefaultHeadersTimeout = 3.seconds + DefaultHandshakeTimeout = 3.seconds DefaultAutotlsWaitTimeout = 3.seconds DefaultAutotlsRetries = 3 + DefaultConcurrentAccepts = 200 type WsStream = ref object of Connection @@ -111,11 +113,17 @@ method closeImpl*(s: WsStream): Future[void] {.async: (raises: []).} = method getWrapped*(s: WsStream): Connection = nil +type AcceptResult = Result[Connection, ref CatchableError] + type WsTransport* = ref object of Transport httpservers: seq[HttpServer] wsserver: WSServer connections: array[Direction, seq[WsStream]] - acceptFuts: seq[Future[HttpRequest]] + handshakeFuts: seq[Future[void]] + acceptResults: AsyncQueue[AcceptResult] + acceptLoop: Future[void] + acceptSem: AsyncSemaphore + concurrentAccepts: int tlsPrivateKey*: TLSPrivateKey tlsCertificate*: TLSCertificate @@ -129,6 +137,117 @@ type WsTransport* = ref object of Transport proc secure*(self: WsTransport): bool = not (isNil(self.tlsPrivateKey) or isNil(self.tlsCertificate)) +proc connHandler( + self: WsTransport, stream: WSSession, secure: bool, dir: Direction +): Future[Connection] {.async: (raises: [CatchableError]).} = + let (observedAddr, localAddr) = + try: + let + codec = + if secure: + MultiAddress.init("/wss") + else: + MultiAddress.init("/ws") + remoteAddr = stream.stream.reader.tsource.remoteAddress + localAddr = stream.stream.reader.tsource.localAddress + + ( + MultiAddress.init(remoteAddr).tryGet() & codec.tryGet(), + MultiAddress.init(localAddr).tryGet() & codec.tryGet(), + ) + except CatchableError as exc: + trace "Failed to create observedAddr or listenAddr", description = exc.msg + if not (isNil(stream) and stream.stream.reader.closed): + safeClose(stream) + raise exc + + let conn = WsStream.new(stream, dir, Opt.some(observedAddr), Opt.some(localAddr)) + + self.connections[dir].add(conn) + proc onClose() {.async: (raises: []).} = + await noCancel conn.session.stream.reader.join() + self.connections[dir].keepItIf(it != conn) + trace "Cleaned up client" + + asyncSpawn onClose() + return conn + +proc addHandshakeResult(self: WsTransport, ares: AcceptResult) = + try: + self.acceptResults.addLastNoWait(ares) + except AsyncQueueFullError: # never happens but need to catch + discard + +proc handshakeWorker( + self: WsTransport, server: HttpServer, clientStream: AsyncStream +) {.async: (raises: []).} = + try: + let conn = await ( + proc(): Future[Connection] {.async.} = + let req = await server.processHttpRequest(clientStream) + let wstransp = await self.wsserver.handleRequest(req) + return await self.connHandler(wstransp, server.secure, Direction.In) + )() + .wait(self.handshakeTimeout) + self.addHandshakeResult(AcceptResult.ok(conn)) + except CatchableError as exc: + await noCancel clientStream.closeWait() + self.addHandshakeResult(AcceptResult.err(exc)) + finally: + self.acceptSem.release() + +proc acceptDispatcher(self: WsTransport) {.async: (raises: []).} = + trace "Started acceptDispatcher" + + var acceptFuts: seq[Future[AsyncStream]] = @[] + for server in self.httpservers: + acceptFuts.add(server.acceptStream()) + if acceptFuts.len == 0: + error "acceptDispatcher has no work; terminating" + return + + while self.running: + try: + if self.handshakeFuts.len > 0: + self.handshakeFuts.keepItIf(not it.finished) + await self.acceptSem.acquire() + except CancelledError: + continue + try: + let streamFut = await one(acceptFuts) + let idx = acceptFuts.find(streamFut) + if idx < 0: + self.acceptSem.release() + continue + + let httpServer = self.httpservers[idx] + acceptFuts[idx] = httpServer.acceptStream() + + if streamFut.failed: + self.acceptSem.release() + self.addHandshakeResult(AcceptResult.err(streamFut.error)) + continue + + let hFut = self.handshakeWorker(httpServer, streamFut.read()) + self.handshakeFuts.add(hFut) + except CatchableError as exc: + self.acceptSem.release() + if not self.running: + break + trace "Error in acceptDispatcher", msg = exc.msg + try: + await sleepAsync(100.milliseconds) + except CancelledError: + discard + + trace "Exiting acceptDispatcher" + for fut in acceptFuts: + if not fut.finished: + await fut.cancelAndWait() + self.addHandshakeResult( + AcceptResult.err(newException(TransportClosedError, "Server is closed")) + ) + method start*( self: WsTransport, addrs: seq[MultiAddress] ) {.async: (raises: [LPError, transport.TransportError, CancelledError]).} = @@ -175,6 +294,9 @@ method start*( let address = ma.initTAddress().tryGet() + # allow HTTP headers to take up to 90% of the WS handshake's total time budget + let headerProcessingTimeout = self.handshakeTimeout * 9 div 10 + let httpserver = try: if isWss: @@ -183,10 +305,10 @@ method start*( tlsPrivateKey = self.tlsPrivateKey, tlsCertificate = self.tlsCertificate, flags = self.flags, - handshakeTimeout = self.handshakeTimeout, + headersTimeout = headerProcessingTimeout, ) else: - HttpServer.create(address, handshakeTimeout = self.handshakeTimeout) + HttpServer.create(address, headersTimeout = headerProcessingTimeout) except CatchableError as exc: raise (ref WsTransportError)( msg: "error in WsTransport start: " & exc.msg, parent: exc @@ -209,6 +331,10 @@ method start*( trace "Listening on", addresses = self.addrs + self.acceptSem = newAsyncSemaphore(self.concurrentAccepts) + self.acceptResults = newAsyncQueue[AcceptResult]() + self.acceptLoop = self.acceptDispatcher() + method stop*(self: WsTransport) {.async: (raises: []).} = ## stop the transport ## @@ -224,17 +350,20 @@ method stop*(self: WsTransport) {.async: (raises: []).} = self.connections[Direction.Out].mapIt(it.close()) ) + if not isNil(self.acceptLoop): + await self.acceptLoop.cancelAndWait() + var toWait: seq[Future[void]] - for fut in self.acceptFuts: - if not fut.finished: - toWait.add(fut.cancelAndWait()) - elif fut.completed: - toWait.add(fut.read().stream.closeWait()) for server in self.httpservers: server.stop() toWait.add(server.closeWait()) + for fut in self.handshakeFuts: + if not fut.finished: + fut.cancel() + toWait.add(self.handshakeFuts) + await allFutures(toWait) self.httpservers = @[] @@ -242,43 +371,6 @@ method stop*(self: WsTransport) {.async: (raises: []).} = except CatchableError as exc: trace "Error shutting down ws transport", description = exc.msg -proc connHandler( - self: WsTransport, stream: WSSession, secure: bool, dir: Direction -): Future[Connection] {.async: (raises: [CatchableError]).} = - ## Returning CatchableError is fine because we later handle different exceptions. - - let (observedAddr, localAddr) = - try: - let - codec = - if secure: - MultiAddress.init("/wss") - else: - MultiAddress.init("/ws") - remoteAddr = stream.stream.reader.tsource.remoteAddress - localAddr = stream.stream.reader.tsource.localAddress - - ( - MultiAddress.init(remoteAddr).tryGet() & codec.tryGet(), - MultiAddress.init(localAddr).tryGet() & codec.tryGet(), - ) - except CatchableError as exc: - trace "Failed to create observedAddr or listenAddr", description = exc.msg - if not (isNil(stream) and stream.stream.reader.closed): - safeClose(stream) - raise exc - - let conn = WsStream.new(stream, dir, Opt.some(observedAddr), Opt.some(localAddr)) - - self.connections[dir].add(conn) - proc onClose() {.async: (raises: []).} = - await noCancel conn.session.stream.reader.join() - self.connections[dir].keepItIf(it != conn) - trace "Cleaned up client" - - asyncSpawn onClose() - return conn - method accept*( self: WsTransport ): Future[Connection] {.async: (raises: [transport.TransportError, CancelledError]).} = @@ -294,34 +386,12 @@ method accept*( if not self.running: raise newTransportClosedError() - if self.acceptFuts.len <= 0: - self.acceptFuts = self.httpservers.mapIt(it.accept()) - - if self.acceptFuts.len <= 0: - return - - let finished = - try: - await one(self.acceptFuts) - except ValueError: - raiseAssert("already checked with if") - except CancelledError as e: - raise e - - let index = self.acceptFuts.find(finished) - self.acceptFuts[index] = self.httpservers[index].accept() + let res = await self.acceptResults.popFirst() + res.isErrOr: + return value try: - let req = await finished - - try: - let wstransp = await self.wsserver.handleRequest(req).wait(self.handshakeTimeout) - let isSecure = self.httpservers[index].secure - - return await self.connHandler(wstransp, isSecure, Direction.In) - except CatchableError as exc: - await noCancel req.stream.closeWait() - raise exc + raise res.error except WebSocketError as exc: debug "Websocket Error", description = exc.msg except HttpError as exc: @@ -334,13 +404,17 @@ method accept*( debug "Connection aborted", description = exc.msg except AsyncTimeoutError as exc: debug "Timed out", description = exc.msg + except TransportOsError as exc: + debug "OS Error", description = exc.msg except TransportUseClosedError as exc: debug "Server was closed", description = exc.msg raise newTransportClosedError(exc) + except TransportClosedError as exc: + self.addHandshakeResult(res) + debug "Server was closed", description = exc.msg + raise newTransportClosedError(exc) except CancelledError as exc: raise exc - except TransportOsError as exc: - debug "OS Error", description = exc.msg except CatchableError as exc: info "Unexpected error accepting connection", description = exc.msg raise newException( @@ -392,9 +466,11 @@ proc new*( flags: set[ServerFlags] = {}, factories: openArray[ExtFactory] = [], rng: ref HmacDrbgContext = nil, - handshakeTimeout = DefaultHeadersTimeout, + handshakeTimeout = DefaultHandshakeTimeout, + concurrentAccepts = DefaultConcurrentAccepts, ): T {.raises: [].} = ## Creates a secure WebSocket transport + doAssert concurrentAccepts > 0, "must accept connections" let self = T( upgrader: upgrade, @@ -406,6 +482,7 @@ proc new*( factories: @factories, rng: rng, handshakeTimeout: handshakeTimeout, + concurrentAccepts: concurrentAccepts, ) procCall Transport(self).initialize() self @@ -416,9 +493,11 @@ proc new*( flags: set[ServerFlags] = {}, factories: openArray[ExtFactory] = [], rng: ref HmacDrbgContext = nil, - handshakeTimeout = DefaultHeadersTimeout, + handshakeTimeout = DefaultHandshakeTimeout, + concurrentAccepts = DefaultConcurrentAccepts, ): T {.raises: [].} = ## Creates a clear-text WebSocket transport + doAssert concurrentAccepts > 0, "must accept connections" T.new( upgrade = upgrade, @@ -429,4 +508,5 @@ proc new*( factories = @factories, rng = rng, handshakeTimeout = handshakeTimeout, + concurrentAccepts = concurrentAccepts, ) diff --git a/tests/libp2p/transports/stream_tests.nim b/tests/libp2p/transports/stream_tests.nim index 86294f0973..19fc4b56d1 100644 --- a/tests/libp2p/transports/stream_tests.nim +++ b/tests/libp2p/transports/stream_tests.nim @@ -448,7 +448,9 @@ template streamTransportTest*( const chunkSize = 64 const chunkCount = 32 const messageSize = chunkSize * chunkCount + const errorClientId: byte = 0xff const numConnections = 5 + doAssert numConnections < errorClientId var serverReadOrder: seq[byte] = @[] # Track when stream handlers complete @@ -479,10 +481,18 @@ template streamTransportTest*( # Doing this improves likelihood of parallel data transition on the connections. await sleepAsync(rand(20 .. 100).milliseconds) - check receivedData == newData(messageSize, byte(handlerIndex)) + let + # Get the client ID from any byte of the data; can't depend on accept/dial order. + clientId = + if receivedData.len > 0: + receivedData[0] + else: + errorClientId + + check receivedData == newData(messageSize, clientId) # Send back ID - await stream.write(@[byte(receivedData[0])]) + await stream.write(@[clientId]) # Signal that this stream handler is done serverStreamHandlerFuts[handlerIndex].complete()