Skip to content

Commit

Permalink
oauth2: move global auth style cache to be per-Config
Browse files Browse the repository at this point in the history
In 80673b4 (https://go.dev/cl/157820) I added a never-shrinking
package-global cache to remember which auto-detected auth style (HTTP
headers vs POST) was supported by a certain OAuth2 server, keyed by
its URL.

Unfortunately, some multi-tenant SaaS OIDC servers behave poorly and
have one global OpenID configuration document for all of their
customers which says ("we support all auth styles! you pick!") but
then give each customer control of which style they specifically
accept. This is bogus behavior on their part, but the oauth2 package's
global caching per URL isn't helping. (It's also bad to have a
package-global cache that can never be GC'ed)

So, this change moves the cache to hang off the oauth *Configs
instead. Unfortunately, it does so with some backwards compatiblity
compromises (an atomic.Value hack), lest people are using old versions
of Go still or copying a Config by value, both of which this package
previously accidentally supported, even though they weren't tested.

This change also means that anybody that's repeatedly making ephemeral
oauth.Configs without an explicit auth style will be losing &
reinitializing their cache on any auth style failures + fallbacks to
the other style. I think that should be pretty rare. People seem to
make an oauth2.Config once earlier and stash it away somewhere (often
deep in a token fetcher or HTTP client/transport).

Change-Id: I91f107368ab3c3d77bc425eeef65372a589feb7b
Signed-off-by: Brad Fitzpatrick <[email protected]>
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/515675
TryBot-Result: Gopher Robot <[email protected]>
Reviewed-by: Roland Shoemaker <[email protected]>
Reviewed-by: Adrian Dewhurst <[email protected]>
Reviewed-by: Michael Knyszek <[email protected]>
  • Loading branch information
bradfitz committed Aug 9, 2023
1 parent 2e4a4e2 commit a835fc4
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 39 deletions.
6 changes: 5 additions & 1 deletion clientcredentials/clientcredentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ type Config struct {
// client ID & client secret sent. The zero value means to
// auto-detect.
AuthStyle oauth2.AuthStyle

// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
// the zero value (AuthStyleAutoDetect).
authStyleCache internal.LazyAuthStyleCache
}

// Token uses client credentials to retrieve a token.
Expand Down Expand Up @@ -103,7 +107,7 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
v[k] = p
}

tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle))
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle), c.conf.authStyleCache.Get())
if err != nil {
if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*oauth2.RetrieveError)(rErr)
Expand Down
3 changes: 0 additions & 3 deletions clientcredentials/clientcredentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import (
"net/http/httptest"
"net/url"
"testing"

"golang.org/x/oauth2/internal"
)

func newConf(serverURL string) *Config {
Expand Down Expand Up @@ -114,7 +112,6 @@ func TestTokenRequest(t *testing.T) {
}

func TestTokenRefreshRequest(t *testing.T) {
internal.ResetAuthCache()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/somethingelse" {
return
Expand Down
70 changes: 45 additions & 25 deletions internal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)

Expand Down Expand Up @@ -115,41 +116,60 @@ const (
AuthStyleInHeader AuthStyle = 2
)

// authStyleCache is the set of tokenURLs we've successfully used via
// LazyAuthStyleCache is a backwards compatibility compromise to let Configs
// have a lazily-initialized AuthStyleCache.
//
// The two users of this, oauth2.Config and oauth2/clientcredentials.Config,
// both would ideally just embed an unexported AuthStyleCache but because both
// were historically allowed to be copied by value we can't retroactively add an
// uncopyable Mutex to them.
//
// We could use an atomic.Pointer, but that was added recently enough (in Go
// 1.18) that we'd break Go 1.17 users where the tests as of 2023-08-03
// still pass. By using an atomic.Value, it supports both Go 1.17 and
// copying by value, even if that's not ideal.
type LazyAuthStyleCache struct {
v atomic.Value // of *AuthStyleCache
}

func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
if c, ok := lc.v.Load().(*AuthStyleCache); ok {
return c
}
c := new(AuthStyleCache)
if !lc.v.CompareAndSwap(nil, c) {
c = lc.v.Load().(*AuthStyleCache)
}
return c
}

// AuthStyleCache is the set of tokenURLs we've successfully used via
// RetrieveToken and which style auth we ended up using.
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
// the set of OAuth2 servers a program contacts over time is fixed and
// small.
var authStyleCache struct {
sync.Mutex
m map[string]AuthStyle // keyed by tokenURL
}

// ResetAuthCache resets the global authentication style cache used
// for AuthStyleUnknown token requests.
func ResetAuthCache() {
authStyleCache.Lock()
defer authStyleCache.Unlock()
authStyleCache.m = nil
type AuthStyleCache struct {
mu sync.Mutex
m map[string]AuthStyle // keyed by tokenURL
}

// lookupAuthStyle reports which auth style we last used with tokenURL
// when calling RetrieveToken and whether we have ever done so.
func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
authStyleCache.Lock()
defer authStyleCache.Unlock()
style, ok = authStyleCache.m[tokenURL]
func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
style, ok = c.m[tokenURL]
return
}

// setAuthStyle adds an entry to authStyleCache, documented above.
func setAuthStyle(tokenURL string, v AuthStyle) {
authStyleCache.Lock()
defer authStyleCache.Unlock()
if authStyleCache.m == nil {
authStyleCache.m = make(map[string]AuthStyle)
func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
c.mu.Lock()
defer c.mu.Unlock()
if c.m == nil {
c.m = make(map[string]AuthStyle)
}
authStyleCache.m[tokenURL] = v
c.m[tokenURL] = v
}

// newTokenRequest returns a new *http.Request to retrieve a new token
Expand Down Expand Up @@ -189,10 +209,10 @@ func cloneURLValues(v url.Values) url.Values {
return v2
}

func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) {
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
needsAuthStyleProbe := authStyle == 0
if needsAuthStyleProbe {
if style, ok := lookupAuthStyle(tokenURL); ok {
if style, ok := styleCache.lookupAuthStyle(tokenURL); ok {
authStyle = style
needsAuthStyleProbe = false
} else {
Expand Down Expand Up @@ -222,7 +242,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
token, err = doTokenRoundTrip(ctx, req)
}
if needsAuthStyleProbe && err == nil {
setAuthStyle(tokenURL, authStyle)
styleCache.setAuthStyle(tokenURL, authStyle)
}
// Don't overwrite `RefreshToken` with an empty value
// if this was a token refreshing request.
Expand Down
10 changes: 5 additions & 5 deletions internal/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
)

func TestRetrieveToken_InParams(t *testing.T) {
ResetAuthCache()
styleCache := new(AuthStyleCache)
const clientID = "client-id"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.FormValue("client_id"), clientID; got != want {
Expand All @@ -29,14 +29,14 @@ func TestRetrieveToken_InParams(t *testing.T) {
io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`)
}))
defer ts.Close()
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams)
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams, styleCache)
if err != nil {
t.Errorf("RetrieveToken = %v; want no error", err)
}
}

func TestRetrieveTokenWithContexts(t *testing.T) {
ResetAuthCache()
styleCache := new(AuthStyleCache)
const clientID = "client-id"

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -45,7 +45,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
}))
defer ts.Close()

_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown)
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown, styleCache)
if err != nil {
t.Errorf("RetrieveToken (with background context) = %v; want no error", err)
}
Expand All @@ -58,7 +58,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {

ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown)
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown, styleCache)
close(retrieved)
if err == nil {
t.Errorf("RetrieveToken (with cancelled context) = nil; want error")
Expand Down
4 changes: 4 additions & 0 deletions oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ type Config struct {

// Scope specifies optional requested permissions.
Scopes []string

// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
// the zero value (AuthStyleAutoDetect).
authStyleCache internal.LazyAuthStyleCache
}

// A TokenSource is anything that can return a token.
Expand Down
4 changes: 0 additions & 4 deletions oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ import (
"net/url"
"testing"
"time"

"golang.org/x/oauth2/internal"
)

type mockTransport struct {
Expand Down Expand Up @@ -355,7 +353,6 @@ func TestExchangeRequest_BadResponseType(t *testing.T) {
}

func TestExchangeRequest_NonBasicAuth(t *testing.T) {
internal.ResetAuthCache()
tr := &mockTransport{
rt: func(r *http.Request) (w *http.Response, err error) {
headerAuth := r.Header.Get("Authorization")
Expand Down Expand Up @@ -427,7 +424,6 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) {
}

func TestTokenRefreshRequest(t *testing.T) {
internal.ResetAuthCache()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/somethingelse" {
return
Expand Down
2 changes: 1 addition & 1 deletion token.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func tokenFromInternal(t *internal.Token) *Token {
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
// with an error..
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle))
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
if err != nil {
if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*RetrieveError)(rErr)
Expand Down

0 comments on commit a835fc4

Please sign in to comment.