From b668106257cae1a8834f7dec60bf5568b8779aab Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 20 May 2022 11:54:31 +0200 Subject: [PATCH] introduce a http3.RoundTripOpt to prevent closing of request stream (#3411) --- http3/client.go | 18 +++----- http3/client_test.go | 90 ++++++++++++++++++------------------ http3/conn.go | 21 +++++++++ http3/request_writer.go | 10 ++-- http3/request_writer_test.go | 22 ++++++--- http3/roundtrip.go | 15 ++++-- http3/roundtrip_test.go | 2 +- http3/server_test.go | 2 +- 8 files changed, 107 insertions(+), 73 deletions(-) create mode 100644 http3/conn.go diff --git a/http3/client.go b/http3/client.go index 43d65b327e4..0bf0ca43ae7 100644 --- a/http3/client.go +++ b/http3/client.go @@ -236,10 +236,10 @@ func (c *client) maxHeaderBytes() uint64 { return uint64(c.opts.MaxHeaderBytes) } -// RoundTrip executes a request and returns a response -func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { +// RoundTripOpt executes a request and returns a response +func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { - return nil, fmt.Errorf("http3 client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host) + return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host) } c.dialOnce.Do(func() { @@ -268,7 +268,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { } // Request Cancellation: - // This go routine keeps running even after RoundTrip() returns. + // This go routine keeps running even after RoundTripOpt() returns. // It is shut down when the application is done processing the body. reqDone := make(chan struct{}) go func() { @@ -280,7 +280,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { } }() - rsp, rerr := c.doRequest(req, str, reqDone) + rsp, rerr := c.doRequest(req, str, opt, reqDone) if rerr.err != nil { // if any error occurred close(reqDone) if rerr.streamErr != 0 { // if it was a stream error @@ -297,16 +297,12 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { return rsp, rerr.err } -func (c *client) doRequest( - req *http.Request, - str quic.Stream, - reqDone chan struct{}, -) (*http.Response, requestError) { +func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan struct{}) (*http.Response, requestError) { var requestGzip bool if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { requestGzip = true } - if err := c.requestWriter.WriteRequest(str, req, requestGzip); err != nil { + if err := c.requestWriter.WriteRequest(str, req, opt.DontCloseRequestStream, requestGzip); err != nil { return nil, newStreamError(errorInternalError, err) } diff --git a/http3/client_test.go b/http3/client_test.go index b13993f5697..3cdf3a881d2 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -72,7 +72,7 @@ var _ = Describe("Client", func() { dialAddrCalled = true return nil, errors.New("test done") } - client.RoundTrip(req) + client.RoundTripOpt(req, RoundTripOpt{}) Expect(dialAddrCalled).To(BeTrue()) }) @@ -87,7 +87,7 @@ var _ = Describe("Client", func() { } req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil) Expect(err).ToNot(HaveOccurred()) - client.RoundTrip(req) + client.RoundTripOpt(req, RoundTripOpt{}) Expect(dialAddrCalled).To(BeTrue()) }) @@ -108,7 +108,7 @@ var _ = Describe("Client", func() { dialAddrCalled = true return nil, errors.New("test done") } - client.RoundTrip(req) + client.RoundTripOpt(req, RoundTripOpt{}) Expect(dialAddrCalled).To(BeTrue()) // make sure the original tls.Config was not modified Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"})) @@ -131,7 +131,7 @@ var _ = Describe("Client", func() { } client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer) Expect(err).ToNot(HaveOccurred()) - _, err = client.RoundTrip(req.WithContext(ctx)) + _, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{}) Expect(err).To(MatchError(testErr)) Expect(dialerCalled).To(BeTrue()) }) @@ -144,7 +144,7 @@ var _ = Describe("Client", func() { Expect(quicConf.EnableDatagrams).To(BeTrue()) return nil, testErr } - _, err = client.RoundTrip(req) + _, err = client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError(testErr)) }) @@ -155,7 +155,7 @@ var _ = Describe("Client", func() { dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { return nil, testErr } - _, err = client.RoundTrip(req) + _, err = client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError(testErr)) }) @@ -169,8 +169,8 @@ var _ = Describe("Client", func() { It("refuses to do requests for the wrong host", func() { req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) - _, err = client.RoundTrip(req) - Expect(err).To(MatchError("http3 client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) + _, err = client.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) }) It("allows requests using a different scheme", func() { @@ -180,14 +180,14 @@ var _ = Describe("Client", func() { dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { return nil, testErr } - _, err = client.RoundTrip(req) + _, err = client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError(testErr)) }) }) Context("hijacking unidirectional streams", func() { var ( - request *http.Request + req *http.Request conn *mockquic.MockEarlyConnection settingsFrameWritten chan struct{} ) @@ -209,7 +209,7 @@ var _ = Describe("Client", func() { return conn, nil } var err error - request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) + req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) }) @@ -236,7 +236,7 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -261,7 +261,7 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -270,7 +270,7 @@ var _ = Describe("Client", func() { Context("control stream handling", func() { var ( - request *http.Request + req *http.Request conn *mockquic.MockEarlyConnection settingsFrameWritten chan struct{} ) @@ -291,7 +291,7 @@ var _ = Describe("Client", func() { return conn, nil } var err error - request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) + req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) }) @@ -313,7 +313,7 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -338,7 +338,7 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead }) @@ -361,7 +361,7 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) @@ -385,7 +385,7 @@ var _ = Describe("Client", func() { Expect(code).To(BeEquivalentTo(errorMissingSettings)) close(done) }) - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) @@ -411,7 +411,7 @@ var _ = Describe("Client", func() { Expect(code).To(BeEquivalentTo(errorFrameError)) close(done) }) - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) @@ -434,7 +434,7 @@ var _ = Describe("Client", func() { Expect(code).To(BeEquivalentTo(errorIDError)) close(done) }) - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) @@ -461,7 +461,7 @@ var _ = Describe("Client", func() { Expect(reason).To(Equal("missing QUIC Datagram support")) close(done) }) - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) @@ -469,7 +469,7 @@ var _ = Describe("Client", func() { Context("Doing requests", func() { var ( - request *http.Request + req *http.Request str *mockquic.MockStream conn *mockquic.MockEarlyConnection settingsFrameWritten chan struct{} @@ -540,7 +540,7 @@ var _ = Describe("Client", func() { return conn, nil } var err error - request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) + req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) }) @@ -554,13 +554,13 @@ var _ = Describe("Client", func() { conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) conn.EXPECT().HandshakeComplete().Return(handshakeCtx) - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError(testErr)) }) It("performs a 0-RTT request", func() { testErr := errors.New("stream open error") - request.Method = MethodGet0RTT + req.Method = MethodGet0RTT // don't EXPECT any calls to HandshakeComplete() conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) buf := &bytes.Buffer{} @@ -570,7 +570,7 @@ var _ = Describe("Client", func() { str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr }) - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError(testErr)) Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET")) }) @@ -585,7 +585,7 @@ var _ = Describe("Client", func() { str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) str.EXPECT().Close() str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() - rsp, err := client.RoundTrip(request) + rsp, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).ToNot(HaveOccurred()) Expect(rsp.Proto).To(Equal("HTTP/3")) Expect(rsp.ProtoMajor).To(Equal(3)) @@ -604,7 +604,7 @@ var _ = Describe("Client", func() { body := &mockBody{} body.SetData([]byte("request body")) var err error - request, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body) + req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body) Expect(err).ToNot(HaveOccurred()) str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() }) @@ -620,7 +620,7 @@ var _ = Describe("Client", func() { <-done return 0, errors.New("test done") }) - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("test done")) hfs := decodeHeader(strBuf) Expect(hfs).To(HaveKeyWithValue(":method", "POST")) @@ -628,7 +628,7 @@ var _ = Describe("Client", func() { }) It("returns the error that occurred when reading the body", func() { - request.Body.(*mockBody).readErr = errors.New("testErr") + req.Body.(*mockBody).readErr = errors.New("testErr") done := make(chan struct{}) gomock.InOrder( str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { @@ -642,7 +642,7 @@ var _ = Describe("Client", func() { <-done return 0, errors.New("test done") }) - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("test done")) }) @@ -660,7 +660,7 @@ var _ = Describe("Client", func() { str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors // the response body is sent asynchronously, while already reading the response str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - req, err := client.RoundTrip(request) + req, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).ToNot(HaveOccurred()) Expect(req.ContentLength).To(BeEquivalentTo(1337)) Eventually(done).Should(BeClosed()) @@ -673,7 +673,7 @@ var _ = Describe("Client", func() { closed := make(chan struct{}) str.EXPECT().Close().Do(func() { close(closed) }) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("expected first frame to be a HEADERS frame")) Eventually(closed).Should(BeClosed()) }) @@ -685,7 +685,7 @@ var _ = Describe("Client", func() { closed := make(chan struct{}) str.EXPECT().Close().Do(func() { close(closed) }) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)")) Eventually(closed).Should(BeClosed()) }) @@ -694,12 +694,12 @@ var _ = Describe("Client", func() { Context("request cancellations", func() { It("cancels a request while waiting for the handshake to complete", func() { ctx, cancel := context.WithCancel(context.Background()) - req := request.WithContext(ctx) + req := req.WithContext(ctx) conn.EXPECT().HandshakeComplete().Return(context.Background()) errChan := make(chan error) go func() { - _, err := client.RoundTrip(req) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) errChan <- err }() Consistently(errChan).ShouldNot(Receive()) @@ -709,7 +709,7 @@ var _ = Describe("Client", func() { It("cancels a request while the request is still in flight", func() { ctx, cancel := context.WithCancel(context.Background()) - req := request.WithContext(ctx) + req := req.WithContext(ctx) conn.EXPECT().HandshakeComplete().Return(handshakeCtx) conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) buf := &bytes.Buffer{} @@ -729,7 +729,7 @@ var _ = Describe("Client", func() { <-canceled return 0, errors.New("test done") }) - _, err := client.RoundTrip(req) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("test done")) Eventually(done).Should(BeClosed()) }) @@ -738,7 +738,7 @@ var _ = Describe("Client", func() { rspBuf := bytes.NewBuffer(getResponse(404)) ctx, cancel := context.WithCancel(context.Background()) - req := request.WithContext(ctx) + req := req.WithContext(ctx) conn.EXPECT().HandshakeComplete().Return(handshakeCtx) conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) @@ -750,7 +750,7 @@ var _ = Describe("Client", func() { str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }) - _, err := client.RoundTrip(req) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(done).Should(BeClosed()) @@ -771,7 +771,7 @@ var _ = Describe("Client", func() { str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors ) str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) - _, err := client.RoundTrip(request) + _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("test done")) hfs := decodeHeader(buf) Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip")) @@ -788,7 +788,7 @@ var _ = Describe("Client", func() { str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors ) str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) - _, err = client.RoundTrip(request) + _, err = client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("test done")) hfs := decodeHeader(buf) Expect(hfs).ToNot(HaveKey("accept-encoding")) @@ -810,7 +810,7 @@ var _ = Describe("Client", func() { str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() str.EXPECT().Close() - rsp, err := client.RoundTrip(request) + rsp, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).ToNot(HaveOccurred()) data, err := ioutil.ReadAll(rsp.Body) Expect(err).ToNot(HaveOccurred()) @@ -833,7 +833,7 @@ var _ = Describe("Client", func() { str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() str.EXPECT().Close() - rsp, err := client.RoundTrip(request) + rsp, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).ToNot(HaveOccurred()) data, err := ioutil.ReadAll(rsp.Body) Expect(err).ToNot(HaveOccurred()) diff --git a/http3/conn.go b/http3/conn.go new file mode 100644 index 00000000000..c4ec749e6d9 --- /dev/null +++ b/http3/conn.go @@ -0,0 +1,21 @@ +package http3 + +import "github.com/lucas-clemente/quic-go" + +type ConnState struct { + SupportsDatagram bool +} + +type Conn struct { + conn quic.Connection + + supportsDatagram bool +} + +func (c *Conn) State() ConnState { + return ConnState{SupportsDatagram: c.supportsDatagram} +} + +func (c *Conn) SendDatagram(b []byte) error { + return c.conn.SendMessage(b) +} diff --git a/http3/request_writer.go b/http3/request_writer.go index aebb640b17e..bde141bee91 100644 --- a/http3/request_writer.go +++ b/http3/request_writer.go @@ -38,7 +38,7 @@ func newRequestWriter(logger utils.Logger) *requestWriter { } } -func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bool) error { +func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, dontCloseStr, gzip bool) error { buf := &bytes.Buffer{} if err := w.writeHeaders(buf, req, gzip); err != nil { return err @@ -48,7 +48,9 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bo } // TODO: add support for trailers if req.Body == nil { - str.Close() + if !dontCloseStr { + str.Close() + } return nil } @@ -84,7 +86,9 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bo return } } - str.Close() + if !dontCloseStr { + str.Close() + } }() return nil diff --git a/http3/request_writer_test.go b/http3/request_writer_test.go index 9a1e718e289..e2c80cdc1d9 100644 --- a/http3/request_writer_test.go +++ b/http3/request_writer_test.go @@ -60,7 +60,7 @@ var _ = Describe("Request Writer", func() { str.EXPECT().Close() req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/index.html?foo=bar", nil) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, false)).To(Succeed()) + Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) Expect(headerFields).To(HaveKeyWithValue(":method", "GET")) @@ -69,13 +69,21 @@ var _ = Describe("Request Writer", func() { Expect(headerFields).ToNot(HaveKey("accept-encoding")) }) + It("writes a GET request without closing the stream", func() { + req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io", nil) + Expect(err).ToNot(HaveOccurred()) + Expect(rw.WriteRequest(str, req, true, false)).To(Succeed()) + headerFields := decode(strBuf) + Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) + }) + It("writes a POST request", func() { closed := make(chan struct{}) str.EXPECT().Close().Do(func() { close(closed) }) postData := bytes.NewReader([]byte("foobar")) req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", postData) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, false)).To(Succeed()) + Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) Eventually(closed).Should(BeClosed()) headerFields := decode(strBuf) @@ -96,7 +104,7 @@ var _ = Describe("Request Writer", func() { str.EXPECT().Close().Do(func() { close(closed) }) req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", &foobarReader{}) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, false)).To(Succeed()) + Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) Eventually(closed).Should(BeClosed()) headerFields := decode(strBuf) @@ -122,7 +130,7 @@ var _ = Describe("Request Writer", func() { } req.AddCookie(cookie1) req.AddCookie(cookie2) - Expect(rw.WriteRequest(str, req, false)).To(Succeed()) + Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`)) }) @@ -131,7 +139,7 @@ var _ = Describe("Request Writer", func() { str.EXPECT().Close() req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, true)).To(Succeed()) + Expect(rw.WriteRequest(str, req, false, true)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip")) }) @@ -140,7 +148,7 @@ var _ = Describe("Request Writer", func() { str.EXPECT().Close() req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/", nil) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, false)).To(Succeed()) + Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT")) Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) @@ -154,7 +162,7 @@ var _ = Describe("Request Writer", func() { req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/foobar", nil) Expect(err).ToNot(HaveOccurred()) req.Proto = "webtransport" - Expect(rw.WriteRequest(str, req, false)).To(Succeed()) + Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT")) diff --git a/http3/roundtrip.go b/http3/roundtrip.go index 6ba251cbb43..f5fe2ab9451 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -16,7 +16,7 @@ import ( ) type roundTripCloser interface { - http.RoundTripper + RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) io.Closer } @@ -77,11 +77,16 @@ type RoundTripper struct { // RoundTripOpt are options for the Transport.RoundTripOpt method. type RoundTripOpt struct { // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. - // If set true and no cached connection is available, RoundTrip will return ErrNoCachedConn. + // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. OnlyCachedConn bool + // DontCloseRequestStream controls whether the request stream is closed after sending the request. + DontCloseRequestStream bool } -var _ roundTripCloser = &RoundTripper{} +var ( + _ http.RoundTripper = &RoundTripper{} + _ io.Closer = &RoundTripper{} +) // ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set var ErrNoCachedConn = errors.New("http3: no cached connection was available") @@ -127,7 +132,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. if err != nil { return nil, err } - return cl.RoundTrip(req) + return cl.RoundTripOpt(req, opt) } // RoundTrip does a round trip. @@ -135,7 +140,7 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return r.RoundTripOpt(req, RoundTripOpt{}) } -func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) { +func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripCloser, error) { r.mutex.Lock() defer r.mutex.Unlock() diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index a17cf4db087..4596975b659 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -20,7 +20,7 @@ type mockClient struct { closed bool } -func (m *mockClient) RoundTrip(req *http.Request) (*http.Response, error) { +func (m *mockClient) RoundTripOpt(req *http.Request, _ RoundTripOpt) (*http.Response, error) { return &http.Response{Request: req}, nil } diff --git a/http3/server_test.go b/http3/server_test.go index d3a360afd32..b5e23f775aa 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -138,7 +138,7 @@ var _ = Describe("Server", func() { closed := make(chan struct{}) str.EXPECT().Close().Do(func() { close(closed) }) rw := newRequestWriter(utils.DefaultLogger) - Expect(rw.WriteRequest(str, req, false)).To(Succeed()) + Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) Eventually(closed).Should(BeClosed()) return buf.Bytes() }