Skip to content

Commit 30b72df

Browse files
Tim Cooperbradfitz
Tim Cooper
authored andcommitted
oauth2: close request body if errors occur before base RoundTripper is invoked
Fixes golang/oauth#269 Change-Id: I25eb3273a0868a999a2e98961ae5e4040e44ad7a Reviewed-on: https://go-review.googlesource.com/114956 Reviewed-by: Brad Fitzpatrick <[email protected]>
1 parent bee4e0a commit 30b72df

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

transport.go

+13
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ type Transport struct {
3434
// access token. If no token exists or token is expired,
3535
// tries to refresh/fetch a new token.
3636
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
37+
reqBodyClosed := false
38+
if req.Body != nil {
39+
defer func() {
40+
if !reqBodyClosed {
41+
req.Body.Close()
42+
}
43+
}()
44+
}
45+
3746
if t.Source == nil {
3847
return nil, errors.New("oauth2: Transport's Source is nil")
3948
}
@@ -46,6 +55,10 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
4655
token.SetAuthHeader(req2)
4756
t.setModReq(req, req2)
4857
res, err := t.base().RoundTrip(req2)
58+
59+
// req.Body is assumed to have been closed by the base RoundTripper.
60+
reqBodyClosed = true
61+
4962
if err != nil {
5063
t.setModReq(req, nil)
5164
return nil, err

transport_test.go

+60
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package oauth2
22

33
import (
4+
"errors"
5+
"io"
46
"net/http"
57
"net/http/httptest"
68
"testing"
@@ -27,6 +29,64 @@ func TestTransportNilTokenSource(t *testing.T) {
2729
}
2830
}
2931

32+
type readCloseCounter struct {
33+
CloseCount int
34+
ReadErr error
35+
}
36+
37+
func (r *readCloseCounter) Read(b []byte) (int, error) {
38+
return 0, r.ReadErr
39+
}
40+
41+
func (r *readCloseCounter) Close() error {
42+
r.CloseCount++
43+
return nil
44+
}
45+
46+
func TestTransportCloseRequestBody(t *testing.T) {
47+
tr := &Transport{}
48+
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
49+
defer server.Close()
50+
client := &http.Client{Transport: tr}
51+
body := &readCloseCounter{
52+
ReadErr: errors.New("readCloseCounter.Read not implemented"),
53+
}
54+
resp, err := client.Post(server.URL, "application/json", body)
55+
if err == nil {
56+
t.Errorf("got no errors, want an error with nil token source")
57+
}
58+
if resp != nil {
59+
t.Errorf("Response = %v; want nil", resp)
60+
}
61+
if expected := 1; body.CloseCount != expected {
62+
t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
63+
}
64+
}
65+
66+
func TestTransportCloseRequestBodySuccess(t *testing.T) {
67+
tr := &Transport{
68+
Source: StaticTokenSource(&Token{
69+
AccessToken: "abc",
70+
}),
71+
}
72+
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
73+
defer server.Close()
74+
client := &http.Client{Transport: tr}
75+
body := &readCloseCounter{
76+
ReadErr: io.EOF,
77+
}
78+
resp, err := client.Post(server.URL, "application/json", body)
79+
if err != nil {
80+
t.Errorf("got error %v; expected none", err)
81+
}
82+
if resp == nil {
83+
t.Errorf("Response is nil; expected non-nil")
84+
}
85+
if expected := 1; body.CloseCount != expected {
86+
t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
87+
}
88+
}
89+
3090
func TestTransportTokenSource(t *testing.T) {
3191
ts := &tokenSource{
3292
token: &Token{

0 commit comments

Comments
 (0)