diff --git a/app.go b/app.go index fcfd485..29eb13e 100644 --- a/app.go +++ b/app.go @@ -15,11 +15,10 @@ type Config struct { } type App struct { - ID string - JWKS jwk.Set - Config *Config - client *ClientWithResponses - jwksCache *jwk.Cache + ID string + Config *Config + client *ClientWithResponses + jwksCacheSet jwk.Set } func New(appID string, config *Config) (*App, error) { @@ -42,15 +41,18 @@ func New(appID string, config *Config) (*App, error) { client: client, } - app.jwksCache = jwk.NewCache(context.Background()) - if err := app.jwksCache.Register(fmt.Sprintf(jwksUrl, appID)); err != nil { + url := fmt.Sprintf(jwksUrl, appID) + cache := jwk.NewCache(context.Background()) + if err := cache.Register(url); err != nil { return nil, err } - if err := app.refreshJWKSCache(); err != nil { - return nil, err + if _, err = cache.Refresh(context.Background(), url); err != nil { + return nil, Error{Message: "failed to fetch jwks"} } + app.jwksCacheSet = jwk.NewCachedSet(cache, url) + return &app, nil } diff --git a/app_test.go b/app_test.go index 64ea5ad..6f94f77 100644 --- a/app_test.go +++ b/app_test.go @@ -40,14 +40,6 @@ func TestGetApp(t *testing.T) { } -func TestAppNewJWKSCache(t *testing.T) { - psg, err := passage.New(PassageAppID, &passage.Config{ - APIKey: PassageApiKey, // An API_KEY environment variable is required for testing. - }) - require.Nil(t, err) - assert.NotNil(t, psg.JWKS) -} - // should be run with the -race flag, i.e. `go test -race -run TestAppJWKSCacheWriteConcurrency` func TestAppJWKSCacheWriteConcurrency(t *testing.T) { goRoutineCount := 2 diff --git a/authentication.go b/authentication.go index f01d6bc..b396a56 100644 --- a/authentication.go +++ b/authentication.go @@ -1,7 +1,6 @@ package passage import ( - "context" "fmt" "net/http" "strings" @@ -41,17 +40,9 @@ func (a *App) getPublicKey(token *jwt.Token) (interface{}, error) { return nil, Error{Message: "expecting JWT header to have string kid"} } - key, ok := a.JWKS.LookupKeyID(keyID) - // if key doesn't exist, re-fetch one more time to see if this jwk was just added + key, ok := a.jwksCacheSet.LookupKeyID(keyID) if !ok { - if err := a.refreshJWKSCache(); err != nil { - return nil, err - } - - key, ok = a.JWKS.LookupKeyID(keyID) - if !ok { - return nil, Error{Message: fmt.Sprintf("unable to find key %q", keyID)} - } + return nil, Error{Message: fmt.Sprintf("unable to find key %q", keyID)} } var pubKey interface{} @@ -60,15 +51,6 @@ func (a *App) getPublicKey(token *jwt.Token) (interface{}, error) { return pubKey, err } -func (a *App) refreshJWKSCache() error { - var err error - if a.JWKS, err = a.jwksCache.Refresh(context.Background(), fmt.Sprintf(jwksUrl, a.ID)); err != nil { - return Error{Message: "failed to fetch jwks"} - } - - return nil -} - // AuthenticateRequestWithCookie fetches a cookie from the request and uses it to authenticate // returns the userID (string) on success, error on failure func (a *App) AuthenticateRequestWithCookie(r *http.Request) (string, error) {