diff --git a/docs/operator-manual/user-management/keycloak.md b/docs/operator-manual/user-management/keycloak.md index 02b0049c0b50a..b83613731a0c1 100644 --- a/docs/operator-manual/user-management/keycloak.md +++ b/docs/operator-manual/user-management/keycloak.md @@ -67,6 +67,7 @@ data: issuer: https://keycloak.example.com/realms/master clientID: argocd clientSecret: $oidc.keycloak.clientSecret + refreshTokenThreshold: 2m requestedScopes: ["openid", "profile", "email", "groups"] ``` @@ -77,6 +78,7 @@ Make sure that: - __clientID__ is set to the Client ID you configured in Keycloak - __clientSecret__ points to the right key you created in the _argocd-secret_ Secret - __requestedScopes__ contains the _groups_ claim if you didn't add it to the Default scopes +- __refreshTokenThreshold__ is less than the client token lifetime. If this setting is not less than the token lifetime, a new token will be obtained for every request. Keycloak sets the client token lifetime to 5 minutes by default. ## Keycloak and ArgoCD with PKCE @@ -135,6 +137,7 @@ data: issuer: https://keycloak.example.com/realms/master clientID: argocd enablePKCEAuthentication: true + refreshTokenThreshold: 2m requestedScopes: ["openid", "profile", "email", "groups"] ``` @@ -145,6 +148,7 @@ Make sure that: - __clientID__ is set to the Client ID you configured in Keycloak - __enablePKCEAuthentication__ must be set to true to enable correct ArgoCD behaviour with PKCE - __requestedScopes__ contains the _groups_ claim if you didn't add it to the Default scopes +- __refreshTokenThreshold__ is less than the client token lifetime. If this setting is not less than the token lifetime, a new token will be obtained for every request. Keycloak sets the client token lifetime to 5 minutes by default. ## Configuring the groups claim diff --git a/server/application/websocket.go b/server/application/websocket.go index 1afc5575f7258..3cfdc729b2c11 100644 --- a/server/application/websocket.go +++ b/server/application/websocket.go @@ -162,7 +162,7 @@ func (t *terminalSession) performValidationsAndReconnect(p []byte) (int, error) } // check if token still valid - _, newToken, err := t.sessionManager.VerifyToken(*t.token) + _, newToken, err := t.sessionManager.VerifyToken(t.ctx, *t.token) // err in case if token is revoked, newToken in case if refresh happened if err != nil || newToken != "" { // need to send reconnect code in case if token was refreshed diff --git a/server/logout/logout.go b/server/logout/logout.go index 9db19f57ae58a..75f017835cb53 100644 --- a/server/logout/logout.go +++ b/server/logout/logout.go @@ -31,7 +31,7 @@ func NewHandler(settingsMrg *settings.SettingsManager, sessionMgr *session.Sessi type Handler struct { settingsMgr *settings.SettingsManager rootPath string - verifyToken func(tokenString string) (jwt.Claims, string, error) + verifyToken func(ctx context.Context, tokenString string) (jwt.Claims, string, error) revokeToken func(ctx context.Context, id string, expiringAt time.Duration) error baseHRef string } @@ -94,7 +94,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Add("Set-Cookie", argocdCookie.String()) } - claims, _, err := h.verifyToken(tokenString) + claims, _, err := h.verifyToken(r.Context(), tokenString) if err != nil { http.Redirect(w, r, logoutRedirectURL, http.StatusSeeOther) return diff --git a/server/logout/logout_test.go b/server/logout/logout_test.go index e007d3a054204..42f59f8ca121e 100644 --- a/server/logout/logout_test.go +++ b/server/logout/logout_test.go @@ -1,6 +1,7 @@ package logout import ( + "context" "errors" "net/http" "net/http/httptest" @@ -245,28 +246,28 @@ func TestHandlerConstructLogoutURL(t *testing.T) { sessionManager := session.NewSessionManager(settingsManagerWithOIDCConfig, test.NewFakeProjLister(), "", nil, session.NewUserStateStorage(nil)) oidcHandler := NewHandler(settingsManagerWithOIDCConfig, sessionManager, rootPath, baseHRef) - oidcHandler.verifyToken = func(tokenString string) (jwt.Claims, string, error) { + oidcHandler.verifyToken = func(_ context.Context, tokenString string) (jwt.Claims, string, error) { if !validJWTPattern.MatchString(tokenString) { return nil, "", errors.New("invalid jwt") } return &jwt.RegisteredClaims{Issuer: "okta"}, "", nil } nonoidcHandler := NewHandler(settingsManagerWithoutOIDCConfig, sessionManager, "", baseHRef) - nonoidcHandler.verifyToken = func(tokenString string) (jwt.Claims, string, error) { + nonoidcHandler.verifyToken = func(_ context.Context, tokenString string) (jwt.Claims, string, error) { if !validJWTPattern.MatchString(tokenString) { return nil, "", errors.New("invalid jwt") } return &jwt.RegisteredClaims{Issuer: session.SessionManagerClaimsIssuer}, "", nil } oidcHandlerWithoutLogoutURL := NewHandler(settingsManagerWithOIDCConfigButNoLogoutURL, sessionManager, "", baseHRef) - oidcHandlerWithoutLogoutURL.verifyToken = func(tokenString string) (jwt.Claims, string, error) { + oidcHandlerWithoutLogoutURL.verifyToken = func(_ context.Context, tokenString string) (jwt.Claims, string, error) { if !validJWTPattern.MatchString(tokenString) { return nil, "", errors.New("invalid jwt") } return &jwt.RegisteredClaims{Issuer: "okta"}, "", nil } nonoidcHandlerWithMultipleURLs := NewHandler(settingsManagerWithoutOIDCAndMultipleURLs, sessionManager, "", baseHRef) - nonoidcHandlerWithMultipleURLs.verifyToken = func(tokenString string) (jwt.Claims, string, error) { + nonoidcHandlerWithMultipleURLs.verifyToken = func(_ context.Context, tokenString string) (jwt.Claims, string, error) { if !validJWTPattern.MatchString(tokenString) { return nil, "", errors.New("invalid jwt") } @@ -274,7 +275,7 @@ func TestHandlerConstructLogoutURL(t *testing.T) { } oidcHandlerWithoutBaseURL := NewHandler(settingsManagerWithOIDCConfigButNoURL, sessionManager, "argocd", baseHRef) - oidcHandlerWithoutBaseURL.verifyToken = func(tokenString string) (jwt.Claims, string, error) { + oidcHandlerWithoutBaseURL.verifyToken = func(_ context.Context, tokenString string) (jwt.Claims, string, error) { if !validJWTPattern.MatchString(tokenString) { return nil, "", errors.New("invalid jwt") } diff --git a/server/server.go b/server/server.go index 08379aa449ec7..5f02f85538822 100644 --- a/server/server.go +++ b/server/server.go @@ -323,6 +323,8 @@ func NewServer(ctx context.Context, opts ArgoCDServerOpts, appsetOpts Applicatio appsetLister := appFactory.Argoproj().V1alpha1().ApplicationSets().Lister() userStateStorage := util_session.NewUserStateStorage(opts.RedisClient) + ssoClientApp, err := oidc.NewClientApp(settings, opts.DexServerAddr, opts.DexTLSConfig, opts.BaseHRef, cacheutil.NewRedisCache(opts.RedisClient, settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone)) + errorsutil.CheckError(err) sessionMgr := util_session.NewSessionManager(settingsMgr, projLister, opts.DexServerAddr, opts.DexTLSConfig, userStateStorage) enf := rbac.NewEnforcer(opts.KubeClientset, opts.Namespace, common.ArgoCDRBACConfigMapName, nil) enf.EnableEnforce(!opts.DisableAuth) @@ -370,6 +372,7 @@ func NewServer(ctx context.Context, opts ArgoCDServerOpts, appsetOpts Applicatio a := &ArgoCDServer{ ArgoCDServerOpts: opts, ApplicationSetOpts: appsetOpts, + ssoClientApp: ssoClientApp, log: logger, settings: settings, sessionMgr: sessionMgr, @@ -1125,19 +1128,7 @@ func (server *ArgoCDServer) translateGrpcCookieHeader(ctx context.Context, w htt } func (server *ArgoCDServer) setTokenCookie(token string, w http.ResponseWriter) error { - cookiePath := "path=/" + strings.TrimRight(strings.TrimLeft(server.BaseHRef, "/"), "/") - flags := []string{cookiePath, "SameSite=lax", "httpOnly"} - if !server.Insecure { - flags = append(flags, "Secure") - } - cookies, err := httputil.MakeCookieMetadata(common.AuthCookieName, token, flags...) - if err != nil { - return fmt.Errorf("error creating cookie metadata: %w", err) - } - for _, cookie := range cookies { - w.Header().Add("Set-Cookie", cookie) - } - return nil + return httputil.SetTokenCookie(token, server.BaseHRef, !server.Insecure, w) } func withRootPath(handler http.Handler, a *ArgoCDServer) http.Handler { @@ -1221,9 +1212,6 @@ func (server *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWeb terminalOpts := application.TerminalOptions{DisableAuth: server.DisableAuth, Enf: server.enf} - // SSO ClientApp - server.ssoClientApp, _ = oidc.NewClientApp(server.settings, server.DexServerAddr, server.DexTLSConfig, server.BaseHRef, cacheutil.NewRedisCache(server.RedisClient, server.settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone)) - terminal := application.NewHandler(server.appLister, server.Namespace, server.ApplicationNamespaces, server.db, appResourceTreeFn, server.settings.ExecShells, server.sessionMgr, &terminalOpts). WithFeatureFlagMiddleware(server.settingsMgr.GetSettings) th := util_session.WithAuthMiddleware(server.DisableAuth, server.settings.IsSSOConfigured(), server.ssoClientApp, server.sessionMgr, terminal) @@ -1368,9 +1356,7 @@ func (server *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) { return } // Run dex OpenID Connect Identity Provider behind a reverse proxy (served at /api/dex) - var err error mux.HandleFunc(common.DexAPIEndpoint+"/", dexutil.NewDexHTTPReverseProxy(server.DexServerAddr, server.BaseHRef, server.DexTLSConfig)) - errorsutil.CheckError(err) mux.HandleFunc(common.LoginEndpoint, server.ssoClientApp.HandleLogin) mux.HandleFunc(common.CallbackEndpoint, server.ssoClientApp.HandleCallback) } @@ -1566,6 +1552,7 @@ func (server *ArgoCDServer) Authenticate(ctx context.Context) (context.Context, return ctx, nil } +// getClaims extracts, validates and refreshes a JWT token from an incoming request context. func (server *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { @@ -1575,17 +1562,29 @@ func (server *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, if tokenString == "" { return nil, "", ErrNoSession } - claims, newToken, err := server.sessionMgr.VerifyToken(tokenString) + // A valid argocd-issued token is automatically refreshed here prior to expiration. + // OIDC tokens will be verified but will not be refreshed here. + claims, newToken, err := server.sessionMgr.VerifyToken(ctx, tokenString) if err != nil { return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err) } finalClaims := claims if server.settings.IsSSOConfigured() { - finalClaims, err = server.ssoClientApp.SetGroupsFromUserInfo(claims, util_session.SessionManagerClaimsIssuer) + updatedClaims, err := server.ssoClientApp.SetGroupsFromUserInfo(ctx, claims, util_session.SessionManagerClaimsIssuer) if err != nil { return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err) } + finalClaims = updatedClaims + // OIDC tokens are automatically refreshed here prior to expiration + refreshedToken, err := server.ssoClientApp.CheckAndRefreshToken(ctx, updatedClaims, server.settings.OIDCRefreshTokenThreshold) + if err != nil { + log.Errorf("error checking and refreshing token: %v", err) + } + if refreshedToken != "" && refreshedToken != tokenString { + newToken = refreshedToken + log.Infof("refreshed token for subject: %v", jwtutil.StringField(updatedClaims, "sub")) + } } return finalClaims, newToken, nil diff --git a/util/http/http.go b/util/http/http.go index ceafca09f011d..3adcbe52a49e2 100644 --- a/util/http/http.go +++ b/util/http/http.go @@ -241,3 +241,23 @@ func drainBody(body io.ReadCloser) { log.Warnf("error reading response body: %s", err.Error()) } } + +func SetTokenCookie(token string, baseHRef string, isSecure bool, w http.ResponseWriter) error { + var path string + if baseHRef != "" { + path = strings.TrimRight(strings.TrimLeft(baseHRef, "/"), "/") + } + cookiePath := "path=/" + path + flags := []string{cookiePath, "SameSite=lax", "httpOnly"} + if isSecure { + flags = append(flags, "Secure") + } + cookies, err := MakeCookieMetadata(common.AuthCookieName, token, flags...) + if err != nil { + return fmt.Errorf("error creating cookie metadata: %w", err) + } + for _, cookie := range cookies { + w.Header().Add("Set-Cookie", cookie) + } + return nil +} diff --git a/util/http/http_test.go b/util/http/http_test.go index bc97742611bfc..515c6b2be0a78 100644 --- a/util/http/http_test.go +++ b/util/http/http_test.go @@ -1,10 +1,13 @@ package http import ( + "fmt" "net/http" "strings" "testing" + "github.com/argoproj/argo-cd/v3/common" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -49,6 +52,101 @@ func TestSplitCookie(t *testing.T) { assert.Equal(t, cookieValue, token) } +// mockResponseWriter is a mock implementation of http.ResponseWriter. +// It captures added headers for verification in tests. +type mockResponseWriter struct { + header http.Header +} + +func (m *mockResponseWriter) Header() http.Header { + if m.header == nil { + m.header = make(http.Header) + } + return m.header +} +func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil } +func (m *mockResponseWriter) WriteHeader(_ int) {} + +func TestSetTokenCookie(t *testing.T) { + tests := []struct { + name string + token string + baseHRef string + isSecure bool + expectedCookies []string // Expected Set-Cookie header values + }{ + { + name: "Insecure cookie", + token: "insecure-token", + baseHRef: "", + isSecure: false, + expectedCookies: []string{ + fmt.Sprintf("%s=%s; path=/; SameSite=lax; httpOnly", common.AuthCookieName, "insecure-token"), + }, + }, + { + name: "Secure cookie", + token: "secure-token", + baseHRef: "", + isSecure: true, + expectedCookies: []string{ + fmt.Sprintf("%s=%s; path=/; SameSite=lax; httpOnly; Secure", common.AuthCookieName, "secure-token"), + }, + }, + { + name: "Insecure cookie with baseHRef", + token: "token-with-path", + baseHRef: "/app", + isSecure: false, + expectedCookies: []string{ + fmt.Sprintf("%s=%s; path=/app; SameSite=lax; httpOnly", common.AuthCookieName, "token-with-path"), + }, + }, + { + name: "Secure cookie with baseHRef", + token: "secure-token-with-path", + baseHRef: "app/", + isSecure: true, + expectedCookies: []string{ + fmt.Sprintf("%s=%s; path=/app; SameSite=lax; httpOnly; Secure", common.AuthCookieName, "secure-token-with-path"), + }, + }, + { + name: "Unsecured cookie, baseHRef with multiple segments and mixed slashes", + token: "complex-path-token", + baseHRef: "///api/v1/auth///", + isSecure: false, + expectedCookies: []string{ + fmt.Sprintf("%s=%s; path=/api/v1/auth; SameSite=lax; httpOnly", common.AuthCookieName, "complex-path-token"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &mockResponseWriter{} + + err := SetTokenCookie(tt.token, tt.baseHRef, tt.isSecure, w) + if err != nil { + t.Fatalf("%s: Unexpected error: %v", tt.name, err) + } + + setCookieHeaders := w.Header()["Set-Cookie"] + + if len(setCookieHeaders) != len(tt.expectedCookies) { + t.Errorf("Mistmatch in Set-Cookie header length: %s\nExpected: %d\nGot: %d", + tt.name, len(tt.expectedCookies), len(setCookieHeaders)) + return + } + + if len(tt.expectedCookies) > 0 && setCookieHeaders[0] != tt.expectedCookies[0] { + t.Errorf("Mismatch in Set-Cookie header: %s\nExpected: %s\nGot: %s", + tt.name, tt.expectedCookies[0], setCookieHeaders[0]) + } + }) + } +} + // TestRoundTripper just copy request headers to the resposne. type TestRoundTripper struct{} diff --git a/util/oidc/oidc.go b/util/oidc/oidc.go index 76c3dfd10f9ab..61f63152a770a 100644 --- a/util/oidc/oidc.go +++ b/util/oidc/oidc.go @@ -43,6 +43,7 @@ const ( ResponseTypeCode = "code" UserInfoResponseCachePrefix = "userinfo_response" AccessTokenCachePrefix = "access_token" + OidcTokenCachePrefix = "oidc_token" ) // OIDCConfiguration holds a subset of interested fields from the OIDC configuration spec @@ -87,6 +88,8 @@ type ClientApp struct { clientCache cache.CacheClient // properties for azure workload identity. azure azureApp + // preemptive token refresh threshold + refreshTokenThreshold time.Duration } type azureApp struct { @@ -98,6 +101,63 @@ type azureApp struct { mtx *sync.RWMutex } +// OidcTokenCache is a serialization wrapper around oauth2 provider configuration needed to generate a TokenSource +type OidcTokenCache struct { + // Redirect URL is needed for oauth2 config initialization + RedirectURL string `json:"redirect_url"` + // oauth2 Token + Token *oauth2.Token `json:"token"` + // TokenExtraIdToken captures value of id_token + TokenExtraIdToken string `json:"token_extra_id_token"` +} + +// NewOidcTokenCache initializes the struct from a redirect URL and an existing token +func NewOidcTokenCache(redirectURL string, token *oauth2.Token) *OidcTokenCache { + var idToken string + if token.Extra("id_token") == nil { + idToken = "" + } else { + idToken = token.Extra("id_token").(string) + } + return &OidcTokenCache{ + RedirectURL: redirectURL, + Token: token, + TokenExtraIdToken: idToken, + } +} + +// GetOidcTokenCacheFromJSON deserializes the json representation of OidcTokenCache. The Token extra map is updated from +// the serialization wrapper to propagate the id_token. This will ensure that the TokenSource always retrieves a usable token. +func GetOidcTokenCacheFromJSON(jsonBytes []byte) (*OidcTokenCache, error) { + var newToken OidcTokenCache + err := json.Unmarshal(jsonBytes, &newToken) + if err != nil { + return nil, err + } + if newToken.Token == nil { + return nil, errors.New("empty token") + } + newToken.Token = newToken.Token.WithExtra(map[string]any{ + "id_token": newToken.TokenExtraIdToken, + }) + return &newToken, nil +} + +// GetTokenSourceFromCache creates an oauth2 TokenSource from a cached oidc token. The TokenSource will be configured +// with an early expiration based on the refreshTokenThreshold. +func (a *ClientApp) GetTokenSourceFromCache(ctx context.Context, oidcTokenCache *OidcTokenCache) (oauth2.TokenSource, error) { + if oidcTokenCache == nil { + return nil, errors.New("oidcTokenCache is required") + } + config, err := a.getOauth2ConfigForRedirectURI(oidcTokenCache.RedirectURL) + if err != nil { + return nil, err + } + baseTokenSource := config.TokenSource(ctx, oidcTokenCache.Token) + tokenRefresher := oauth2.ReuseTokenSourceWithExpiry(oidcTokenCache.Token, baseTokenSource, a.refreshTokenThreshold) + return tokenRefresher, nil +} + func GetScopesOrDefault(scopes []string) []string { if len(scopes) == 0 { return []string{"openid", "profile", "email", "groups"} @@ -127,6 +187,7 @@ func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTL encryptionKey: encryptionKey, clientCache: cacheClient, azure: azureApp{mtx: &sync.RWMutex{}}, + refreshTokenThreshold: settings.OIDCRefreshTokenThreshold, } log.Infof("Creating client app (%s)", a.clientID) u, err := url.Parse(settings.URL) @@ -165,23 +226,27 @@ func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTL return &a, nil } -func (a *ClientApp) oauth2Config(request *http.Request, scopes []string) (*oauth2.Config, error) { - endpoint, err := a.provider.Endpoint() +func (a *ClientApp) getRedirectURIForRequest(req *http.Request) string { + redirectURI, err := a.settings.RedirectURLForRequest(req) if err != nil { - return nil, err + log.Warnf("Unable to find ArgoCD URL from request, falling back to configured redirect URI: %v", err) + redirectURI = a.redirectURI } - redirectURL, err := a.settings.RedirectURLForRequest(request) + return redirectURI +} + +func (a *ClientApp) getOauth2ConfigForRedirectURI(redirectURI string) (*oauth2.Config, error) { + endpoint, err := a.provider.Endpoint() if err != nil { - log.Warnf("Unable to find ArgoCD URL from request, falling back to configured redirect URI: %v", err) - redirectURL = a.redirectURI + return nil, err } return &oauth2.Config{ ClientID: a.clientID, ClientSecret: a.clientSecret, Endpoint: *endpoint, - Scopes: scopes, - RedirectURL: redirectURL, + Scopes: a.getScopes(), + RedirectURL: redirectURI, }, nil } @@ -315,17 +380,13 @@ func (a *ClientApp) HandleLogin(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - scopes := make([]string, 0) pkceVerifier := "" var opts []oauth2.AuthCodeOption if config := a.settings.OIDCConfig(); config != nil { - scopes = GetScopesOrDefault(config.RequestedScopes) opts = AppendClaimsAuthenticationRequestParameter(opts, config.RequestedIDTokenClaims) - } else if a.settings.IsDexConfigured() { - scopes = append(GetScopesOrDefault(nil), common.DexFederatedScope) } - oauth2Config, err := a.oauth2Config(r, scopes) + oauth2Config, err := a.getOauth2ConfigForRedirectURI(a.getRedirectURIForRequest(r)) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -406,7 +467,7 @@ func (a *azureApp) getFederatedServiceAccountToken(context.Context) (string, err // HandleCallback is the callback handler for an OAuth2 login flow func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) { - oauth2Config, err := a.oauth2Config(r, nil) + oauth2Config, err := a.getOauth2ConfigForRedirectURI(a.getRedirectURIForRequest(r)) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -456,27 +517,21 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) { return } + // Parse out id token idTokenRAW, ok := token.Extra("id_token").(string) if !ok { http.Error(w, "no id_token in token response", http.StatusInternalServerError) return } - idToken, err := a.provider.Verify(idTokenRAW, a.settings) + idToken, err := a.provider.Verify(ctx, idTokenRAW, a.settings) if err != nil { - log.Warnf("Failed to verify token: %s", err) + log.Warnf("Failed to verify oidc token: %s", err) http.Error(w, common.TokenVerificationError, http.StatusInternalServerError) return } - path := "/" - if a.baseHRef != "" { - path = strings.TrimRight(strings.TrimLeft(a.baseHRef, "/"), "/") - } - cookiePath := "path=/" + path - flags := []string{cookiePath, "SameSite=lax", "httpOnly"} - if a.secureCookie { - flags = append(flags, "Secure") - } + + // Set cache var claims jwt.MapClaims err = idToken.Claims(&claims) if err != nil { @@ -484,38 +539,38 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) { return } // save the accessToken in memory for later use - encToken, err := crypto.Encrypt([]byte(token.AccessToken), a.encryptionKey) + sub := jwtutil.StringField(claims, "sub") + err = a.SetValueInEncryptedCache(FormatAccessTokenCacheKey(sub), []byte(token.AccessToken), GetTokenExpiration(claims)) if err != nil { claimsJSON, _ := json.Marshal(claims) - http.Error(w, "failed encrypting token", http.StatusInternalServerError) - log.Errorf("cannot encrypt accessToken: %v (claims=%s)", err, claimsJSON) + log.Errorf("cannot cache encrypted accessToken: %v (claims=%s)", err, claimsJSON) + http.Error(w, err.Error(), http.StatusInternalServerError) return } - sub := jwtutil.StringField(claims, "sub") - err = a.clientCache.Set(&cache.Item{ - Key: FormatAccessTokenCacheKey(sub), - Object: encToken, - CacheActionOpts: cache.CacheActionOpts{ - Expiration: getTokenExpiration(claims), - }, - }) + + // Cache encrypted raw token for background refresh + oidcTokenCache := NewOidcTokenCache(a.getRedirectURIForRequest(r), token) + oidcTokenCacheJSON, err := json.Marshal(oidcTokenCache) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + sid := jwtutil.StringField(claims, "sid") + err = a.SetValueInEncryptedCache(formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, GetTokenExpiration(claims)) if err != nil { claimsJSON, _ := json.Marshal(claims) - http.Error(w, fmt.Sprintf("claims=%s, err=%v", claimsJSON, err), http.StatusInternalServerError) + log.Errorf("cannot cache encrypted oidc token: %v (claims=%s)", err, claimsJSON) + http.Error(w, err.Error(), http.StatusInternalServerError) return } if idTokenRAW != "" { - cookies, err := httputil.MakeCookieMetadata(common.AuthCookieName, idTokenRAW, flags...) + err = httputil.SetTokenCookie(idTokenRAW, a.baseHRef, a.secureCookie, w) if err != nil { claimsJSON, _ := json.Marshal(claims) http.Error(w, fmt.Sprintf("claims=%s, err=%v", claimsJSON, err), http.StatusInternalServerError) return } - - for _, cookie := range cookies { - w.Header().Add("Set-Cookie", cookie) - } } claimsJSON, _ := json.Marshal(claims) @@ -528,6 +583,109 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) { } } +// GetValueFromEncryptedCache is a convenience method for retreiving a value from cache and decrypting it. If the cache +// does not contain a value for the given key, a nil value is returned. Return handling should check for error and then +// check for nil. +func (a *ClientApp) GetValueFromEncryptedCache(key string) (value []byte, err error) { + var encryptedValue []byte + err = a.clientCache.Get(key, &encryptedValue) + if err != nil { + if errors.Is(err, cache.ErrCacheMiss) { + // Return nil to signify a cache miss + return nil, nil + } + return nil, fmt.Errorf("failed to get encrypted value from cache: %w", err) + } + value, err = crypto.Decrypt(encryptedValue, a.encryptionKey) + if err != nil { + return nil, fmt.Errorf("failed to decrypt value from cache: %w", err) + } + return value, err +} + +// SetValueFromEncyrptedCache is a convenience method for encrypting a value and storing it in the cache at a given key. +// Cache expiration is set based on input. +func (a *ClientApp) SetValueInEncryptedCache(key string, value []byte, expiration time.Duration) error { + encryptedValue, err := crypto.Encrypt(value, a.encryptionKey) + if err != nil { + return err + } + err = a.clientCache.Set(&cache.Item{ + Key: key, + Object: encryptedValue, + CacheActionOpts: cache.CacheActionOpts{ + Expiration: expiration, + }, + }) + if err != nil { + return err + } + return nil +} + +func (a *ClientApp) CheckAndRefreshToken(ctx context.Context, groupClaims jwt.MapClaims, refreshTokenThreshold time.Duration) (string, error) { + sub := jwtutil.StringField(groupClaims, "sub") + sid := jwtutil.StringField(groupClaims, "sid") + if GetTokenExpiration(groupClaims) < refreshTokenThreshold { + token, err := a.GetUpdatedOidcTokenFromCache(ctx, sub, sid) + if err != nil { + log.Errorf("Failed to get token from cache: %v", err) + return "", err + } + if token != nil { + idTokenRAW, ok := token.Extra("id_token").(string) + if !ok { + return "", errors.New("empty id_token") + } + return idTokenRAW, nil + } + } + return "", nil +} + +// GetUpdatedOidcTokenFromCache fetches a token from cache and refreshes it if under the threshold for expiration. +// The cached token will also be updated if it is refreshed. Returns latest token or an error if the process fails. +func (a *ClientApp) GetUpdatedOidcTokenFromCache(ctx context.Context, subject string, sessionId string) (*oauth2.Token, error) { + ctx = gooidc.ClientContext(ctx, a.client) + + // Get oauth2 config + cacheKey := formatOidcTokenCacheKey(subject, sessionId) + oidcTokenCacheJSON, err := a.GetValueFromEncryptedCache(cacheKey) + if err != nil { + return nil, err + } + if oidcTokenCacheJSON == nil { + return nil, nil + } + + oidcTokenCache, err := GetOidcTokenCacheFromJSON(oidcTokenCacheJSON) + if err != nil { + err = fmt.Errorf("failed to unmarshal cached oidc token: %w", err) + return nil, err + } + tokenSource, err := a.GetTokenSourceFromCache(ctx, oidcTokenCache) + if err != nil { + err = fmt.Errorf("failed to get token source from cached oidc token: %w", err) + return nil, err + } + token, err := tokenSource.Token() + if err != nil { + return nil, fmt.Errorf("failed to refresh token from source: %w", err) + } + if token.AccessToken != oidcTokenCache.Token.AccessToken { + oidcTokenCache = NewOidcTokenCache(oidcTokenCache.RedirectURL, token) + oidcTokenCacheJSON, err = json.Marshal(oidcTokenCache) + if err != nil { + return nil, fmt.Errorf("failed to marshal oidc oidcTokenCache refresher: %w", err) + } + err = a.SetValueInEncryptedCache(cacheKey, oidcTokenCacheJSON, time.Until(token.Expiry)) + if err != nil { + return nil, err + } + } + return token, nil +} + var implicitFlowTmpl = template.Must(template.New("implicit.html").Parse(`