Skip to content

Commit

Permalink
new API
Browse files Browse the repository at this point in the history
  • Loading branch information
dunglas committed Nov 23, 2022
1 parent 6b9da27 commit 6eb176a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 75 deletions.
41 changes: 23 additions & 18 deletions src/net/http/httptest/recorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,24 @@ import (
"fmt"
"io"
"net/http"
"net/http/httptrace"
"net/textproto"
"strconv"
"strings"

"golang.org/x/net/http/httpguts"
)

// InformationalResponse is an HTTP response sent with a [1xx status code].
//
// [1xx status code]: https://httpwg.org/specs/rfc9110.html#status.1xx
type InformationalResponse struct {
// Code is the 1xx HTTP response code of this informational response.
Code int

// Header contains the headers of this informational response.
Header http.Header
}

// ResponseRecorder is an implementation of http.ResponseWriter that
// records its mutations for later inspection in tests.
type ResponseRecorder struct {
Expand All @@ -28,6 +38,9 @@ type ResponseRecorder struct {
// method.
Code int

// Informational HTTP responses (1xx status code) send before the main response.
InformationalResponses []InformationalResponse

// HeaderMap contains the headers explicitly set by the Handler.
// It is an internal detail.
//
Expand All @@ -43,9 +56,6 @@ type ResponseRecorder struct {
// Flushed is whether the Handler called Flush.
Flushed bool

// ClientTrace is used to trace 1XX responses
ClientTrace *httptrace.ClientTrace

result *http.Response // cache of Result's return value
snapHeader http.Header // snapshot of HeaderMap at first Write
wroteHeader bool
Expand Down Expand Up @@ -151,24 +161,19 @@ func (rw *ResponseRecorder) WriteHeader(code int) {

checkWriteHeaderCode(code)

if rw.ClientTrace != nil && code >= 100 && code < 200 {
if code == 100 {
rw.ClientTrace.Got100Continue()
}
// treat 101 as a terminal status, see issue 26161
if code != http.StatusSwitchingProtocols {
if err := rw.ClientTrace.Got1xxResponse(code, textproto.MIMEHeader(rw.HeaderMap)); err != nil {
panic(err)
}
return
}
if rw.HeaderMap == nil {
rw.HeaderMap = make(http.Header)
}

if code >= 100 && code < 200 {
ir := InformationalResponse{code, rw.HeaderMap.Clone()}
rw.InformationalResponses = append(rw.InformationalResponses, ir)

return
}

rw.Code = code
rw.wroteHeader = true
if rw.HeaderMap == nil {
rw.HeaderMap = make(http.Header)
}
rw.snapHeader = rw.HeaderMap.Clone()
}

Expand Down
87 changes: 30 additions & 57 deletions src/net/http/httptest/recorder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ import (
"fmt"
"io"
"net/http"
"net/http/httptrace"
"net/textproto"
"reflect"
"testing"
)

Expand Down Expand Up @@ -125,6 +124,15 @@ func TestRecorder(t *testing.T) {
return nil
}
}
hasInformationalResponses := func(ir []InformationalResponse) checkFunc {
return func(rec *ResponseRecorder) error {
if !reflect.DeepEqual(ir, rec.InformationalResponses) {
return fmt.Errorf("InformationalResponses = %v; want %v", rec.InformationalResponses, ir)
}

return nil
}
}

for _, tt := range [...]struct {
name string
Expand Down Expand Up @@ -296,6 +304,26 @@ func TestRecorder(t *testing.T) {
check(hasResultContents("")), // check we don't crash reading the body

},
{
"1xx status code",
func(rw http.ResponseWriter, _ *http.Request) {
rw.WriteHeader(http.StatusContinue)
rw.Header().Add("Foo", "bar")

rw.WriteHeader(http.StatusEarlyHints)
rw.Header().Add("Baz", "bat")

rw.Header().Del("Foo")
},
check(
hasInformationalResponses([]InformationalResponse{
InformationalResponse{100, http.Header{}},
InformationalResponse{103, http.Header{"Foo": []string{"bar"}}},
}),
hasHeader("Baz", "bat"),
hasNotHeaders("Foo"),
),
},
} {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
Expand Down Expand Up @@ -371,58 +399,3 @@ func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) {
})
}
}

func TestRecorderClientTrace(t *testing.T) {
handler := func(rw http.ResponseWriter, _ *http.Request) {
rw.WriteHeader(http.StatusContinue)

rw.Header().Add("Foo", "bar")
rw.WriteHeader(http.StatusEarlyHints)

rw.Header().Add("Baz", "bat")
}

var received100, received103 bool
trace := &httptrace.ClientTrace{
Got100Continue: func() {
received100 = true
},
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
switch code {
case http.StatusContinue:
case http.StatusEarlyHints:
received103 = true
if header.Get("Foo") != "bar" {
t.Errorf(`Expected Foo=bar, got %s`, header.Get("Foo"))
}
if header.Get("Bar") != "" {
t.Error("Unexpected Bar header")
}
default:
t.Errorf("Unexpected status code %d", code)
}

return nil
},
}

r, _ := http.NewRequest("GET", "http://example.org/", nil)
rw := NewRecorder()
rw.ClientTrace = trace
handler(rw, r)

if !received100 {
t.Error("Got100Continue not called")
}
if !received103 {
t.Error("103 request not received")
}

header := rw.Result().Header
if header.Get("Foo") != "bar" {
t.Errorf("Expected Foo=bar, got %s", header.Get("Foo"))
}
if header.Get("Baz") != "bat" {
t.Errorf("Expected Baz=bat, got %s", header.Get("Baz"))
}
}

0 comments on commit 6eb176a

Please sign in to comment.