diff --git a/http2/go18.go b/http2/go18.go index 8c0dd2508..633202c39 100644 --- a/http2/go18.go +++ b/http2/go18.go @@ -8,6 +8,7 @@ package http2 import ( "crypto/tls" + "io" "net/http" ) @@ -39,3 +40,11 @@ func configureServer18(h1 *http.Server, h2 *Server) error { func shouldLogPanic(panicValue interface{}) bool { return panicValue != nil && panicValue != http.ErrAbortHandler } + +func reqGetBody(req *http.Request) func() (io.ReadCloser, error) { + return req.GetBody +} + +func reqBodyIsNoBody(body io.ReadCloser) bool { + return body == http.NoBody +} diff --git a/http2/not_go18.go b/http2/not_go18.go index 2e600dc35..efbf83c32 100644 --- a/http2/not_go18.go +++ b/http2/not_go18.go @@ -6,7 +6,10 @@ package http2 -import "net/http" +import ( + "io" + "net/http" +) func configureServer18(h1 *http.Server, h2 *Server) error { // No IdleTimeout to sync prior to Go 1.8. @@ -16,3 +19,9 @@ func configureServer18(h1 *http.Server, h2 *Server) error { func shouldLogPanic(panicValue interface{}) bool { return panicValue != nil } + +func reqGetBody(req *http.Request) func() (io.ReadCloser, error) { + return nil +} + +func reqBodyIsNoBody(io.ReadCloser) bool { return false } diff --git a/http2/transport.go b/http2/transport.go index 8f5f84412..68c325ae3 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -191,6 +191,7 @@ type clientStream struct { ID uint32 resc chan resAndError bufPipe pipe // buffered pipe with the flow-controlled response payload + startedWrite bool // started request body write; guarded by cc.mu requestedGzip bool on100 func() // optional code to run if get a 100 continue response @@ -332,8 +333,10 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res } traceGotConn(req, cc) res, err := cc.RoundTrip(req) - if shouldRetryRequest(req, err) { - continue + if err != nil { + if req, err = shouldRetryRequest(req, err); err == nil { + continue + } } if err != nil { t.vlogf("RoundTrip failure: %v", err) @@ -355,12 +358,41 @@ func (t *Transport) CloseIdleConnections() { var ( errClientConnClosed = errors.New("http2: client conn is closed") errClientConnUnusable = errors.New("http2: client conn not usable") + + errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY") + errClientConnGotGoAwayAfterSomeReqBody = errors.New("http2: Transport received Server's graceful shutdown GOAWAY; some request body already written") ) -func shouldRetryRequest(req *http.Request, err error) bool { - // TODO: retry GET requests (no bodies) more aggressively, if shutdown - // before response. - return err == errClientConnUnusable +// shouldRetryRequest is called by RoundTrip when a request fails to get +// response headers. It is always called with a non-nil error. +// It returns either a request to retry (either the same request, or a +// modified clone), or an error if the request can't be replayed. +func shouldRetryRequest(req *http.Request, err error) (*http.Request, error) { + switch err { + default: + return nil, err + case errClientConnUnusable, errClientConnGotGoAway: + return req, nil + case errClientConnGotGoAwayAfterSomeReqBody: + // If the Body is nil (or http.NoBody), it's safe to reuse + // this request and its Body. + if req.Body == nil || reqBodyIsNoBody(req.Body) { + return req, nil + } + // Otherwise we depend on the Request having its GetBody + // func defined. + getBody := reqGetBody(req) // Go 1.8: getBody = req.GetBody + if getBody == nil { + return nil, errors.New("http2: Transport: peer server initiated graceful shutdown after some of Request.Body was written; define Request.GetBody to avoid this error") + } + body, err := getBody() + if err != nil { + return nil, err + } + newReq := *req + newReq.Body = body + return &newReq, nil + } } func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, error) { @@ -513,6 +545,15 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) { if old != nil && old.ErrCode != ErrCodeNo { cc.goAway.ErrCode = old.ErrCode } + last := f.LastStreamID + for streamID, cs := range cc.streams { + if streamID > last { + select { + case cs.resc <- resAndError{err: errClientConnGotGoAway}: + default: + } + } + } } func (cc *ClientConn) CanTakeNewRequest() bool { @@ -773,6 +814,13 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { cs.abortRequestBodyWrite(errStopReqBodyWrite) } if re.err != nil { + if re.err == errClientConnGotGoAway { + cc.mu.Lock() + if cs.startedWrite { + re.err = errClientConnGotGoAwayAfterSomeReqBody + } + cc.mu.Unlock() + } cc.forgetStreamID(cs.ID) return nil, re.err } @@ -2013,6 +2061,9 @@ func (t *Transport) getBodyWriterState(cs *clientStream, body io.Reader) (s body resc := make(chan error, 1) s.resc = resc s.fn = func() { + cs.cc.mu.Lock() + cs.startedWrite = true + cs.cc.mu.Unlock() resc <- cs.writeRequestBody(body, cs.req.Body) } s.delay = t.expectContinueTimeout() diff --git a/http2/transport_test.go b/http2/transport_test.go index cd8e7a186..d10b5f75f 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -2747,7 +2747,6 @@ func TestTransportCancelDataResponseRace(t *testing.T) { } func TestTransportRetryAfterGOAWAY(t *testing.T) { - t.Skip("to be unskipped by https://go-review.googlesource.com/c/33971/") var dialer struct { sync.Mutex count int @@ -2765,6 +2764,9 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { dialer.Lock() defer dialer.Unlock() dialer.count++ + if dialer.count == 3 { + return nil, errors.New("unexpected number of dials") + } cc, err := net.Dial("tcp", ln.Addr().String()) if err != nil { return nil, fmt.Errorf("dial error: %v", err) @@ -2797,10 +2799,20 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { go func() { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := tr.RoundTrip(req) - t.Logf("client got %T, %v", res, err) + if res != nil { + res.Body.Close() + if got := res.Header.Get("Foo"); got != "bar" { + err = fmt.Errorf("foo header = %q; want bar", got) + } + } + if err != nil { + err = fmt.Errorf("RoundTrip: %v", err) + } errs <- err }() + connToClose := make(chan io.Closer, 2) + // Server for the first request. go func() { var ct *clientTester @@ -2810,6 +2822,7 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { return } + connToClose <- ct.cc ct.greet() hf, err := ct.firstHeaders() if err != nil { @@ -2821,7 +2834,6 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err) return } - ct.cc.(*net.TCPConn).Close() errs <- nil }() @@ -2834,17 +2846,19 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { return } + connToClose <- ct.cc ct.greet() hf, err := ct.firstHeaders() if err != nil { errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err) return } - t.Logf("server2 Got %v", hf) + t.Logf("server2 got %v", hf) var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) err = ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: hf.StreamID, EndHeaders: true, @@ -2852,7 +2866,7 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { BlockFragment: buf.Bytes(), }) if err != nil { - errs <- fmt.Errorf("server2 failed writin responseg HEADERS: %v", err) + errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err) } else { errs <- nil } @@ -2868,4 +2882,13 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { t.Errorf("timed out") } } + + for { + select { + case c := <-connToClose: + c.Close() + default: + return + } + } }