From 891ff581e391263d26ec9285bffa4d60cb92fc4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Dunglas?= Date: Tue, 11 Oct 2022 17:30:00 +0200 Subject: [PATCH] net/http/httptest: add support for 1XX responses The existing implementation doesn't allow tracing 1xx responses. This patch allow using net/http/httptrace to inspect 1XX responses. Updates #26089. --- src/net/http/httptest/recorder.go | 18 ++++++++ src/net/http/httptest/recorder_test.go | 57 ++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go index 1c1d8801558ed7..3ad206feb88a12 100644 --- a/src/net/http/httptest/recorder.go +++ b/src/net/http/httptest/recorder.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "net/http/httptrace" "net/textproto" "strconv" "strings" @@ -42,6 +43,9 @@ type ResponseRecorder struct { // Flushed is whether the Handler called Flush. Flushed bool + // ClientTrace is used to trace 1XX responses + ClientTrace *httptrace.ClientTrace + result *http.Response // cache of Result's return value snapHeader http.Header // snapshot of HeaderMap at first Write wroteHeader bool @@ -146,6 +150,20 @@ func (rw *ResponseRecorder) WriteHeader(code int) { } checkWriteHeaderCode(code) + + if rw.ClientTrace != nil && code >= 100 && code < 200 { + if code == 100 { + rw.ClientTrace.Got100Continue() + } + // treat 101 as a terminal status, see issue 26161 + if code != http.StatusSwitchingProtocols { + if err := rw.ClientTrace.Got1xxResponse(code, textproto.MIMEHeader(rw.HeaderMap)); err != nil { + panic(err) + } + return + } + } + rw.Code = code rw.wroteHeader = true if rw.HeaderMap == nil { diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go index 4782eced43e6ce..c3b5cf7aa5bba0 100644 --- a/src/net/http/httptest/recorder_test.go +++ b/src/net/http/httptest/recorder_test.go @@ -8,6 +8,8 @@ import ( "fmt" "io" "net/http" + "net/http/httptrace" + "net/textproto" "testing" ) @@ -369,3 +371,58 @@ func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) { }) } } + +func TestRecorderClientTrace(t *testing.T) { + handler := func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusContinue) + + rw.Header().Add("Foo", "bar") + rw.WriteHeader(http.StatusEarlyHints) + + rw.Header().Add("Baz", "bat") + } + + var received100, received103 bool + trace := &httptrace.ClientTrace{ + Got100Continue: func() { + received100 = true + }, + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch code { + case http.StatusContinue: + case http.StatusEarlyHints: + received103 = true + if header.Get("Foo") != "bar" { + t.Errorf(`Expected Foo=bar, got %s`, header.Get("Foo")) + } + if header.Get("Bar") != "" { + t.Error("Unexpected Bar header") + } + default: + t.Errorf("Unexpected status code %d", code) + } + + return nil + }, + } + + r, _ := http.NewRequest("GET", "http://example.org/", nil) + rw := NewRecorder() + rw.ClientTrace = trace + handler(rw, r) + + if !received100 { + t.Error("Got100Continue not called") + } + if !received103 { + t.Error("103 request not received") + } + + header := rw.Result().Header + if header.Get("Foo") != "bar" { + t.Errorf("Expected Foo=bar, got %s", header.Get("Foo")) + } + if header.Get("Baz") != "bat" { + t.Errorf("Expected Baz=bat, got %s", header.Get("Baz")) + } +}