Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 4 additions & 32 deletions internal/transport/controlbuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,13 @@ import (
"fmt"
"net"
"runtime"
"strconv"
"sync"
"sync/atomic"

"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/pretty"
istatus "google.golang.org/grpc/internal/status"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)

var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
Expand Down Expand Up @@ -150,11 +144,9 @@ type cleanupStream struct {
func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM

type earlyAbortStream struct {
httpStatus uint32
streamID uint32
contentSubtype string
status *status.Status
rst bool
streamID uint32
rst bool
hf []hpack.HeaderField // Pre-built header fields
}

func (*earlyAbortStream) isTransportResponseFrame() bool { return false }
Expand Down Expand Up @@ -846,27 +838,7 @@ func (l *loopyWriter) earlyAbortStreamHandler(eas *earlyAbortStream) error {
if l.side == clientSide {
return errors.New("earlyAbortStream not handled on client")
}
// In case the caller forgets to set the http status, default to 200.
if eas.httpStatus == 0 {
eas.httpStatus = 200
}
headerFields := []hpack.HeaderField{
{Name: ":status", Value: strconv.Itoa(int(eas.httpStatus))},
{Name: "content-type", Value: grpcutil.ContentType(eas.contentSubtype)},
{Name: "grpc-status", Value: strconv.Itoa(int(eas.status.Code()))},
{Name: "grpc-message", Value: encodeGrpcMessage(eas.status.Message())},
}

if p := istatus.RawStatusProto(eas.status); len(p.GetDetails()) > 0 {
stBytes, err := proto.Marshal(p)
if err != nil {
l.logger.Errorf("Failed to marshal rpc status: %s, error: %v", pretty.ToJSON(p), err)
} else {
headerFields = append(headerFields, hpack.HeaderField{Name: grpcStatusDetailsBinHeader, Value: encodeBinHeader(stBytes)})
}
}

if err := l.writeHeader(eas.streamID, true, headerFields, nil); err != nil {
if err := l.writeHeader(eas.streamID, true, eas.hf, nil); err != nil {
return err
}
if eas.rst {
Expand Down
94 changes: 47 additions & 47 deletions internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,13 +479,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
if t.logger.V(logLevel) {
t.logger.Infof("Aborting the stream early: %v", errMsg)
}
t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusBadRequest,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.New(codes.Internal, errMsg),
rst: !frame.StreamEnded(),
})
t.writeEarlyAbort(streamID, s.contentSubtype, status.New(codes.Internal, errMsg), http.StatusBadRequest, !frame.StreamEnded())
return nil
}

Expand All @@ -499,23 +493,11 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
return nil
}
if !isGRPC {
t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusUnsupportedMediaType,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.Newf(codes.InvalidArgument, "invalid gRPC request content-type %q", contentType),
rst: !frame.StreamEnded(),
})
t.writeEarlyAbort(streamID, s.contentSubtype, status.Newf(codes.InvalidArgument, "invalid gRPC request content-type %q", contentType), http.StatusUnsupportedMediaType, !frame.StreamEnded())
return nil
}
if headerError != nil {
t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusBadRequest,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: headerError,
rst: !frame.StreamEnded(),
})
t.writeEarlyAbort(streamID, s.contentSubtype, headerError, http.StatusBadRequest, !frame.StreamEnded())
return nil
}

Expand Down Expand Up @@ -569,13 +551,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
if t.logger.V(logLevel) {
t.logger.Infof("Aborting the stream early: %v", errMsg)
}
t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusMethodNotAllowed,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.New(codes.Internal, errMsg),
rst: !frame.StreamEnded(),
})
t.writeEarlyAbort(streamID, s.contentSubtype, status.New(codes.Internal, errMsg), http.StatusMethodNotAllowed, !frame.StreamEnded())
s.cancel()
return nil
}
Expand All @@ -590,27 +566,16 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
if !ok {
stat = status.New(codes.PermissionDenied, err.Error())
}
t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusOK,
streamID: s.id,
contentSubtype: s.contentSubtype,
status: stat,
rst: !frame.StreamEnded(),
})
t.writeEarlyAbort(s.id, s.contentSubtype, stat, http.StatusOK, !frame.StreamEnded())
return nil
}
}

if s.ctx.Err() != nil {
t.mu.Unlock()
st := status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
// Early abort in case the timeout was zero or so low it already fired.
t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusOK,
streamID: s.id,
contentSubtype: s.contentSubtype,
status: status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()),
rst: !frame.StreamEnded(),
})
t.writeEarlyAbort(s.id, s.contentSubtype, st, http.StatusOK, !frame.StreamEnded())
return nil
}

Expand Down Expand Up @@ -969,13 +934,12 @@ func appendHeaderFieldsFromMD(headerFields []hpack.HeaderField, md metadata.MD)
return headerFields
}

func (t *http2Server) checkForHeaderListSize(it any) bool {
func (t *http2Server) checkForHeaderListSize(hf []hpack.HeaderField) bool {
if t.maxSendHeaderListSize == nil {
return true
}
hdrFrame := it.(*headerFrame)
var sz int64
for _, f := range hdrFrame.hf {
for _, f := range hf {
if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) {
if t.logger.V(logLevel) {
t.logger.Infof("Header list size to send violates the maximum size (%d bytes) set by client", *t.maxSendHeaderListSize)
Expand All @@ -986,6 +950,42 @@ func (t *http2Server) checkForHeaderListSize(it any) bool {
return true
}

// writeEarlyAbort sends an early abort response with the given HTTP status and
// gRPC status. If the header list size exceeds the peer's limit, it sends a
// RST_STREAM instead.
func (t *http2Server) writeEarlyAbort(streamID uint32, contentSubtype string, stat *status.Status, httpStatus uint32, rst bool) {
hf := []hpack.HeaderField{
{Name: ":status", Value: strconv.Itoa(int(httpStatus))},
{Name: "content-type", Value: grpcutil.ContentType(contentSubtype)},
{Name: "grpc-status", Value: strconv.Itoa(int(stat.Code()))},
{Name: "grpc-message", Value: encodeGrpcMessage(stat.Message())},
}
if p := istatus.RawStatusProto(stat); len(p.GetDetails()) > 0 {
stBytes, err := proto.Marshal(p)
if err != nil {
t.logger.Errorf("Failed to marshal rpc status: %s, error: %v", pretty.ToJSON(p), err)
}
if err == nil {
Comment thread
arjan-bal marked this conversation as resolved.
hf = append(hf, hpack.HeaderField{Name: grpcStatusDetailsBinHeader, Value: encodeBinHeader(stBytes)})
}
}
success, _ := t.controlBuf.executeAndPut(func() bool {
return t.checkForHeaderListSize(hf)
}, &earlyAbortStream{
streamID: streamID,
rst: rst,
hf: hf,
})
if !success {
t.controlBuf.put(&cleanupStream{
streamID: streamID,
rst: true,
rstCode: http2.ErrCodeInternal,
onWrite: func() {},
})
}
}

func (t *http2Server) streamContextErr(s *ServerStream) error {
select {
case <-t.done:
Expand Down Expand Up @@ -1041,7 +1041,7 @@ func (t *http2Server) writeHeaderLocked(s *ServerStream) error {
endStream: false,
onWrite: t.setResetPingStrikes,
}
success, err := t.controlBuf.executeAndPut(func() bool { return t.checkForHeaderListSize(hf) }, hf)
success, err := t.controlBuf.executeAndPut(func() bool { return t.checkForHeaderListSize(hf.hf) }, hf)
if !success {
if err != nil {
return err
Expand Down Expand Up @@ -1111,7 +1111,7 @@ func (t *http2Server) writeStatus(s *ServerStream, st *status.Status) error {
}

success, err := t.controlBuf.executeAndPut(func() bool {
return t.checkForHeaderListSize(trailingHeader)
return t.checkForHeaderListSize(trailingHeader.hf)
}, nil)
if !success {
if err != nil {
Expand Down
37 changes: 37 additions & 0 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6134,6 +6134,43 @@ func testClientMaxHeaderListSizeServerIntentionalViolation(t *testing.T, e env)
}
}

func (s) TestEarlyAbortStreamHeaderListSizeCheck(t *testing.T) {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
s := grpc.NewServer()
defer s.Stop()
go s.Serve(lis)

conn, err := net.DialTimeout("tcp", lis.Addr().String(), defaultTestTimeout)
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer conn.Close()
st := newServerTesterFromConn(t, conn)

// Set a very small MaxHeaderListSize that any response headers would violate.
st.greetWithSettings(http2.Setting{ID: http2.SettingMaxHeaderListSize, Val: 1})

// Send a request with an invalid content-type to trigger early abort.
st.writeHeaders(http2.HeadersFrameParam{
StreamID: 1,
BlockFragment: st.encodeHeader(
":method", "POST",
":path", "/grpc.testing.TestService/UnaryCall",
"content-type", "text/plain", // Invalid content-type to trigger early abort
"te", "trailers",
),
EndStream: true,
EndHeaders: true,
})

// We should receive a RST_STREAM with ErrCodeInternal because the response
// headers exceed the MaxHeaderListSize limit.
st.wantRSTStream(http2.ErrCodeInternal)
}

func (s) TestNetPipeConn(t *testing.T) {
// This test will block indefinitely if grpc writes both client and server
// prefaces without either reading from the Conn.
Expand Down
13 changes: 12 additions & 1 deletion test/servertester.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,19 @@ func (st *serverTester) readFrame() (http2.Frame, error) {
// greet initiates the client's HTTP/2 connection into a state where
// frames may be sent.
func (st *serverTester) greet() {
st.greetWithSettings()
}

// greetWithSettings initiates the client's HTTP/2 connection with custom settings.
func (st *serverTester) greetWithSettings(settings ...http2.Setting) {
st.writePreface()
st.writeInitialSettings()
if len(settings) > 0 {
if err := st.fr.WriteSettings(settings...); err != nil {
st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err)
}
} else {
st.writeInitialSettings()
}
st.wantSettings()
st.writeSettingsAck()
for {
Expand Down