Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ public void inboundDataReceived(ReadableBuffer frame, boolean endOfStream) {
*/
public final void transportReportStatus(final Status status) {
Preconditions.checkArgument(!status.isOk(), "status must not be OK");
onStreamDeallocated();
if (deframerClosed) {
deframerClosedTask = null;
closeListener(status);
Expand All @@ -300,6 +301,7 @@ public void run() {
* #transportReportStatus}.
*/
public void complete() {
onStreamDeallocated();
if (deframerClosed) {
deframerClosedTask = null;
closeListener(Status.OK);
Expand Down Expand Up @@ -335,7 +337,6 @@ private void closeListener(Status newStatus) {
getTransportTracer().reportStreamClosed(closedStatus.isOk());
}
listenerClosed = true;
onStreamDeallocated();
listener().closed(newStatus);
}
}
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/java/io/grpc/internal/AbstractStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ protected final void onStreamDeallocated() {
}
}

protected boolean isStreamDeallocated() {
synchronized (onReadyLock) {
return deallocated;
}
}

/**
* Event handler to be called by the subclass when a number of bytes are being queued for
* sending to the remote endpoint.
Expand Down
39 changes: 26 additions & 13 deletions netty/src/main/java/io/grpc/netty/NettyClientStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,20 +182,10 @@ private void writeFrameInternal(
if (numBytes > 0) {
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
ChannelFutureListener failureListener =
future -> transportState().onWriteFrameData(future, numMessages, numBytes);
writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, endOfStream), flush)
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
// If the future succeeds when http2stream is null, the stream has been cancelled
// before it began and Netty is purging pending writes from the flow-controller.
if (future.isSuccess() && transportState().http2Stream() != null) {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
transportState().onSentBytes(numBytes);
NettyClientStream.this.getTransportTracer().reportMessageSent(numMessages);
}
}
});
.addListener(failureListener);
} else {
// The frame is empty and will not impact outbound flow control. Just send it.
writeQueue.enqueue(
Expand Down Expand Up @@ -306,6 +296,29 @@ protected void http2ProcessingFailed(Status status, boolean stopDelivery, Metada
handler.getWriteQueue().enqueue(new CancelClientStreamCommand(this, status), true);
}

protected void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) {
// If the future succeeds when http2stream is null, the stream has been cancelled
// before it began and Netty is purging pending writes from the flow-controller.
if (future.isSuccess() && http2Stream() == null) {
return;
}

if (future.isSuccess()) {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
onSentBytes(numBytes);
getTransportTracer().reportMessageSent(numMessages);
} else if (!isStreamDeallocated()) {
// Future failed, fail RPC.
// Normally we don't need to do anything here because the cause of a failed future
// while writing DATA frames would be an IO error and the stream is already closed.
// However, we still need handle any unexpected failures raised in Netty.
// Note: isStreamDeallocated() protects from spamming stream resets by scheduling multiple
// CancelClientStreamCommand commands.
http2ProcessingFailed(statusFromFailedFuture(future), true, new Metadata());
}
}

@Override
public void runOnTransportThread(final Runnable r) {
if (eventLoop.inEventLoop()) {
Expand Down
73 changes: 47 additions & 26 deletions netty/src/main/java/io/grpc/netty/NettyServerStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,48 +96,46 @@ private class Sink implements AbstractServerStream.Sink {
@Override
public void writeHeaders(Metadata headers, boolean flush) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeHeaders")) {
writeQueue.enqueue(
SendResponseHeadersCommand.createHeaders(
transportState(),
Utils.convertServerHeaders(headers)),
flush);
Http2Headers http2headers = Utils.convertServerHeaders(headers);
SendResponseHeadersCommand headersCommand =
SendResponseHeadersCommand.createHeaders(transportState(), http2headers);
writeQueue.enqueue(headersCommand, true)
.addListener((ChannelFutureListener) transportState()::handleWriteFutureFailures);
}
}

private void writeFrameInternal(WritableBuffer frame, boolean flush, final int numMessages) {
Preconditions.checkArgument(numMessages >= 0);
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf().touch();
final int numBytes = bytebuf.readableBytes();
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, false), flush)
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
@Override
public void writeFrame(WritableBuffer frame, boolean flush, final int numMessages) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeFrame")) {
Preconditions.checkArgument(numMessages >= 0);
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf().touch();
final int numBytes = bytebuf.readableBytes();
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, false), flush)
.addListener((ChannelFutureListener) future -> {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
// TODO(sergiitk): should onSentBytes be called only on success?
transportState().onSentBytes(numBytes);
if (future.isSuccess()) {
// TODO(sergiitk): when all moved to transport state, transportTracer becomes unused
transportTracer.reportMessageSent(numMessages);
} else {
transportState().handleWriteFutureFailures(future);
}
}
});
}

@Override
public void writeFrame(WritableBuffer frame, boolean flush, final int numMessages) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeFrame")) {
writeFrameInternal(frame, flush, numMessages);
});
}
}

@Override
public void writeTrailers(Metadata trailers, boolean headersSent, Status status) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeTrailers")) {
Http2Headers http2Trailers = Utils.convertTrailers(trailers, headersSent);
writeQueue.enqueue(
SendResponseHeadersCommand.createTrailers(transportState(), http2Trailers, status),
true);
SendResponseHeadersCommand trailersCommand =
SendResponseHeadersCommand.createTrailers(transportState(), http2Trailers, status);
writeQueue.enqueue(trailersCommand, true)
.addListener((ChannelFutureListener) transportState()::handleWriteFutureFailures);
}
}

Expand Down Expand Up @@ -206,6 +204,29 @@ public void deframeFailed(Throwable cause) {
handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true);
}

private void handleWriteFutureFailures(ChannelFuture future) {
// isStreamDeallocated() check protects from spamming stream resets by scheduling multiple
// CancelServerStreamCommand commands.
if (future.isSuccess() || isStreamDeallocated()) {
return;
}
// Future failed, fail RPC.
// Normally we don't need to do anything here because the cause of a failed future
// while writing DATA frames would be an IO error and the stream is already closed.
// However, we still need handle any unexpected failures raised in Netty.
// TODO(sergiitk): check if something similar to
// io.grpc.netty.NettyClientTransport#statusFromFailedFuture is needed.
http2ProcessingFailed(Utils.statusFromThrowable(future.cause()));
}

/**
* Called to process a failure in HTTP/2 processing.
*/
protected void http2ProcessingFailed(Status status) {
transportReportStatus(status);
handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true);
}

void inboundDataReceived(ByteBuf frame, boolean endOfStream) {
super.inboundDataReceived(new NettyReadableBuffer(frame.retain()), endOfStream);
}
Expand Down
42 changes: 42 additions & 0 deletions netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
package io.grpc.netty;

import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.NettyTestUtil.messageFrame;
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
import static io.grpc.netty.Utils.STATUS_OK;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static io.netty.util.CharsetUtil.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
Expand All @@ -34,10 +37,12 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
Expand All @@ -62,6 +67,7 @@
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.Http2Exception;
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.util.AsciiString;
import java.io.BufferedInputStream;
Expand Down Expand Up @@ -205,6 +211,42 @@ public void writeMessageShouldSendRequestUnknownLength() throws Exception {
eq(true));
}

@Test
public void writeFrameFutureFailedShouldCancelRpc() {
Http2Exception h2Error = connectionError(PROTOCOL_ERROR, "Stream does not exist %d", STREAM_ID);
when(writeQueue.enqueue(any(SendGrpcFrameCommand.class), eq(true))).thenReturn(
new DefaultChannelPromise(channel).setFailure(h2Error));

// Force stream creation.
stream().transportState().setId(STREAM_ID);
// TODO(sergiitk): multiple messages?
stream.writeMessage(new ByteArrayInputStream(smallMessage()));
stream.flush();

// Spot-check single-frame message after stream creation.
verify(writeQueue, times(1)).enqueue(any(CreateStreamCommand.class), eq(false));
verify(writeQueue, times(1)).enqueue(any(SendGrpcFrameCommand.class), eq(true));
verify(writeQueue, times(1)).enqueue(any(CancelClientStreamCommand.class), eq(true));
verifyNoMoreInteractions(writeQueue);

ArgumentCaptor<QueuedCommand> commandCaptor = ArgumentCaptor.forClass(QueuedCommand.class);
verify(writeQueue, atLeastOnce()).enqueue(commandCaptor.capture(), eq(true));

// Check the last call to be CancelClientStreamCommand.
QueuedCommand command = commandCaptor.getValue();
assertWithMessage("Expected last command on the stream to be CancelClientStreamCommand")
.that(command)
.isInstanceOf(CancelClientStreamCommand.class);
CancelClientStreamCommand cancelCommand = (CancelClientStreamCommand) command;
// Check connection error info is propagated via Status.
Status cancelReason = cancelCommand.reason();
assertThat(cancelReason.getCode()).isEqualTo(Status.INTERNAL.getCode());
assertThat(cancelReason.getCause()).isEqualTo(h2Error);
// Verify listener closed.
// should we expect REFUSED/MISCARRIED instead?
verify(listener).closed(same(cancelReason), eq(PROCESSED), any(Metadata.class));
}

@Test
public void setStatusWithOkShouldCloseStream() {
stream().transportState().setId(STREAM_ID);
Expand Down
63 changes: 63 additions & 0 deletions netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
package io.grpc.netty;

import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.NettyTestUtil.messageFrame;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static org.junit.Assert.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
Expand All @@ -27,6 +30,7 @@
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
Expand All @@ -41,9 +45,12 @@
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.StreamListener;
import io.grpc.internal.TransportTracer;
import io.grpc.netty.WriteQueue.QueuedCommand;
import io.netty.buffer.EmptyByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.Http2Exception;
import io.netty.util.AsciiString;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
Expand Down Expand Up @@ -124,6 +131,62 @@ public void writeMessageShouldSendResponse() throws Exception {
eq(true));
}

@Test
public void writeFrameFutureFailedShouldCancelRpc() {
Http2Exception h2Error = connectionError(PROTOCOL_ERROR, "Stream does not exist %d", STREAM_ID);
when(writeQueue.enqueue(any(SendGrpcFrameCommand.class), eq(true))).thenReturn(
new DefaultChannelPromise(channel).setFailure(h2Error));

// Single-frame message.
stream.writeMessage(new ByteArrayInputStream(smallMessage()));
stream.flush();
verifyWriteFutureFailure(h2Error, SendGrpcFrameCommand.class);
}

@Test
public void writeHeadersFutureFailedShouldCancelRpc() {
Http2Exception h2Error = connectionError(PROTOCOL_ERROR, "Stream does not exist %d", STREAM_ID);
when(writeQueue.enqueue(any(SendResponseHeadersCommand.class), eq(true))).thenReturn(
new DefaultChannelPromise(channel).setFailure(h2Error));

stream().writeHeaders(new Metadata(), true);
stream.flush();
verifyWriteFutureFailure(h2Error, SendResponseHeadersCommand.class);
}

@Test
public void writeTrailersFutureFailedShouldCancelRpc() {
Http2Exception h2Error = connectionError(PROTOCOL_ERROR, "Stream does not exist %d", STREAM_ID);
when(writeQueue.enqueue(any(SendResponseHeadersCommand.class), eq(true))).thenReturn(
new DefaultChannelPromise(channel).setFailure(h2Error));

stream().close(Status.OK, trailers);
verifyWriteFutureFailure(h2Error, SendResponseHeadersCommand.class);
}

private void verifyWriteFutureFailure(
Http2Exception h2Error, Class<? extends QueuedCommand> failedCommand) {
verify(writeQueue, times(1)).enqueue(any(failedCommand), eq(true));
verify(writeQueue, times(1)).enqueue(any(CancelServerStreamCommand.class), eq(true));
verifyNoMoreInteractions(writeQueue);

ArgumentCaptor<QueuedCommand> commandCaptor = ArgumentCaptor.forClass(QueuedCommand.class);
verify(writeQueue, atLeastOnce()).enqueue(commandCaptor.capture(), eq(true));

// Check the last call to be CancelClientStreamCommand.
QueuedCommand command = commandCaptor.getValue();
assertWithMessage("Expected last command on the stream to be CancelServerStreamCommand")
.that(command)
.isInstanceOf(CancelServerStreamCommand.class);
CancelServerStreamCommand cancelCommand = (CancelServerStreamCommand) command;
// Check connection error info is propagated via Status.
Status cancelReason = cancelCommand.reason();
assertThat(cancelReason.getCode()).isEqualTo(Status.INTERNAL.getCode());
assertThat(cancelReason.getCause()).isEqualTo(h2Error);
// Listener closed.
verify(serverListener).closed(same(cancelReason));
}

@Test
public void writeHeadersShouldSendHeaders() throws Exception {
Metadata headers = new Metadata();
Expand Down