Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward all reads before stream channel inactive #317

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions Sources/NIOHTTP2/HTTP2StreamChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -812,8 +812,13 @@ internal extension HTTP2StreamChannel {
// Avoid emitting any WINDOW_UPDATE frames now that we're closed.
self.windowManager.closed = true

// The stream is closed, we should aim to deliver any read frames we have for it.
self.tryToRead()
// The stream is closed, we should force forward all pending frames, even without
// unsatisfied read, to ensure the handlers can see all frames before receiving
// channelInactive.
if self.pendingReads.count > 0 && self._isActive {
self.unsatisfiedRead = false
self.deliverPendingReads()
}

if let reason = reason {
// To receive from the network, it must be safe to force-unwrap here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ extension HTTP2FramePayloadStreamMultiplexerTests {
("testWindowUpdateIsNotEmittedAfterStreamIsClosedEvenOnLaterFrame", testWindowUpdateIsNotEmittedAfterStreamIsClosedEvenOnLaterFrame),
("testStreamChannelSupportsSyncOptions", testStreamChannelSupportsSyncOptions),
("testStreamErrorIsDeliveredToChannel", testStreamErrorIsDeliveredToChannel),
("testPendingReadsAreFlushedEvenWithoutUnsatisfiedReadOnChannelInactive", testPendingReadsAreFlushedEvenWithoutUnsatisfiedReadOnChannelInactive),
]
}
}
Expand Down
128 changes: 128 additions & 0 deletions Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1992,6 +1992,79 @@ final class HTTP2FramePayloadStreamMultiplexerTests: XCTestCase {
frames[0].assertHeadersFrame(endStream: false, streamID: 1, headers: goodHeaders, priority: nil, type: .request)
frames[1].assertHeadersFrame(endStream: false, streamID: 3, headers: badHeaders, priority: nil, type: .doNotValidate)
}

func testPendingReadsAreFlushedEvenWithoutUnsatisfiedReadOnChannelInactive() throws {
let goodHeaders = HPACKHeaders([
(":path", "/"), (":method", "GET"), (":scheme", "https"), (":authority", "localhost")
])

let multiplexer = HTTP2StreamMultiplexer(mode: .client, channel: self.channel) { channel in
XCTFail("Server push is unexpected")
return channel.eventLoop.makeSucceededFuture(())
}
XCTAssertNoThrow(try self.channel.pipeline.addHandler(multiplexer).wait())

// We need to activate the underlying channel here.
XCTAssertNoThrow(try self.channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 80)).wait())

// Now create two child channels with error recording handlers in them. Save one, ignore the other.
let consumer = ReadAndFrameConsumer()
var childChannel: Channel!
multiplexer.createStreamChannel(promise: nil) { channel in
childChannel = channel
return channel.pipeline.addHandler(consumer)
}
self.channel.embeddedEventLoop.run()

let streamID = HTTP2StreamID(1)

let payload = HTTP2Frame.FramePayload.Headers(headers: goodHeaders, endStream: true)
XCTAssertNoThrow(try childChannel.writeAndFlush(HTTP2Frame.FramePayload.headers(payload)).wait())

let frames = try self.channel.sentFrames()
XCTAssertEqual(frames.count, 1)
frames.first?.assertHeadersFrameMatches(this: HTTP2Frame(streamID: streamID, payload: .headers(payload)))

XCTAssertEqual(consumer.readCount, 1)

// 1. pass header onwards

let responseHeaderPayload = HTTP2Frame.FramePayload.headers(.init(headers: [":status": "200"]))
XCTAssertNoThrow(try self.channel.writeInbound(HTTP2Frame(streamID: streamID, payload: responseHeaderPayload)))
XCTAssertEqual(consumer.receivedFrames.count, 1)
XCTAssertEqual(consumer.readCompleteCount, 1)
XCTAssertEqual(consumer.readCount, 2)

consumer.forwardRead = false

// 2. pass body onwards

let responseBody1 = HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(.init(string: "foo"))))
XCTAssertNoThrow(try self.channel.writeInbound(HTTP2Frame(streamID: streamID, payload: responseBody1)))
XCTAssertEqual(consumer.receivedFrames.count, 2)
XCTAssertEqual(consumer.readCompleteCount, 2)
XCTAssertEqual(consumer.readCount, 3)
XCTAssertEqual(consumer.readPending, true)

// 3. pass on more body - should not change a thing, since read is pending in consumer

let responseBody2 = HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(.init(string: "bar")), endStream: true))
XCTAssertNoThrow(try self.channel.writeInbound(HTTP2Frame(streamID: streamID, payload: responseBody2)))
XCTAssertEqual(consumer.receivedFrames.count, 2)
XCTAssertEqual(consumer.readCompleteCount, 2)
XCTAssertEqual(consumer.readCount, 3)
XCTAssertEqual(consumer.readPending, true)

// 4. signal stream is closed – this should force forward all pending frames

XCTAssertEqual(consumer.channelInactiveCount, 0)
self.channel.pipeline.fireUserInboundEventTriggered(StreamClosedEvent(streamID: streamID, reason: nil))
XCTAssertEqual(consumer.receivedFrames.count, 3)
XCTAssertEqual(consumer.readCompleteCount, 3)
XCTAssertEqual(consumer.readCount, 3)
XCTAssertEqual(consumer.channelInactiveCount, 1)
XCTAssertEqual(consumer.readPending, true)
}
}

private final class ErrorRecorder: ChannelInboundHandler {
Expand All @@ -2004,3 +2077,58 @@ private final class ErrorRecorder: ChannelInboundHandler {
context.fireErrorCaught(error)
}
}

private final class ReadAndFrameConsumer: ChannelInboundHandler, ChannelOutboundHandler {
typealias InboundIn = HTTP2Frame.FramePayload
typealias OutboundIn = HTTP2Frame.FramePayload

private(set) var receivedFrames: [HTTP2Frame.FramePayload] = []
private(set) var readCount = 0
private(set) var readCompleteCount = 0
private(set) var channelInactiveCount = 0
private(set) var readPending = false

var forwardRead = true {
didSet {
if self.forwardRead, self.readPending {
self.context.read()
self.readPending = false
}
}
}

var context: ChannelHandlerContext!

func handlerAdded(context: ChannelHandlerContext) {
self.context = context
}

func handlerRemoved(context: ChannelHandlerContext) {
self.context = context
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
self.receivedFrames.append(self.unwrapInboundIn(data))
context.fireChannelRead(data)
}

func channelReadComplete(context: ChannelHandlerContext) {
self.readCompleteCount += 1
context.fireChannelReadComplete()
}

func channelInactive(context: ChannelHandlerContext) {
self.channelInactiveCount += 1
context.fireChannelInactive()
}

func read(context: ChannelHandlerContext) {
self.readCount += 1
if forwardRead {
context.read()
self.readPending = false
} else {
self.readPending = true
}
}
}