Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,8 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
statusCode = codes.DeadlineExceeded
}
}
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil, false)
st := status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode)
t.closeStream(s, st.Err(), false, http2.ErrCodeNo, st, nil, false)
}

func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) {
Expand Down
5 changes: 3 additions & 2 deletions internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -919,8 +919,9 @@ func (s) TestLargeMessageSuspension(t *testing.T) {
}
// The server will send an RST stream frame on observing the deadline
// expiration making the client stream fail with a DeadlineExceeded status.
if _, err := s.readTo(make([]byte, 8)); err != io.EOF {
t.Fatalf("Read got unexpected error: %v, want %v", err, io.EOF)
_, err = s.readTo(make([]byte, 8))
if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded {
t.Fatalf("Read got unexpected error: %v, want status with code %v", err, codes.DeadlineExceeded)
}
if got, want := s.Status().Code(), codes.DeadlineExceeded; got != want {
t.Fatalf("Read got status %v with code %v, want %v", s.Status(), got, want)
Expand Down
76 changes: 76 additions & 0 deletions test/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@ package test

import (
"context"
"encoding/binary"
"io"
"net"
"sync"
"testing"

"golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/status"

Expand Down Expand Up @@ -153,3 +157,75 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {
t.Fatal("Timeout expired when waiting for first client transport to close")
}
}

// Tests that an RST_STREAM frame that causes an io.ErrUnexpectedEOF while
// reading a gRPC message is correctly converted to a gRPC status with code
// CANCELLED. The test sends a data frame with a partial gRPC message, followed
// by an RST_STREAM frame with HTTP/2 code CANCELLED. The test asserts the
// client receives the correct status.
func (s) TestRSTDuringMessageRead(t *testing.T) {
lis, err := testutils.LocalTCPListener()
if err != nil {
t.Fatal(err)
}
defer lis.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%s) = %v", lis.Addr().String(), err)
}
defer cc.Close()

go func() {
conn, err := lis.Accept()
if err != nil {
t.Errorf("lis.Accept() = %v", err)
return
}
defer conn.Close()
framer := http2.NewFramer(conn, conn)

if _, err := io.ReadFull(conn, make([]byte, len(clientPreface))); err != nil {
t.Errorf("Error while reading client preface: %v", err)
return
}
if err := framer.WriteSettings(); err != nil {
t.Errorf("Error while writing settings: %v", err)
return
}
if err := framer.WriteSettingsAck(); err != nil {
t.Errorf("Error while writing settings: %v", err)
return
}
for ctx.Err() == nil {
frame, err := framer.ReadFrame()
if err != nil {
return
}
switch frame := frame.(type) {
case *http2.HeadersFrame:
// When the client creates a stream, write a partial gRPC
// message followed by an RST_STREAM.
const messageLen = 2048
buf := make([]byte, messageLen/2)
// Write the gRPC message length header.
binary.BigEndian.PutUint32(buf[1:5], uint32(messageLen))
if err := framer.WriteData(1, false, buf); err != nil {
return
}
framer.WriteRSTStream(1, http2.ErrCodeCancel)
default:
t.Logf("Server received frame: %v", frame)
}
}
}()

// The server will send a partial gRPC message before cancelling the stream.
// The client should get a gRPC status with code CANCELLED.
client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Canceled {
t.Fatalf("client.EmptyCall() returned %v; want status with code %v", err, codes.Canceled)
}
}