From 39ed0e753596afadad920e302ae769b28f3a982b Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 30 Nov 2021 18:24:42 +0100 Subject: [PATCH] Forward all reads before stream channel inactive (#317) ### Motivation In the AsyncHTTPClient the read event is not always forwarded right away. We have seen instances, in which we see a `HTTPClientError.remoteConnectionClosed` error, on requests that finish normally. https://github.com/swift-server/async-http-client/issues/488 On deeper inspection, I noticed: If there is no unsatisfied read event, when a stream is closed, the pending reads are not forwarded. This can lead to response bytes being ignored on successful requests. NIOHTTP2 should behave as NIO and force forward all pending reads on channelInactive. ### Changes - Forward all pending reads on channelInactive even if no read event has hit the channel ### Result All requests will receive the complete request body. --- Sources/NIOHTTP2/HTTP2StreamChannel.swift | 9 +- ...PayloadStreamMultiplexerTests+XCTest.swift | 1 + ...P2FramePayloadStreamMultiplexerTests.swift | 128 ++++++++++++++++++ 3 files changed, 136 insertions(+), 2 deletions(-) diff --git a/Sources/NIOHTTP2/HTTP2StreamChannel.swift b/Sources/NIOHTTP2/HTTP2StreamChannel.swift index bda0bf5b..4e08eddd 100644 --- a/Sources/NIOHTTP2/HTTP2StreamChannel.swift +++ b/Sources/NIOHTTP2/HTTP2StreamChannel.swift @@ -812,8 +812,13 @@ internal extension HTTP2StreamChannel { // Avoid emitting any WINDOW_UPDATE frames now that we're closed. self.windowManager.closed = true - // The stream is closed, we should aim to deliver any read frames we have for it. - self.tryToRead() + // The stream is closed, we should force forward all pending frames, even without + // unsatisfied read, to ensure the handlers can see all frames before receiving + // channelInactive. + if self.pendingReads.count > 0 && self._isActive { + self.unsatisfiedRead = false + self.deliverPendingReads() + } if let reason = reason { // To receive from the network, it must be safe to force-unwrap here. diff --git a/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests+XCTest.swift b/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests+XCTest.swift index 0c4def55..f3ba0a72 100644 --- a/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests+XCTest.swift +++ b/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests+XCTest.swift @@ -79,6 +79,7 @@ extension HTTP2FramePayloadStreamMultiplexerTests { ("testWindowUpdateIsNotEmittedAfterStreamIsClosedEvenOnLaterFrame", testWindowUpdateIsNotEmittedAfterStreamIsClosedEvenOnLaterFrame), ("testStreamChannelSupportsSyncOptions", testStreamChannelSupportsSyncOptions), ("testStreamErrorIsDeliveredToChannel", testStreamErrorIsDeliveredToChannel), + ("testPendingReadsAreFlushedEvenWithoutUnsatisfiedReadOnChannelInactive", testPendingReadsAreFlushedEvenWithoutUnsatisfiedReadOnChannelInactive), ] } } diff --git a/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests.swift b/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests.swift index be29f5db..c9aaf198 100644 --- a/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests.swift +++ b/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests.swift @@ -1992,6 +1992,79 @@ final class HTTP2FramePayloadStreamMultiplexerTests: XCTestCase { frames[0].assertHeadersFrame(endStream: false, streamID: 1, headers: goodHeaders, priority: nil, type: .request) frames[1].assertHeadersFrame(endStream: false, streamID: 3, headers: badHeaders, priority: nil, type: .doNotValidate) } + + func testPendingReadsAreFlushedEvenWithoutUnsatisfiedReadOnChannelInactive() throws { + let goodHeaders = HPACKHeaders([ + (":path", "/"), (":method", "GET"), (":scheme", "https"), (":authority", "localhost") + ]) + + let multiplexer = HTTP2StreamMultiplexer(mode: .client, channel: self.channel) { channel in + XCTFail("Server push is unexpected") + return channel.eventLoop.makeSucceededFuture(()) + } + XCTAssertNoThrow(try self.channel.pipeline.addHandler(multiplexer).wait()) + + // We need to activate the underlying channel here. + XCTAssertNoThrow(try self.channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 80)).wait()) + + // Now create two child channels with error recording handlers in them. Save one, ignore the other. + let consumer = ReadAndFrameConsumer() + var childChannel: Channel! + multiplexer.createStreamChannel(promise: nil) { channel in + childChannel = channel + return channel.pipeline.addHandler(consumer) + } + self.channel.embeddedEventLoop.run() + + let streamID = HTTP2StreamID(1) + + let payload = HTTP2Frame.FramePayload.Headers(headers: goodHeaders, endStream: true) + XCTAssertNoThrow(try childChannel.writeAndFlush(HTTP2Frame.FramePayload.headers(payload)).wait()) + + let frames = try self.channel.sentFrames() + XCTAssertEqual(frames.count, 1) + frames.first?.assertHeadersFrameMatches(this: HTTP2Frame(streamID: streamID, payload: .headers(payload))) + + XCTAssertEqual(consumer.readCount, 1) + + // 1. pass header onwards + + let responseHeaderPayload = HTTP2Frame.FramePayload.headers(.init(headers: [":status": "200"])) + XCTAssertNoThrow(try self.channel.writeInbound(HTTP2Frame(streamID: streamID, payload: responseHeaderPayload))) + XCTAssertEqual(consumer.receivedFrames.count, 1) + XCTAssertEqual(consumer.readCompleteCount, 1) + XCTAssertEqual(consumer.readCount, 2) + + consumer.forwardRead = false + + // 2. pass body onwards + + let responseBody1 = HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(.init(string: "foo")))) + XCTAssertNoThrow(try self.channel.writeInbound(HTTP2Frame(streamID: streamID, payload: responseBody1))) + XCTAssertEqual(consumer.receivedFrames.count, 2) + XCTAssertEqual(consumer.readCompleteCount, 2) + XCTAssertEqual(consumer.readCount, 3) + XCTAssertEqual(consumer.readPending, true) + + // 3. pass on more body - should not change a thing, since read is pending in consumer + + let responseBody2 = HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(.init(string: "bar")), endStream: true)) + XCTAssertNoThrow(try self.channel.writeInbound(HTTP2Frame(streamID: streamID, payload: responseBody2))) + XCTAssertEqual(consumer.receivedFrames.count, 2) + XCTAssertEqual(consumer.readCompleteCount, 2) + XCTAssertEqual(consumer.readCount, 3) + XCTAssertEqual(consumer.readPending, true) + + // 4. signal stream is closed – this should force forward all pending frames + + XCTAssertEqual(consumer.channelInactiveCount, 0) + self.channel.pipeline.fireUserInboundEventTriggered(StreamClosedEvent(streamID: streamID, reason: nil)) + XCTAssertEqual(consumer.receivedFrames.count, 3) + XCTAssertEqual(consumer.readCompleteCount, 3) + XCTAssertEqual(consumer.readCount, 3) + XCTAssertEqual(consumer.channelInactiveCount, 1) + XCTAssertEqual(consumer.readPending, true) + } } private final class ErrorRecorder: ChannelInboundHandler { @@ -2004,3 +2077,58 @@ private final class ErrorRecorder: ChannelInboundHandler { context.fireErrorCaught(error) } } + +private final class ReadAndFrameConsumer: ChannelInboundHandler, ChannelOutboundHandler { + typealias InboundIn = HTTP2Frame.FramePayload + typealias OutboundIn = HTTP2Frame.FramePayload + + private(set) var receivedFrames: [HTTP2Frame.FramePayload] = [] + private(set) var readCount = 0 + private(set) var readCompleteCount = 0 + private(set) var channelInactiveCount = 0 + private(set) var readPending = false + + var forwardRead = true { + didSet { + if self.forwardRead, self.readPending { + self.context.read() + self.readPending = false + } + } + } + + var context: ChannelHandlerContext! + + func handlerAdded(context: ChannelHandlerContext) { + self.context = context + } + + func handlerRemoved(context: ChannelHandlerContext) { + self.context = context + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + self.receivedFrames.append(self.unwrapInboundIn(data)) + context.fireChannelRead(data) + } + + func channelReadComplete(context: ChannelHandlerContext) { + self.readCompleteCount += 1 + context.fireChannelReadComplete() + } + + func channelInactive(context: ChannelHandlerContext) { + self.channelInactiveCount += 1 + context.fireChannelInactive() + } + + func read(context: ChannelHandlerContext) { + self.readCount += 1 + if forwardRead { + context.read() + self.readPending = false + } else { + self.readPending = true + } + } +}