diff --git a/http2/writesched.go b/http2/writesched.go index f24d2b1e7..c7cd00173 100644 --- a/http2/writesched.go +++ b/http2/writesched.go @@ -32,7 +32,8 @@ type WriteScheduler interface { // Pop dequeues the next frame to write. Returns false if no frames can // be written. Frames with a given wr.StreamID() are Pop'd in the same - // order they are Push'd. No frames should be discarded except by CloseStream. + // order they are Push'd, except RST_STREAM frames. No frames should be + // discarded except by CloseStream. Pop() (wr FrameWriteRequest, ok bool) } @@ -52,6 +53,7 @@ type FrameWriteRequest struct { // stream is the stream on which this frame will be written. // nil for non-stream frames like PING and SETTINGS. + // nil for RST_STREAM streams, which use the StreamError.StreamID field instead. stream *stream // done, if non-nil, must be a buffered channel with space for diff --git a/http2/writesched_random.go b/http2/writesched_random.go index 9a7b9e581..f2e55e05c 100644 --- a/http2/writesched_random.go +++ b/http2/writesched_random.go @@ -45,11 +45,11 @@ func (ws *randomWriteScheduler) AdjustStream(streamID uint32, priority PriorityP } func (ws *randomWriteScheduler) Push(wr FrameWriteRequest) { - id := wr.StreamID() - if id == 0 { + if wr.isControl() { ws.zero.push(wr) return } + id := wr.StreamID() q, ok := ws.sq[id] if !ok { q = ws.queuePool.get() @@ -59,7 +59,7 @@ func (ws *randomWriteScheduler) Push(wr FrameWriteRequest) { } func (ws *randomWriteScheduler) Pop() (FrameWriteRequest, bool) { - // Control frames first. + // Control and RST_STREAM frames first. if !ws.zero.empty() { return ws.zero.shift(), true } diff --git a/http2/writesched_random_test.go b/http2/writesched_random_test.go index 1f501b4bd..02a41f36b 100644 --- a/http2/writesched_random_test.go +++ b/http2/writesched_random_test.go @@ -14,8 +14,9 @@ func TestRandomScheduler(t *testing.T) { ws.Push(makeWriteHeadersRequest(2)) ws.Push(makeWriteNonStreamRequest()) ws.Push(makeWriteNonStreamRequest()) + ws.Push(makeWriteRSTStream(1)) - // Pop all frames. Should get the non-stream requests first, + // Pop all frames. Should get the non-stream and RST stream requests first, // followed by the stream requests in any order. var order []FrameWriteRequest for { @@ -26,12 +27,15 @@ func TestRandomScheduler(t *testing.T) { order = append(order, wr) } t.Logf("got frames: %v", order) - if len(order) != 6 { + if len(order) != 7 { t.Fatalf("got %d frames, expected 6", len(order)) } if order[0].StreamID() != 0 || order[1].StreamID() != 0 { t.Fatal("expected non-stream frames first", order[0], order[1]) } + if _, ok := order[2].write.(StreamError); !ok { + t.Fatal("expected RST stream frames first", order[2]) + } got := make(map[uint32]bool) for _, wr := range order[2:] { got[wr.StreamID()] = true diff --git a/http2/writesched_test.go b/http2/writesched_test.go index 99be5a771..6dbd7f0ad 100644 --- a/http2/writesched_test.go +++ b/http2/writesched_test.go @@ -25,6 +25,10 @@ func makeHandlerPanicRST(streamID uint32) FrameWriteRequest { return FrameWriteRequest{&handlerPanicRST{StreamID: streamID}, st, nil} } +func makeWriteRSTStream(streamID uint32) FrameWriteRequest { + return FrameWriteRequest{write: streamError(streamID, ErrCodeInternal)} +} + func checkConsume(wr FrameWriteRequest, nbytes int32, want []FrameWriteRequest) error { consumed, rest, n := wr.Consume(nbytes) var wantConsumed, wantRest FrameWriteRequest @@ -52,6 +56,56 @@ func TestFrameWriteRequestNonData(t *testing.T) { if err := checkConsume(wr, 0, []FrameWriteRequest{wr}); err != nil { t.Errorf("Consume:\n%v", err) } + + wr = makeWriteRSTStream(123) + if got, want := wr.DataSize(), 0; got != want { + t.Errorf("DataSize: got %v, want %v", got, want) + } + + // RST_STREAM frames are always consumed whole. + if err := checkConsume(wr, 0, []FrameWriteRequest{wr}); err != nil { + t.Errorf("Consume:\n%v", err) + } +} + +// #49741 RST_STREAM and Control frames should have more priority than data +// frames to avoid blocking streams caused by clients not able to drain the +// queue. +func TestFrameWriteRequestWithData(t *testing.T) { + st := &stream{ + id: 1, + sc: &serverConn{maxFrameSize: 16}, + } + const size = 32 + wr := FrameWriteRequest{&writeData{st.id, make([]byte, size), true}, st, make(chan error)} + if got, want := wr.DataSize(), size; got != want { + t.Errorf("DataSize: got %v, want %v", got, want) + } + + // No flow-control bytes available: cannot consume anything. + if err := checkConsume(wr, math.MaxInt32, []FrameWriteRequest{}); err != nil { + t.Errorf("Consume(limited by flow control):\n%v", err) + } + + wr = makeWriteNonStreamRequest() + if got, want := wr.DataSize(), 0; got != want { + t.Errorf("DataSize: got %v, want %v", got, want) + } + + // Non-DATA frames are always consumed whole. + if err := checkConsume(wr, 0, []FrameWriteRequest{wr}); err != nil { + t.Errorf("Consume:\n%v", err) + } + + wr = makeWriteRSTStream(1) + if got, want := wr.DataSize(), 0; got != want { + t.Errorf("DataSize: got %v, want %v", got, want) + } + + // RST_STREAM frames are always consumed whole. + if err := checkConsume(wr, 0, []FrameWriteRequest{wr}); err != nil { + t.Errorf("Consume:\n%v", err) + } } func TestFrameWriteRequestData(t *testing.T) {