Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an idle write timeout #718

Merged
merged 3 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ extension Transaction {
// response body stream.
let body = TransactionBody.makeSequence(
backPressureStrategy: .init(lowWatermark: 1, highWatermark: 1),
finishOnDeinit: true,
delegate: AnyAsyncSequenceProducerDelegate(delegate)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,17 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler {
if let idleReadTimeout = newRequest.requestOptions.idleReadTimeout {
self.idleReadTimeoutStateMachine = .init(timeAmount: idleReadTimeout)
}

if let idleWriteTimeout = newRequest.requestOptions.idleWriteTimeout {
self.idleWriteTimeoutStateMachine = .init(
timeAmount: idleWriteTimeout,
isWritabilityEnabled: self.channelContext?.channel.isWritable ?? false
)
}
} else {
self.logger = self.backgroundLogger
self.idleReadTimeoutStateMachine = nil
self.idleWriteTimeoutStateMachine = nil
}
}
}
Expand All @@ -57,6 +65,14 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler {
/// We check in the task if the timer ID has changed in the meantime and do not execute any action if has changed.
private var currentIdleReadTimeoutTimerID: Int = 0

private var idleWriteTimeoutStateMachine: IdleWriteStateMachine?
private var idleWriteTimeoutTimer: Scheduled<Void>?

/// Cancelling a task in NIO does *not* guarantee that the task will not execute under certain race conditions.
/// We therefore give each timer an ID and increase the ID every time we reset or cancel it.
/// We check in the task if the timer ID has changed in the meantime and do not execute any action if has changed.
private var currentIdleWriteTimeoutTimerID: Int = 0

private let backgroundLogger: Logger
private var logger: Logger
private let eventLoop: EventLoop
Expand Down Expand Up @@ -106,6 +122,10 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler {
"ahc-channel-writable": "\(context.channel.isWritable)",
])

if let timeoutAction = self.idleWriteTimeoutStateMachine?.channelWritabilityChanged(context: context) {
self.runTimeoutAction(timeoutAction, context: context)
}

let action = self.state.writabilityChanged(writable: context.channel.isWritable)
self.run(action, context: context)
context.fireChannelWritabilityChanged()
Expand Down Expand Up @@ -150,6 +170,11 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler {
self.request = req

self.logger.debug("Request was scheduled on connection")

if let timeoutAction = self.idleWriteTimeoutStateMachine?.write() {
self.runTimeoutAction(timeoutAction, context: context)
}

req.willExecuteRequest(self)

let action = self.state.runNewRequest(
Expand Down Expand Up @@ -196,8 +221,12 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler {
request.resumeRequestBodyStream()
}
if startIdleTimer {
if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() {
self.runTimeoutAction(timeoutAction, context: context)
if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() {
self.runTimeoutAction(readTimeoutAction, context: context)
}

if let writeTimeoutAction = self.idleWriteTimeoutStateMachine?.requestEndSent() {
self.runTimeoutAction(writeTimeoutAction, context: context)
}
}
case .sendBodyPart(let part, let writePromise):
Expand All @@ -206,8 +235,12 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler {
case .sendRequestEnd(let writePromise):
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise)

if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() {
self.runTimeoutAction(timeoutAction, context: context)
if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() {
self.runTimeoutAction(readTimeoutAction, context: context)
}

if let writeTimeoutAction = self.idleWriteTimeoutStateMachine?.requestEndSent() {
self.runTimeoutAction(writeTimeoutAction, context: context)
}

case .pauseRequestBodyStream:
Expand Down Expand Up @@ -380,6 +413,40 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler {
}
}

private func runTimeoutAction(_ action: IdleWriteStateMachine.Action, context: ChannelHandlerContext) {
switch action {
case .startIdleWriteTimeoutTimer(let timeAmount):
assert(self.idleWriteTimeoutTimer == nil, "Expected there is no timeout timer so far.")

let timerID = self.currentIdleWriteTimeoutTimerID
self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) {
guard self.currentIdleWriteTimeoutTimerID == timerID else { return }
let action = self.state.idleWriteTimeoutTriggered()
self.run(action, context: context)
}
case .resetIdleWriteTimeoutTimer(let timeAmount):
if let oldTimer = self.idleWriteTimeoutTimer {
oldTimer.cancel()
}

self.currentIdleWriteTimeoutTimerID &+= 1
let timerID = self.currentIdleWriteTimeoutTimerID
self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) {
guard self.currentIdleWriteTimeoutTimerID == timerID else { return }
let action = self.state.idleWriteTimeoutTriggered()
self.run(action, context: context)
}
case .clearIdleWriteTimeoutTimer:
if let oldTimer = self.idleWriteTimeoutTimer {
self.idleWriteTimeoutTimer = nil
self.currentIdleWriteTimeoutTimerID &+= 1
oldTimer.cancel()
}
case .none:
break
}
}

// MARK: Private HTTPRequestExecutor

private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise<Void>?) {
Expand All @@ -393,6 +460,10 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler {
return
}

if let timeoutAction = self.idleWriteTimeoutStateMachine?.write() {
self.runTimeoutAction(timeoutAction, context: context)
}

let action = self.state.requestStreamPartReceived(data, promise: promise)
self.run(action, context: context)
}
Expand Down Expand Up @@ -428,6 +499,10 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler {

self.logger.trace("Request was cancelled")

if let timeoutAction = self.idleWriteTimeoutStateMachine?.cancelRequest() {
self.runTimeoutAction(timeoutAction, context: context)
}

let action = self.state.requestCancelled(closeConnection: true)
self.run(action, context: context)
}
Expand Down Expand Up @@ -540,3 +615,87 @@ struct IdleReadStateMachine {
}
}
}

struct IdleWriteStateMachine {
enum Action {
case startIdleWriteTimeoutTimer(TimeAmount)
case resetIdleWriteTimeoutTimer(TimeAmount)
case clearIdleWriteTimeoutTimer
case none
}

enum State {
case waitingForRequestEnd
case waitingForWritabilityEnabled
case requestEndSent
}

private var state: State
private let timeAmount: TimeAmount

init(timeAmount: TimeAmount, isWritabilityEnabled: Bool) {
self.timeAmount = timeAmount
if isWritabilityEnabled {
self.state = .waitingForRequestEnd
} else {
self.state = .waitingForWritabilityEnabled
}
}

mutating func cancelRequest() -> Action {
switch self.state {
case .waitingForRequestEnd, .waitingForWritabilityEnabled:
self.state = .requestEndSent
return .clearIdleWriteTimeoutTimer
case .requestEndSent:
return .none
}
}

mutating func write() -> Action {
switch self.state {
case .waitingForRequestEnd:
return .resetIdleWriteTimeoutTimer(self.timeAmount)
case .waitingForWritabilityEnabled:
return .none
case .requestEndSent:
preconditionFailure("If the request end has been sent, we can't write more data.")
}
}

mutating func requestEndSent() -> Action {
switch self.state {
case .waitingForRequestEnd:
self.state = .requestEndSent
return .clearIdleWriteTimeoutTimer
case .waitingForWritabilityEnabled:
preconditionFailure("If the channel is not writable, we can't have sent the request end.")
case .requestEndSent:
return .none
}
}

mutating func channelWritabilityChanged(context: ChannelHandlerContext) -> Action {
if context.channel.isWritable {
switch self.state {
case .waitingForRequestEnd:
preconditionFailure("If waiting for more data, the channel was already writable.")
case .waitingForWritabilityEnabled:
self.state = .waitingForRequestEnd
return .startIdleWriteTimeoutTimer(self.timeAmount)
case .requestEndSent:
return .none
}
} else {
switch self.state {
case .waitingForRequestEnd:
self.state = .waitingForWritabilityEnabled
return .clearIdleWriteTimeoutTimer
case .waitingForWritabilityEnabled:
preconditionFailure("If the channel was writable before, then we should have been waiting for more data.")
case .requestEndSent:
return .none
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,18 @@ struct HTTP1ConnectionStateMachine {
}
}

mutating func idleWriteTimeoutTriggered() -> Action {
guard case .inRequest(var requestStateMachine, let close) = self.state else {
preconditionFailure("Invalid state: \(self.state)")
}

return self.avoidingStateMachineCoW { state -> Action in
let action = requestStateMachine.idleWriteTimeoutTriggered()
state = .inRequest(requestStateMachine, close: close)
return state.modify(with: action)
}
}

mutating func headSent() -> Action {
guard case .inRequest(var requestStateMachine, let close) = self.state else {
return .wait
Expand Down
Loading