diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index ee2e8c21e533..171e690a3f22 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -1242,7 +1242,8 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { statusCode = codes.DeadlineExceeded } } - t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil, false) + st := status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode) + t.closeStream(s, st.Err(), false, http2.ErrCodeNo, st, nil, false) } func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) { diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 8b1219597912..d2c4fac3443c 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -919,8 +919,9 @@ func (s) TestLargeMessageSuspension(t *testing.T) { } // The server will send an RST stream frame on observing the deadline // expiration making the client stream fail with a DeadlineExceeded status. - if _, err := s.readTo(make([]byte, 8)); err != io.EOF { - t.Fatalf("Read got unexpected error: %v, want %v", err, io.EOF) + _, err = s.readTo(make([]byte, 8)) + if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded { + t.Fatalf("Read got unexpected error: %v, want status with code %v", err, codes.DeadlineExceeded) } if got, want := s.Status().Code(), codes.DeadlineExceeded; got != want { t.Fatalf("Read got status %v with code %v, want %v", s.Status(), got, want) diff --git a/test/transport_test.go b/test/transport_test.go index 2f61a6c03729..10e9ab57de30 100644 --- a/test/transport_test.go +++ b/test/transport_test.go @@ -19,16 +19,20 @@ package test import ( "context" + "encoding/binary" "io" "net" "sync" "testing" + "golang.org/x/net/http2" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/status" @@ -153,3 +157,75 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) { t.Fatal("Timeout expired when waiting for first client transport to close") } } + +// Tests that an RST_STREAM frame that causes an io.ErrUnexpectedEOF while +// reading a gRPC message is correctly converted to a gRPC status with code +// CANCELLED. The test sends a data frame with a partial gRPC message, followed +// by an RST_STREAM frame with HTTP/2 code CANCELLED. The test asserts the +// client receives the correct status. +func (s) TestRSTDuringMessageRead(t *testing.T) { + lis, err := testutils.LocalTCPListener() + if err != nil { + t.Fatal(err) + } + defer lis.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%s) = %v", lis.Addr().String(), err) + } + defer cc.Close() + + go func() { + conn, err := lis.Accept() + if err != nil { + t.Errorf("lis.Accept() = %v", err) + return + } + defer conn.Close() + framer := http2.NewFramer(conn, conn) + + if _, err := io.ReadFull(conn, make([]byte, len(clientPreface))); err != nil { + t.Errorf("Error while reading client preface: %v", err) + return + } + if err := framer.WriteSettings(); err != nil { + t.Errorf("Error while writing settings: %v", err) + return + } + if err := framer.WriteSettingsAck(); err != nil { + t.Errorf("Error while writing settings: %v", err) + return + } + for ctx.Err() == nil { + frame, err := framer.ReadFrame() + if err != nil { + return + } + switch frame := frame.(type) { + case *http2.HeadersFrame: + // When the client creates a stream, write a partial gRPC + // message followed by an RST_STREAM. + const messageLen = 2048 + buf := make([]byte, messageLen/2) + // Write the gRPC message length header. + binary.BigEndian.PutUint32(buf[1:5], uint32(messageLen)) + if err := framer.WriteData(1, false, buf); err != nil { + return + } + framer.WriteRSTStream(1, http2.ErrCodeCancel) + default: + t.Logf("Server received frame: %v", frame) + } + } + }() + + // The server will send a partial gRPC message before cancelling the stream. + // The client should get a gRPC status with code CANCELLED. + client := testgrpc.NewTestServiceClient(cc) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Canceled { + t.Fatalf("client.EmptyCall() returned %v; want status with code %v", err, codes.Canceled) + } +}