Skip to content

Commit 80673b4

Browse files
committed
oauth2: auto-detect auth style by default, add Endpoint.AuthStyle
Instead of maintaining a global map of which OAuth2 servers do which auth style and/or requiring the user to tell us, just try both ways and remember which way worked. But if users want to tell us in the Endpoint, this CL also add Endpoint.AuthStyle. Fixes golang#111 Fixes golang#365 Fixes golang#362 Fixes golang#357 Fixes golang#353 Fixes golang#345 Fixes golang#326 Fixes golang#352 Fixes golang#268 Fixes https://go-review.googlesource.com/c/oauth2/+/58510 (... and surely many more ...) Change-Id: I7b4d98ba1900ee2d3e11e629316b0bf867f7d237 Reviewed-on: https://go-review.googlesource.com/c/157820 Run-TryBot: Brad Fitzpatrick <[email protected]> TryBot-Result: Gobot Gobot <[email protected]> Reviewed-by: Ross Light <[email protected]>
1 parent 99b60b7 commit 80673b4

File tree

9 files changed

+200
-183
lines changed

9 files changed

+200
-183
lines changed

clientcredentials/clientcredentials.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ type Config struct {
4242

4343
// EndpointParams specifies additional parameters for requests to the token endpoint.
4444
EndpointParams url.Values
45+
46+
// AuthStyle optionally specifies how the endpoint wants the
47+
// client ID & client secret sent. The zero value means to
48+
// auto-detect.
49+
AuthStyle oauth2.AuthStyle
4550
}
4651

4752
// Token uses client credentials to retrieve a token.
@@ -97,7 +102,8 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
97102
}
98103
v[k] = p
99104
}
100-
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v)
105+
106+
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle))
101107
if err != nil {
102108
if rErr, ok := err.(*internal.RetrieveError); ok {
103109
return nil, (*oauth2.RetrieveError)(rErr)

clientcredentials/clientcredentials_test.go

+12-5
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@ package clientcredentials
66

77
import (
88
"context"
9+
"io"
910
"io/ioutil"
1011
"net/http"
1112
"net/http/httptest"
1213
"net/url"
1314
"testing"
15+
16+
"golang.org/x/oauth2/internal"
1417
)
1518

1619
func newConf(serverURL string) *Config {
@@ -111,21 +114,25 @@ func TestTokenRequest(t *testing.T) {
111114
}
112115

113116
func TestTokenRefreshRequest(t *testing.T) {
117+
internal.ResetAuthCache()
114118
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
115119
if r.URL.String() == "/somethingelse" {
116120
return
117121
}
118122
if r.URL.String() != "/token" {
119-
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
123+
t.Errorf("Unexpected token refresh request URL: %q", r.URL)
120124
}
121125
headerContentType := r.Header.Get("Content-Type")
122-
if headerContentType != "application/x-www-form-urlencoded" {
123-
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
126+
if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want {
127+
t.Errorf("Content-Type = %q; want %q", got, want)
124128
}
125129
body, _ := ioutil.ReadAll(r.Body)
126-
if string(body) != "audience=audience1&grant_type=client_credentials&scope=scope1+scope2" {
127-
t.Errorf("Unexpected refresh token payload, %v is found.", string(body))
130+
const want = "audience=audience1&grant_type=client_credentials&scope=scope1+scope2"
131+
if string(body) != want {
132+
t.Errorf("Unexpected refresh token payload.\n got: %s\nwant: %s\n", body, want)
128133
}
134+
w.Header().Set("Content-Type", "application/json")
135+
io.WriteString(w, `{"access_token": "foo", "refresh_token": "bar"}`)
129136
}))
130137
defer ts.Close()
131138
conf := newConf(ts.URL)

google/google.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ import (
1919

2020
// Endpoint is Google's OAuth 2.0 endpoint.
2121
var Endpoint = oauth2.Endpoint{
22-
AuthURL: "https://accounts.google.com/o/oauth2/auth",
23-
TokenURL: "https://accounts.google.com/o/oauth2/token",
22+
AuthURL: "https://accounts.google.com/o/oauth2/auth",
23+
TokenURL: "https://accounts.google.com/o/oauth2/token",
24+
AuthStyle: oauth2.AuthStyleInParams,
2425
}
2526

2627
// JWTTokenURL is Google's OAuth 2.0 token URL to use with the JWT flow.

internal/token.go

+115-95
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"net/url"
1717
"strconv"
1818
"strings"
19+
"sync"
1920
"time"
2021

2122
"golang.org/x/net/context/ctxhttp"
@@ -90,102 +91,71 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error {
9091
return nil
9192
}
9293

93-
var brokenAuthHeaderProviders = []string{
94-
"https://accounts.google.com/",
95-
"https://api.codeswholesale.com/oauth/token",
96-
"https://api.dropbox.com/",
97-
"https://api.dropboxapi.com/",
98-
"https://api.instagram.com/",
99-
"https://api.netatmo.net/",
100-
"https://api.odnoklassniki.ru/",
101-
"https://api.pushbullet.com/",
102-
"https://api.soundcloud.com/",
103-
"https://api.twitch.tv/",
104-
"https://id.twitch.tv/",
105-
"https://app.box.com/",
106-
"https://api.box.com/",
107-
"https://connect.stripe.com/",
108-
"https://login.mailchimp.com/",
109-
"https://login.microsoftonline.com/",
110-
"https://login.salesforce.com/",
111-
"https://login.windows.net",
112-
"https://login.live.com/",
113-
"https://login.live-int.com/",
114-
"https://oauth.sandbox.trainingpeaks.com/",
115-
"https://oauth.trainingpeaks.com/",
116-
"https://oauth.vk.com/",
117-
"https://openapi.baidu.com/",
118-
"https://slack.com/",
119-
"https://test-sandbox.auth.corp.google.com",
120-
"https://test.salesforce.com/",
121-
"https://user.gini.net/",
122-
"https://www.douban.com/",
123-
"https://www.googleapis.com/",
124-
"https://www.linkedin.com/",
125-
"https://www.strava.com/oauth/",
126-
"https://www.wunderlist.com/oauth/",
127-
"https://api.patreon.com/",
128-
"https://sandbox.codeswholesale.com/oauth/token",
129-
"https://api.sipgate.com/v1/authorization/oauth",
130-
"https://api.medium.com/v1/tokens",
131-
"https://log.finalsurge.com/oauth/token",
132-
"https://multisport.todaysplan.com.au/rest/oauth/access_token",
133-
"https://whats.todaysplan.com.au/rest/oauth/access_token",
134-
"https://stackoverflow.com/oauth/access_token",
135-
"https://account.health.nokia.com",
136-
"https://accounts.zoho.com",
137-
"https://gitter.im/login/oauth/token",
138-
"https://openid-connect.onelogin.com/oidc",
139-
"https://api.dailymotion.com/oauth/token",
94+
// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
95+
//
96+
// Deprecated: this function no longer does anything. Caller code that
97+
// wants to avoid potential extra HTTP requests made during
98+
// auto-probing of the provider's auth style should set
99+
// Endpoint.AuthStyle.
100+
func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
101+
102+
// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type.
103+
type AuthStyle int
104+
105+
const (
106+
AuthStyleUnknown AuthStyle = 0
107+
AuthStyleInParams AuthStyle = 1
108+
AuthStyleInHeader AuthStyle = 2
109+
)
110+
111+
// authStyleCache is the set of tokenURLs we've successfully used via
112+
// RetrieveToken and which style auth we ended up using.
113+
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
114+
// the set of OAuth2 servers a program contacts over time is fixed and
115+
// small.
116+
var authStyleCache struct {
117+
sync.Mutex
118+
m map[string]AuthStyle // keyed by tokenURL
140119
}
141120

142-
// brokenAuthHeaderDomains lists broken providers that issue dynamic endpoints.
143-
var brokenAuthHeaderDomains = []string{
144-
".auth0.com",
145-
".force.com",
146-
".myshopify.com",
147-
".okta.com",
148-
".oktapreview.com",
121+
// ResetAuthCache resets the global authentication style cache used
122+
// for AuthStyleUnknown token requests.
123+
func ResetAuthCache() {
124+
authStyleCache.Lock()
125+
defer authStyleCache.Unlock()
126+
authStyleCache.m = nil
149127
}
150128

151-
func RegisterBrokenAuthHeaderProvider(tokenURL string) {
152-
brokenAuthHeaderProviders = append(brokenAuthHeaderProviders, tokenURL)
129+
// lookupAuthStyle reports which auth style we last used with tokenURL
130+
// when calling RetrieveToken and whether we have ever done so.
131+
func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
132+
authStyleCache.Lock()
133+
defer authStyleCache.Unlock()
134+
style, ok = authStyleCache.m[tokenURL]
135+
return
153136
}
154137

155-
// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL
156-
// implements the OAuth2 spec correctly
157-
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
158-
// In summary:
159-
// - Reddit only accepts client secret in the Authorization header
160-
// - Dropbox accepts either it in URL param or Auth header, but not both.
161-
// - Google only accepts URL param (not spec compliant?), not Auth header
162-
// - Stripe only accepts client secret in Auth header with Bearer method, not Basic
163-
func providerAuthHeaderWorks(tokenURL string) bool {
164-
for _, s := range brokenAuthHeaderProviders {
165-
if strings.HasPrefix(tokenURL, s) {
166-
// Some sites fail to implement the OAuth2 spec fully.
167-
return false
168-
}
138+
// setAuthStyle adds an entry to authStyleCache, documented above.
139+
func setAuthStyle(tokenURL string, v AuthStyle) {
140+
authStyleCache.Lock()
141+
defer authStyleCache.Unlock()
142+
if authStyleCache.m == nil {
143+
authStyleCache.m = make(map[string]AuthStyle)
169144
}
170-
171-
if u, err := url.Parse(tokenURL); err == nil {
172-
for _, s := range brokenAuthHeaderDomains {
173-
if strings.HasSuffix(u.Host, s) {
174-
return false
175-
}
176-
}
177-
}
178-
179-
// Assume the provider implements the spec properly
180-
// otherwise. We can add more exceptions as they're
181-
// discovered. We will _not_ be adding configurable hooks
182-
// to this package to let users select server bugs.
183-
return true
145+
authStyleCache.m[tokenURL] = v
184146
}
185147

186-
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) {
187-
bustedAuth := !providerAuthHeaderWorks(tokenURL)
188-
if bustedAuth {
148+
// newTokenRequest returns a new *http.Request to retrieve a new token
149+
// from tokenURL using the provided clientID, clientSecret, and POST
150+
// body parameters.
151+
//
152+
// inParams is whether the clientID & clientSecret should be encoded
153+
// as the POST body. An 'inParams' value of true means to send it in
154+
// the POST body (along with any values in v); false means to send it
155+
// in the Authorization header.
156+
func newTokenRequest(tokenURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) {
157+
if authStyle == AuthStyleInParams {
158+
v = cloneURLValues(v)
189159
if clientID != "" {
190160
v.Set("client_id", clientID)
191161
}
@@ -198,15 +168,70 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
198168
return nil, err
199169
}
200170
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
201-
if !bustedAuth {
171+
if authStyle == AuthStyleInHeader {
202172
req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
203173
}
174+
return req, nil
175+
}
176+
177+
func cloneURLValues(v url.Values) url.Values {
178+
v2 := make(url.Values, len(v))
179+
for k, vv := range v {
180+
v2[k] = append([]string(nil), vv...)
181+
}
182+
return v2
183+
}
184+
185+
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) {
186+
needsAuthStyleProbe := authStyle == 0
187+
if needsAuthStyleProbe {
188+
if style, ok := lookupAuthStyle(tokenURL); ok {
189+
authStyle = style
190+
needsAuthStyleProbe = false
191+
} else {
192+
authStyle = AuthStyleInHeader // the first way we'll try
193+
}
194+
}
195+
req, err := newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle)
196+
if err != nil {
197+
return nil, err
198+
}
199+
token, err := doTokenRoundTrip(ctx, req)
200+
if err != nil && needsAuthStyleProbe {
201+
// If we get an error, assume the server wants the
202+
// clientID & clientSecret in a different form.
203+
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
204+
// In summary:
205+
// - Reddit only accepts client secret in the Authorization header
206+
// - Dropbox accepts either it in URL param or Auth header, but not both.
207+
// - Google only accepts URL param (not spec compliant?), not Auth header
208+
// - Stripe only accepts client secret in Auth header with Bearer method, not Basic
209+
//
210+
// We used to maintain a big table in this code of all the sites and which way
211+
// they went, but maintaining it didn't scale & got annoying.
212+
// So just try both ways.
213+
authStyle = AuthStyleInParams // the second way we'll try
214+
req, _ = newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle)
215+
token, err = doTokenRoundTrip(ctx, req)
216+
}
217+
if needsAuthStyleProbe && err == nil {
218+
setAuthStyle(tokenURL, authStyle)
219+
}
220+
// Don't overwrite `RefreshToken` with an empty value
221+
// if this was a token refreshing request.
222+
if token != nil && token.RefreshToken == "" {
223+
token.RefreshToken = v.Get("refresh_token")
224+
}
225+
return token, err
226+
}
227+
228+
func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
204229
r, err := ctxhttp.Do(ctx, ContextClient(ctx), req)
205230
if err != nil {
206231
return nil, err
207232
}
208-
defer r.Body.Close()
209233
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
234+
r.Body.Close()
210235
if err != nil {
211236
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
212237
}
@@ -256,13 +281,8 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
256281
}
257282
json.Unmarshal(body, &token.Raw) // no error checks for optional fields
258283
}
259-
// Don't overwrite `RefreshToken` with an empty value
260-
// if this was a token refreshing request.
261-
if token.RefreshToken == "" {
262-
token.RefreshToken = v.Get("refresh_token")
263-
}
264284
if token.AccessToken == "" {
265-
return token, errors.New("oauth2: server response missing access_token")
285+
return nil, errors.New("oauth2: server response missing access_token")
266286
}
267287
return token, nil
268288
}

0 commit comments

Comments
 (0)