Skip to content

Commit

Permalink
Fixes #10277 - Review read failures impacting writes.
Browse files Browse the repository at this point in the history
Separated read failures from write failures.
In this way it is possible to read even if the write side is failed and write even if the read side is failed.

Signed-off-by: Simone Bordet <[email protected]>
  • Loading branch information
sbordet committed Nov 29, 2023
1 parent 82fbf3d commit b1d5c03
Show file tree
Hide file tree
Showing 2 changed files with 275 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.WritePendingException;
import java.util.HashMap;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -111,17 +112,8 @@ private enum StreamSendState
private long _committedContentLength = -1;
private Runnable _onContentAvailable;
private Predicate<TimeoutException> _onIdleTimeout;
/**
* Failure passed to {@link #onFailure(Throwable)}
*/
private Content.Chunk _failure;
/**
* Listener for {@link #onFailure(Throwable)} events
*/
private Content.Chunk _readFailure;
private Consumer<Throwable> _onFailure;
/**
* Failure passed to {@link ChannelCallback#failed(Throwable)}
*/
private Throwable _callbackFailure;
private Attributes _cache;

Expand Down Expand Up @@ -158,7 +150,7 @@ public void recycle()
_committedContentLength = -1;
_onContentAvailable = null;
_onIdleTimeout = null;
_failure = null;
_readFailure = null;
_onFailure = null;
_callbackFailure = null;
}
Expand Down Expand Up @@ -330,15 +322,15 @@ public Runnable onIdleTimeout(TimeoutException t)
LOG.debug("onIdleTimeout {}", this, t);

// if not already a failure,
if (_failure == null)
if (_readFailure == null)
{
// if we are currently demanding, take the onContentAvailable runnable to invoke below.
Runnable invokeOnContentAvailable = _onContentAvailable;
_onContentAvailable = null;

// If demand was in process, then arrange for the next read to return the idle timeout, if no other error
if (invokeOnContentAvailable != null)
_failure = Content.Chunk.from(t, false);
_readFailure = Content.Chunk.from(t, false);

// If a write call is in progress, take the writeCallback to fail below
Runnable invokeWriteFailure = _response.lockedFailWrite(t);
Expand Down Expand Up @@ -394,10 +386,10 @@ public Runnable onFailure(Throwable x)
}

// Set the error to arrange for any subsequent reads, demands or writes to fail.
if (_failure == null)
_failure = Content.Chunk.from(x, true);
else if (ExceptionUtil.areNotAssociated(_failure.getFailure(), x) && _failure.getFailure().getClass() != x.getClass())
_failure.getFailure().addSuppressed(x);
if (_readFailure == null)
_readFailure = Content.Chunk.from(x, true);
else if (ExceptionUtil.areNotAssociated(_readFailure.getFailure(), x) && _readFailure.getFailure().getClass() != x.getClass())
_readFailure.getFailure().addSuppressed(x);

// If not handled, then we just fail the request callback
if (!_handled && _handling == null)
Expand Down Expand Up @@ -850,8 +842,8 @@ public Content.Chunk read()
{
HttpChannelState httpChannel = lockedGetHttpChannelState();

Content.Chunk error = httpChannel._failure;
httpChannel._failure = Content.Chunk.next(error);
Content.Chunk error = httpChannel._readFailure;
httpChannel._readFailure = Content.Chunk.next(error);
if (error != null)
return error;

Expand Down Expand Up @@ -898,7 +890,7 @@ public void demand(Runnable demandCallback)
if (LOG.isDebugEnabled())
LOG.debug("demand {}", httpChannelState);

error = httpChannelState._failure != null;
error = httpChannelState._readFailure != null;
if (!error)
{
if (httpChannelState._onContentAvailable != null)
Expand Down Expand Up @@ -936,7 +928,7 @@ public void addIdleTimeoutListener(Predicate<TimeoutException> onIdleTimeout)
{
HttpChannelState httpChannel = lockedGetHttpChannelState();

if (httpChannel._failure != null)
if (httpChannel._readFailure != null)
return;

if (httpChannel._onIdleTimeout == null)
Expand All @@ -963,7 +955,7 @@ public void addFailureListener(Consumer<Throwable> onFailure)
{
HttpChannelState httpChannel = lockedGetHttpChannelState();

if (httpChannel._failure != null)
if (httpChannel._readFailure != null)
return;

if (httpChannel._onFailure == null)
Expand Down Expand Up @@ -1031,6 +1023,7 @@ public static class ChannelResponse implements Response, Callback
private long _contentBytesWritten;
private Supplier<HttpFields> _trailers;
private Callback _writeCallback;
private Throwable _writeFailure;

private ChannelResponse(ChannelRequest request)
{
Expand Down Expand Up @@ -1059,7 +1052,10 @@ private Runnable lockedFailWrite(Throwable x)
assert _request._lock.isHeldByCurrentThread();
Callback writeCallback = _writeCallback;
_writeCallback = null;
return writeCallback == null ? null : () -> writeCallback.failed(x);
if (writeCallback == null)
return null;
_writeFailure = x;
return () -> writeCallback.failed(x);
}

public long getContentBytesWritten()
Expand Down Expand Up @@ -1115,78 +1111,76 @@ public void write(boolean last, ByteBuffer content, Callback callback)
{
long length = BufferUtil.length(content);

HttpChannelState httpChannelState;
HttpChannelState httpChannel;
HttpStream stream;
Throwable failure;
Throwable writeFailure;
MetaData.Response responseMetaData = null;
try (AutoLock ignored = _request._lock.lock())
{
httpChannelState = _request.lockedGetHttpChannelState();
long committedContentLength = httpChannelState._committedContentLength;
httpChannel = _request.lockedGetHttpChannelState();
long totalWritten = _contentBytesWritten + length;
long contentLength = committedContentLength >= 0 ? committedContentLength : getHeaders().getLongField(HttpHeader.CONTENT_LENGTH);
writeFailure = _writeFailure;

if (_writeCallback != null)
{
failure = new IllegalStateException("write pending");
}
else
if (writeFailure == null)
{
failure = getFailure(httpChannelState);
if (failure == null && contentLength >= 0 && totalWritten != contentLength)
if (_writeCallback != null)
{
// If the content length were not compatible with what was written, then we need to abort.
String lengthError = null;
if (totalWritten > contentLength)
lengthError = "written %d > %d content-length";
else if (last && !(totalWritten == 0 && HttpMethod.HEAD.is(_request.getMethod())))
lengthError = "written %d < %d content-length";
if (lengthError != null)
writeFailure = new WritePendingException();
}
else
{
long committedContentLength = httpChannel._committedContentLength;
long contentLength = committedContentLength >= 0 ? committedContentLength : getHeaders().getLongField(HttpHeader.CONTENT_LENGTH);

if (contentLength >= 0 && totalWritten != contentLength)
{
String message = lengthError.formatted(totalWritten, contentLength);
if (LOG.isDebugEnabled())
LOG.debug("fail {} {}", callback, message);
failure = new IOException(message);
// If the content length were not compatible with what was written, then we need to abort.
String lengthError = null;
if (totalWritten > contentLength)
lengthError = "written %d > %d content-length";
else if (last && !(totalWritten == 0 && HttpMethod.HEAD.is(_request.getMethod())))
lengthError = "written %d < %d content-length";
if (lengthError != null)
{
String message = lengthError.formatted(totalWritten, contentLength);
if (LOG.isDebugEnabled())
LOG.debug("fail {} {}", callback, message);
writeFailure = new IOException(message);
}
}
}
}

// If no failure by this point, we can try to switch to sending state.
if (failure == null)
failure = httpChannelState.lockedStreamSend(last, length);
if (writeFailure == null)
writeFailure = httpChannel.lockedStreamSend(last, length);

if (failure == NOTHING_TO_SEND)
if (writeFailure == NOTHING_TO_SEND)
{
httpChannelState._serializedInvoker.run(callback::succeeded);
httpChannel._serializedInvoker.run(callback::succeeded);
return;
}
// Have we failed in some way?
if (failure != null)
if (writeFailure != null)
{
Throwable throwable = failure;
httpChannelState._serializedInvoker.run(() -> callback.failed(throwable));
Throwable failure = writeFailure;
httpChannel._serializedInvoker.run(() -> callback.failed(failure));
return;
}

// No failure, do the actual stream send using the ChannelResponse as the callback.
_writeCallback = callback;
_contentBytesWritten = totalWritten;
stream = httpChannelState._stream;
stream = httpChannel._stream;
if (_httpFields.commit())
responseMetaData = lockedPrepareResponse(httpChannelState, last);
responseMetaData = lockedPrepareResponse(httpChannel, last);
}

if (LOG.isDebugEnabled())
LOG.debug("writing last={} {} {}", last, BufferUtil.toDetailString(content), this);
stream.send(_request._metaData, responseMetaData, last, content, this);
}

protected Throwable getFailure(HttpChannelState httpChannelState)
{
Content.Chunk failure = httpChannelState._failure;
return failure == null ? null : failure.getFailure();
}

/**
* Called when the call to
* {@link HttpStream#send(MetaData.Request, MetaData.Response, boolean, ByteBuffer, Callback)}
Expand All @@ -1199,14 +1193,13 @@ public void succeeded()
{
if (LOG.isDebugEnabled())
LOG.debug("write succeeded {}", this);
// Called when an individual write succeeds.
Callback callback;
HttpChannelState httpChannel;
try (AutoLock ignored = _request._lock.lock())
{
httpChannel = _request.lockedGetHttpChannelState();
callback = _writeCallback;
_writeCallback = null;
httpChannel = _request.lockedGetHttpChannelState();
httpChannel.lockedStreamSendCompleted(true);
}
if (callback != null)
Expand All @@ -1227,14 +1220,14 @@ public void failed(Throwable x)
{
if (LOG.isDebugEnabled())
LOG.debug("write failed {}", this, x);
// Called when an individual write succeeds.
Callback callback;
HttpChannelState httpChannel;
try (AutoLock ignored = _request._lock.lock())
{
httpChannel = _request.lockedGetHttpChannelState();
_writeFailure = x;
callback = _writeCallback;
_writeCallback = null;
httpChannel = _request.lockedGetHttpChannelState();
httpChannel.lockedStreamSendCompleted(false);
}
if (callback != null)
Expand Down Expand Up @@ -1520,13 +1513,6 @@ public ErrorResponse(ChannelRequest request)
_status = HttpStatus.INTERNAL_SERVER_ERROR_500;
}

@Override
protected Throwable getFailure(HttpChannelState httpChannelState)
{
// we ignore channel failures so we can try to generate an error response.
return null;
}

@Override
protected ResponseHttpFields getResponseHttpFields(HttpChannelState httpChannelState)
{
Expand Down Expand Up @@ -1649,23 +1635,27 @@ private class HttpChannelSerializedInvoker extends SerializedInvoker
protected void onError(Runnable task, Throwable failure)
{
ChannelRequest request;
Content.Chunk error;
Throwable error;
boolean callbackCompleted;
try (AutoLock ignore = _lock.lock())
{
callbackCompleted = _callbackCompleted;
request = _request;
error = _request == null ? null : _failure;
error = _response == null ? null : _response._writeFailure;
if (error == null)
error = _readFailure == null ? null : _readFailure.getFailure();
}

if (request == null || callbackCompleted)
{
// It is too late to handle error, so just log it
// It is too late to handle error.
super.onError(task, failure);
return;
}
else if (error == null)

if (error == null)
{
// Try to fail the request, but we might lose a race.
// Try to fail the request, but we might lose the race.
try
{
request._callback.failed(failure);
Expand All @@ -1680,9 +1670,8 @@ else if (error == null)
{
// We are already in error, so we will not handle this one,
// but we will add as suppressed if we have not seen it already.
Throwable cause = error.getFailure();
if (ExceptionUtil.areNotAssociated(cause, failure))
error.getFailure().addSuppressed(failure);
ExceptionUtil.addSuppressedIfNotAssociated(error, failure);
super.onError(task, failure);
}
}
}
Expand Down
Loading

0 comments on commit b1d5c03

Please sign in to comment.