From fc661ae5a9c777d0a7777a156abe8e1368ec80fc Mon Sep 17 00:00:00 2001 From: Michael Crenshaw <350466+crenshaw-dev@users.noreply.github.com> Date: Thu, 29 May 2025 08:49:31 -0600 Subject: [PATCH 1/3] fix(server): avoid unecessary claims restrictions (#22973) Signed-off-by: Michael Crenshaw <350466+crenshaw-dev@users.noreply.github.com> --- cmd/argocd/commands/login.go | 21 ++--- cmd/argocd/commands/login_test.go | 16 ++-- cmd/argocd/commands/project_role.go | 19 ++-- server/rbacpolicy/rbacpolicy.go | 6 +- server/server.go | 33 +++---- util/claims/claims.go | 60 ++++--------- util/claims/claims_test.go | 129 +++++----------------------- util/rbac/rbac.go | 14 ++- util/session/sessionmanager.go | 51 ++++------- util/session/sessionmanager_test.go | 27 ++---- 10 files changed, 108 insertions(+), 268 deletions(-) diff --git a/cmd/argocd/commands/login.go b/cmd/argocd/commands/login.go index 07854f2aafbf7..a9a9a1cbe5687 100644 --- a/cmd/argocd/commands/login.go +++ b/cmd/argocd/commands/login.go @@ -12,6 +12,8 @@ import ( "strings" "time" + jwtutil "github.com/argoproj/argo-cd/v3/util/jwt" + "github.com/coreos/go-oidc/v3/oidc" "github.com/golang-jwt/jwt/v5" log "github.com/sirupsen/logrus" @@ -143,9 +145,7 @@ argocd login cd.argoproj.io --core`, claims := jwt.MapClaims{} _, _, err := parser.ParseUnverified(tokenString, &claims) errors.CheckError(err) - argoClaims, err := claimsutil.MapClaimsToArgoClaims(claims) - errors.CheckError(err) - fmt.Printf("'%s' logged in successfully\n", userDisplayName(argoClaims)) + fmt.Printf("'%s' logged in successfully\n", userDisplayName(claims)) } // login successful. Persist the config @@ -192,17 +192,14 @@ argocd login cd.argoproj.io --core`, return command } -func userDisplayName(claims *claimsutil.ArgoClaims) string { - if claims == nil { - return "" - } - if claims.Email != "" { - return claims.Email +func userDisplayName(claims jwt.MapClaims) string { + if email := jwtutil.StringField(claims, "email"); email != "" { + return email } - if claims.Name != "" { - return claims.Name + if name := jwtutil.StringField(claims, "name"); name != "" { + return name } - return claims.GetUserIdentifier() + return claimsutil.GetUserIdentifier(claims) } // oauth2Login opens a browser, runs a temporary HTTP server to delegate OAuth2 login flow and diff --git a/cmd/argocd/commands/login_test.go b/cmd/argocd/commands/login_test.go index 43cbf1febfbef..393ddec97a6e2 100644 --- a/cmd/argocd/commands/login_test.go +++ b/cmd/argocd/commands/login_test.go @@ -5,12 +5,10 @@ import ( "os" "testing" - claimsutil "github.com/argoproj/argo-cd/v3/util/claims" utilio "github.com/argoproj/argo-cd/v3/util/io" "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func captureStdout(callback func()) (string, error) { @@ -37,31 +35,28 @@ func captureStdout(callback func()) (string, error) { } func Test_userDisplayName_email(t *testing.T) { - claims, err := claimsutil.MapClaimsToArgoClaims(jwt.MapClaims{"iss": "qux", "sub": "foo", "email": "firstname.lastname@example.com", "groups": []string{"baz"}}) - require.NoError(t, err) + claims := jwt.MapClaims{"iss": "qux", "sub": "foo", "email": "firstname.lastname@example.com", "groups": []string{"baz"}} actualName := userDisplayName(claims) expectedName := "firstname.lastname@example.com" assert.Equal(t, expectedName, actualName) } func Test_userDisplayName_name(t *testing.T) { - claims, err := claimsutil.MapClaimsToArgoClaims(jwt.MapClaims{"iss": "qux", "sub": "foo", "name": "Firstname Lastname", "groups": []string{"baz"}}) - require.NoError(t, err) + claims := jwt.MapClaims{"iss": "qux", "sub": "foo", "name": "Firstname Lastname", "groups": []string{"baz"}} actualName := userDisplayName(claims) expectedName := "Firstname Lastname" assert.Equal(t, expectedName, actualName) } func Test_userDisplayName_sub(t *testing.T) { - claims, err := claimsutil.MapClaimsToArgoClaims(jwt.MapClaims{"iss": "qux", "sub": "foo", "groups": []string{"baz"}}) - require.NoError(t, err) + claims := jwt.MapClaims{"iss": "qux", "sub": "foo", "groups": []string{"baz"}} actualName := userDisplayName(claims) expectedName := "foo" assert.Equal(t, expectedName, actualName) } func Test_userDisplayName_federatedClaims(t *testing.T) { - claims, err := claimsutil.MapClaimsToArgoClaims(jwt.MapClaims{ + claims := jwt.MapClaims{ "iss": "qux", "sub": "foo", "groups": []string{"baz"}, @@ -69,8 +64,7 @@ func Test_userDisplayName_federatedClaims(t *testing.T) { "connector_id": "dex", "user_id": "ldap-123", }, - }) - require.NoError(t, err) + } actualName := userDisplayName(claims) expectedName := "ldap-123" assert.Equal(t, expectedName, actualName) diff --git a/cmd/argocd/commands/project_role.go b/cmd/argocd/commands/project_role.go index f8715c3c38866..1fb30623f9083 100644 --- a/cmd/argocd/commands/project_role.go +++ b/cmd/argocd/commands/project_role.go @@ -7,6 +7,8 @@ import ( "text/tabwriter" "time" + claimsutil "github.com/argoproj/argo-cd/v3/util/claims" + timeutil "github.com/argoproj/pkg/v2/time" jwtgo "github.com/golang-jwt/jwt/v5" "github.com/spf13/cobra" @@ -16,7 +18,6 @@ import ( argocdclient "github.com/argoproj/argo-cd/v3/pkg/apiclient" projectpkg "github.com/argoproj/argo-cd/v3/pkg/apiclient/project" "github.com/argoproj/argo-cd/v3/pkg/apis/application/v1alpha1" - claimsutil "github.com/argoproj/argo-cd/v3/util/claims" "github.com/argoproj/argo-cd/v3/util/errors" utilio "github.com/argoproj/argo-cd/v3/util/io" "github.com/argoproj/argo-cd/v3/util/jwt" @@ -322,25 +323,19 @@ Create token succeeded for proj:test-project:test-role. }) errors.CheckError(err) - var claims jwtgo.MapClaims - _, _, err = jwtgo.NewParser().ParseUnverified(tokenResponse.Token, &claims) - if err != nil { + token, err := jwtgo.Parse(tokenResponse.Token, nil) + if token == nil { err = fmt.Errorf("received malformed token %w", err) errors.CheckError(err) return } - argoClaims, err := claimsutil.MapClaimsToArgoClaims(claims) - if err != nil { - errors.CheckError(fmt.Errorf("invalid argo claims: %w", err)) - return - } + claims := token.Claims.(jwtgo.MapClaims) issuedAt, _ := jwt.IssuedAt(claims) expiresAt := int64(jwt.Float64Field(claims, "exp")) - id := argoClaims.ID - subject := argoClaims.GetUserIdentifier() - + id := jwt.StringField(claims, "jti") + subject := claimsutil.GetUserIdentifier(claims) if !outputTokenOnly { fmt.Printf("Create token succeeded for %s.\n", subject) fmt.Printf(" ID: %s\n Issued At: %s\n Expires At: %s\n", diff --git a/server/rbacpolicy/rbacpolicy.go b/server/rbacpolicy/rbacpolicy.go index 545f3636769d8..3111c82681dd7 100644 --- a/server/rbacpolicy/rbacpolicy.go +++ b/server/rbacpolicy/rbacpolicy.go @@ -62,12 +62,8 @@ func (p *RBACPolicyEnforcer) EnforceClaims(claims jwt.Claims, rvals ...any) bool if err != nil { return false } - argoClaims, err := claimsutil.MapClaimsToArgoClaims(mapClaims) - if err != nil { - return false - } - subject := argoClaims.GetUserIdentifier() + subject := claimsutil.GetUserIdentifier(mapClaims) // Check if the request is for an application resource. We have special enforcement which takes // into consideration the project's token and group bindings var runtimePolicy string diff --git a/server/server.go b/server/server.go index 4894561d209bc..07ab64f2b3593 100644 --- a/server/server.go +++ b/server/server.go @@ -106,7 +106,6 @@ import ( "github.com/argoproj/argo-cd/v3/ui" "github.com/argoproj/argo-cd/v3/util/assets" cacheutil "github.com/argoproj/argo-cd/v3/util/cache" - claimsutil "github.com/argoproj/argo-cd/v3/util/claims" "github.com/argoproj/argo-cd/v3/util/db" dexutil "github.com/argoproj/argo-cd/v3/util/dex" "github.com/argoproj/argo-cd/v3/util/env" @@ -1560,19 +1559,19 @@ func (server *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err) } - mapClaims, err := jwtutil.MapClaims(claims) - if err != nil { - return claims, "", status.Errorf(codes.Internal, "invalid claims") - } - argoClaims, err := claimsutil.MapClaimsToArgoClaims(mapClaims) - if err != nil { - return claims, "", status.Errorf(codes.Internal, "invalid argo claims") + // Some SSO implementations (Okta) require a call to + // the OIDC user info path to get attributes like groups + // we assume that everywhere in argocd jwt.MapClaims is used as type for interface jwt.Claims + // otherwise this would cause a panic + var groupClaims jwt.MapClaims + if groupClaims, ok = claims.(jwt.MapClaims); !ok { + if tmpClaims, ok := claims.(*jwt.MapClaims); ok { + groupClaims = *tmpClaims + } } - - // Some SSO implementations (Okta) require a call to the OIDC user info path to get attributes like groups - iss := jwtutil.StringField(mapClaims, "iss") + iss := jwtutil.StringField(groupClaims, "iss") if iss != util_session.SessionManagerClaimsIssuer && server.settings.UserInfoGroupsEnabled() && server.settings.UserInfoPath() != "" { - userInfo, unauthorized, err := server.ssoClientApp.GetUserInfo(mapClaims, server.settings.IssuerURL(), server.settings.UserInfoPath()) + userInfo, unauthorized, err := server.ssoClientApp.GetUserInfo(groupClaims, server.settings.IssuerURL(), server.settings.UserInfoPath()) if unauthorized { log.Errorf("error while quering userinfo endpoint: %v", err) return claims, "", status.Errorf(codes.Unauthenticated, "invalid session") @@ -1581,17 +1580,13 @@ func (server *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, log.Errorf("error fetching user info endpoint: %v", err) return claims, "", status.Errorf(codes.Internal, "invalid userinfo response") } - userInfoClaims, err := claimsutil.MapClaimsToArgoClaims(userInfo) - if err != nil { - return claims, "", status.Errorf(codes.Internal, "invalid userinfo claims") - } - if argoClaims.Subject != userInfoClaims.Subject { + if groupClaims["sub"] != userInfo["sub"] { return claims, "", status.Error(codes.Unknown, "subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo") } - mapClaims["groups"] = userInfo["groups"] + groupClaims["groups"] = userInfo["groups"] } - return mapClaims, newToken, nil + return groupClaims, newToken, nil } // getToken extracts the token from gRPC metadata or cookie headers diff --git a/util/claims/claims.go b/util/claims/claims.go index 2b77e53b5b9f8..db7f1f81c8f9a 100644 --- a/util/claims/claims.go +++ b/util/claims/claims.go @@ -1,56 +1,34 @@ package claims import ( - "encoding/json" - "github.com/golang-jwt/jwt/v5" ) -// ArgoClaims defines the claims structure based on Dex's documented claims -type ArgoClaims struct { - jwt.RegisteredClaims - Email string `json:"email,omitempty"` - EmailVerified bool `json:"email_verified,omitempty"` - Name string `json:"name,omitempty"` - Groups []string `json:"groups,omitempty"` - // As per Dex docs, federated_claims has a specific structure - FederatedClaims *FederatedClaims `json:"federated_claims,omitempty"` -} - -// FederatedClaims represents the structure documented by Dex -type FederatedClaims struct { - ConnectorID string `json:"connector_id"` - UserID string `json:"user_id"` -} - -// MapClaimsToArgoClaims converts a jwt.MapClaims to a ArgoClaims -func MapClaimsToArgoClaims(claims jwt.MapClaims) (*ArgoClaims, error) { - if claims == nil { - return &ArgoClaims{}, nil +// GetUserIdentifier returns a consistent user identifier, checking federated_claims.user_id when Dex is in use +func GetUserIdentifier(c jwt.MapClaims) string { + if c == nil { + return "" } - claimsBytes, err := json.Marshal(claims) + // Fallback to sub if federated_claims.user_id is not set. + fallback, err := c.GetSubject() if err != nil { - return nil, err + fallback = "" } - var argoClaims ArgoClaims - err = json.Unmarshal(claimsBytes, &argoClaims) - if err != nil { - return nil, err + f := c["federated_claims"] + if f == nil { + return fallback } - return &argoClaims, nil -} - -// GetUserIdentifier returns a consistent user identifier, checking federated_claims.user_id when Dex is in use -func (c *ArgoClaims) GetUserIdentifier() string { - // Check federated claims first - if c.FederatedClaims != nil && c.FederatedClaims.UserID != "" { - return c.FederatedClaims.UserID + federatedClaims, ok := f.(map[string]any) + if !ok { + return fallback } - // Fallback to sub - if c.Subject != "" { - return c.Subject + + userId, ok := federatedClaims["user_id"].(string) + if !ok || userId == "" { + return fallback } - return "" + + return userId } diff --git a/util/claims/claims_test.go b/util/claims/claims_test.go index 6b9a3fc8a5291..7b28b526f25ac 100644 --- a/util/claims/claims_test.go +++ b/util/claims/claims_test.go @@ -1,159 +1,76 @@ package claims import ( - "reflect" "testing" - "time" "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestGetUserIdentifier(t *testing.T) { tests := []struct { name string - claims *ArgoClaims + claims jwt.MapClaims want string }{ { name: "when both dex and sub defined - prefer dex user_id", - claims: &ArgoClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Subject: "ignored:login", - }, - FederatedClaims: &FederatedClaims{ - UserID: "dex-user", + claims: jwt.MapClaims{ + "sub": "ignored:login", + "federated_claims": map[string]any{ + "user_id": "dex-user", }, }, want: "dex-user", }, { name: "when both dex and sub defined but dex user_id empty - fallback to sub", - claims: &ArgoClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Subject: "test:apiKey", - }, - FederatedClaims: &FederatedClaims{ - UserID: "", + claims: jwt.MapClaims{ + "sub": "test:apiKey", + "federated_claims": map[string]any{ + "user_id": "", }, }, want: "test:apiKey", }, { name: "when only sub is defined (no dex) - use sub", - claims: &ArgoClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Subject: "admin:login", - }, + claims: jwt.MapClaims{ + "sub": "admin:login", }, want: "admin:login", }, { name: "when neither dex nor sub defined - return empty", - claims: &ArgoClaims{}, + claims: jwt.MapClaims{}, want: "", }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.claims.GetUserIdentifier() - assert.Equal(t, tt.want, got) - }) - } -} - -func TestMapClaimsToArgoClaims(t *testing.T) { - expectedExpiredAt := jwt.NewNumericDate(time.Now().Add(time.Hour)) - expectedIssuedAt := jwt.NewNumericDate(time.Now().Add(time.Hour * -2)) - expectedNotBefore := jwt.NewNumericDate(time.Now().Add(time.Hour * -3)) - - tests := []struct { - name string - claims jwt.MapClaims - want *ArgoClaims - wantErr bool - }{ { name: "nil claims", claims: nil, - want: &ArgoClaims{}, - }, - { - name: "empty claims", - claims: jwt.MapClaims{}, - want: &ArgoClaims{}, - }, - { - name: "invalid claims", - claims: jwt.MapClaims{ - "email_verified": "not-a-bool", - }, - wantErr: true, + want: "", }, { - name: "all registered known claims", + name: "invalid subject", claims: jwt.MapClaims{ - "jti": "jti", - "iss": "iss", - "sub": "sub", - "aud": "aud", - "iat": expectedIssuedAt.Unix(), - "exp": expectedExpiredAt.Unix(), - "nbf": expectedNotBefore.Unix(), - }, - want: &ArgoClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - ID: "jti", - Issuer: "iss", - Subject: "sub", - Audience: jwt.ClaimStrings{"aud"}, - ExpiresAt: expectedExpiredAt, - IssuedAt: expectedIssuedAt, - NotBefore: expectedNotBefore, - }, + "sub": nil, }, + want: "", }, { - name: "all argo claims", + name: "invalid federated_claims", claims: jwt.MapClaims{ - "email": "email@test.com", - "email_verified": true, - "name": "the-name", - "groups": []string{ - "my-org:my-team2", - "my-org:my-team1", - }, - "federated_claims": map[string]any{ - "connector_id": "my-connector", - "user_id": "user-id", - }, - }, - want: &ArgoClaims{ - Email: "email@test.com", - EmailVerified: true, - Name: "the-name", - Groups: []string{ - "my-org:my-team2", - "my-org:my-team1", - }, - FederatedClaims: &FederatedClaims{ - ConnectorID: "my-connector", - UserID: "user-id", - }, + "sub": "test:apiKey", + "federated_claims": "invalid", }, + want: "test:apiKey", }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := MapClaimsToArgoClaims(tt.claims) - if tt.wantErr { - assert.Error(t, err, "MapClaimsToArgoClaims()") - } else { - require.NoError(t, err, "MapClaimsToArgoClaims()") - assert.Truef(t, reflect.DeepEqual(got, tt.want), "MapClaimsToArgoClaims() = %v, want %v", got, tt.want) - } + got := GetUserIdentifier(tt.claims) + assert.Equal(t, tt.want, got) }) } } diff --git a/util/rbac/rbac.go b/util/rbac/rbac.go index 2da3495e6ec97..3f99bbe11da49 100644 --- a/util/rbac/rbac.go +++ b/util/rbac/rbac.go @@ -341,14 +341,12 @@ func (e *Enforcer) EnforceErr(rvals ...any) error { if s, ok := rvals[0].(jwt.Claims); ok { claims, err := jwtutil.MapClaims(s) if err == nil { - argoClaims, err := claimsutil.MapClaimsToArgoClaims(claims) - if err == nil { - if argoClaims.GetUserIdentifier() != "" { - rvalsStrs = append(rvalsStrs, "sub: "+argoClaims.GetUserIdentifier()) - } - if issuedAtTime, err := jwtutil.IssuedAtTime(claims); err == nil { - rvalsStrs = append(rvalsStrs, "iat: "+issuedAtTime.Format(time.RFC3339)) - } + userId := claimsutil.GetUserIdentifier(claims) + if userId != "" { + rvalsStrs = append(rvalsStrs, "sub: "+userId) + } + if issuedAtTime, err := jwtutil.IssuedAtTime(claims); err == nil { + rvalsStrs = append(rvalsStrs, "iat: "+issuedAtTime.Format(time.RFC3339)) } } } diff --git a/util/session/sessionmanager.go b/util/session/sessionmanager.go index f2335d1240b87..a0f8e90737018 100644 --- a/util/session/sessionmanager.go +++ b/util/session/sessionmanager.go @@ -159,14 +159,12 @@ func NewSessionManager(settingsMgr *settings.SettingsManager, projectsLister v1a // The id parameter holds an optional unique JWT token identifier and stored as a standard claim "jti" in the JWT token. func (mgr *SessionManager) Create(subject string, secondsBeforeExpiry int64, id string) (string, error) { now := time.Now().UTC() - claims := claimsutil.ArgoClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - IssuedAt: jwt.NewNumericDate(now), - Issuer: SessionManagerClaimsIssuer, - NotBefore: jwt.NewNumericDate(now), - Subject: subject, - ID: id, - }, + claims := jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(now), + Issuer: SessionManagerClaimsIssuer, + NotBefore: jwt.NewNumericDate(now), + Subject: subject, + ID: id, } if secondsBeforeExpiry > 0 { expires := now.Add(time.Duration(secondsBeforeExpiry) * time.Second) @@ -222,17 +220,13 @@ func (mgr *SessionManager) Parse(tokenString string) (jwt.Claims, string, error) if err != nil { return nil, "", err } - argoClaims, err := claimsutil.MapClaimsToArgoClaims(claims) - if err != nil { - return nil, "", err - } issuedAt, err := jwtutil.IssuedAtTime(claims) if err != nil { return nil, "", err } - subject := argoClaims.GetUserIdentifier() + subject := claimsutil.GetUserIdentifier(claims) id := jwtutil.StringField(claims, "jti") if projName, role, ok := rbacpolicy.GetProjectRoleFromSubject(subject); ok { @@ -597,18 +591,19 @@ func LoggedIn(ctx context.Context) bool { // Username is a helper to extract a human readable username from a context func Username(ctx context.Context) string { - argoClaims, ok := argoClaims(ctx) + mapClaims, ok := mapClaims(ctx) if !ok { return "" } - switch argoClaims.Issuer { + switch jwtutil.StringField(mapClaims, "iss") { case SessionManagerClaimsIssuer: - return argoClaims.GetUserIdentifier() + return claimsutil.GetUserIdentifier(mapClaims) default: - if argoClaims.Email != "" { - return argoClaims.Email + e := jwtutil.StringField(mapClaims, "email") + if e != "" { + return e } - return argoClaims.GetUserIdentifier() + return claimsutil.GetUserIdentifier(mapClaims) } } @@ -634,11 +629,7 @@ func GetUserIdentifier(ctx context.Context) string { if !ok { return "" } - argoClaims, err := claimsutil.MapClaimsToArgoClaims(mapClaims) - if err != nil { - return "" - } - return argoClaims.GetUserIdentifier() + return claimsutil.GetUserIdentifier(mapClaims) } func Groups(ctx context.Context, scopes []string) []string { @@ -660,15 +651,3 @@ func mapClaims(ctx context.Context) (jwt.MapClaims, bool) { } return mapClaims, true } - -func argoClaims(ctx context.Context) (*claimsutil.ArgoClaims, bool) { - mapClaims, ok := mapClaims(ctx) - if !ok { - return nil, false - } - argoClaims, err := claimsutil.MapClaimsToArgoClaims(mapClaims) - if err != nil { - return nil, false - } - return argoClaims, true -} diff --git a/util/session/sessionmanager_test.go b/util/session/sessionmanager_test.go index 7df377b14c02a..9c380007533a2 100644 --- a/util/session/sessionmanager_test.go +++ b/util/session/sessionmanager_test.go @@ -29,7 +29,6 @@ import ( apps "github.com/argoproj/argo-cd/v3/pkg/client/clientset/versioned/fake" "github.com/argoproj/argo-cd/v3/pkg/client/listers/application/v1alpha1" "github.com/argoproj/argo-cd/v3/test" - claimsutil "github.com/argoproj/argo-cd/v3/util/claims" jwtutil "github.com/argoproj/argo-cd/v3/util/jwt" "github.com/argoproj/argo-cd/v3/util/password" "github.com/argoproj/argo-cd/v3/util/settings" @@ -98,12 +97,11 @@ func TestSessionManager_AdminToken(t *testing.T) { require.NoError(t, err) assert.Empty(t, newToken) - mapClaims, err := jwtutil.MapClaims(claims) - require.NoError(t, err) - argoClaims, err := claimsutil.MapClaimsToArgoClaims(mapClaims) - require.NoError(t, err) - - assert.Equal(t, "admin", argoClaims.Subject) + mapClaims := *(claims.(*jwt.MapClaims)) + subject := mapClaims["sub"].(string) + if subject != "admin" { + t.Errorf("Token claim subject %q does not match expected subject %q.", subject, "admin") + } } func TestSessionManager_AdminToken_ExpiringSoon(t *testing.T) { @@ -125,13 +123,9 @@ func TestSessionManager_AdminToken_ExpiringSoon(t *testing.T) { claims, _, err := mgr.Parse(newToken) require.NoError(t, err) - mapClaims, err := jwtutil.MapClaims(claims) - require.NoError(t, err) - - argoClaims, err := claimsutil.MapClaimsToArgoClaims(mapClaims) - require.NoError(t, err) - - assert.Equal(t, "admin", argoClaims.Subject) + mapClaims := *(claims.(*jwt.MapClaims)) + subject := mapClaims["sub"].(string) + assert.Equal(t, "admin", subject) } func TestSessionManager_AdminToken_Revoked(t *testing.T) { @@ -202,10 +196,7 @@ func TestSessionManager_ProjectToken(t *testing.T) { mapClaims, err := jwtutil.MapClaims(claims) require.NoError(t, err) - argoClaims, err := claimsutil.MapClaimsToArgoClaims(mapClaims) - require.NoError(t, err) - - assert.Equal(t, "proj:default:test", argoClaims.Subject) + assert.Equal(t, "proj:default:test", mapClaims["sub"]) }) t.Run("Token Revoked", func(t *testing.T) { From 8265f3ea16ed24d8890761521efe3fefc7485d4e Mon Sep 17 00:00:00 2001 From: Michael Crenshaw <350466+crenshaw-dev@users.noreply.github.com> Date: Thu, 29 May 2025 09:47:15 -0600 Subject: [PATCH 2/3] simplify Signed-off-by: Michael Crenshaw <350466+crenshaw-dev@users.noreply.github.com> --- util/claims/claims.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/util/claims/claims.go b/util/claims/claims.go index db7f1f81c8f9a..e418b3d08e693 100644 --- a/util/claims/claims.go +++ b/util/claims/claims.go @@ -2,6 +2,8 @@ package claims import ( "github.com/golang-jwt/jwt/v5" + + jwtutil "github.com/argoproj/argo-cd/v3/util/jwt" ) // GetUserIdentifier returns a consistent user identifier, checking federated_claims.user_id when Dex is in use @@ -11,10 +13,7 @@ func GetUserIdentifier(c jwt.MapClaims) string { } // Fallback to sub if federated_claims.user_id is not set. - fallback, err := c.GetSubject() - if err != nil { - fallback = "" - } + fallback := jwtutil.StringField(c, "sub") f := c["federated_claims"] if f == nil { From ecd7f6ffdac6733cb207912a988eb23008f94a93 Mon Sep 17 00:00:00 2001 From: Michael Crenshaw <350466+crenshaw-dev@users.noreply.github.com> Date: Thu, 29 May 2025 12:16:51 -0600 Subject: [PATCH 3/3] consolidate on one util module Signed-off-by: Michael Crenshaw <350466+crenshaw-dev@users.noreply.github.com> --- cmd/argocd/commands/login.go | 3 +- cmd/argocd/commands/project_role.go | 4 +- server/rbacpolicy/rbacpolicy.go | 3 +- util/claims/claims.go | 33 ------------- util/claims/claims_test.go | 76 ----------------------------- util/jwt/jwt.go | 26 ++++++++++ util/jwt/jwt_test.go | 68 ++++++++++++++++++++++++++ util/rbac/rbac.go | 3 +- util/session/sessionmanager.go | 9 ++-- 9 files changed, 102 insertions(+), 123 deletions(-) delete mode 100644 util/claims/claims.go delete mode 100644 util/claims/claims_test.go diff --git a/cmd/argocd/commands/login.go b/cmd/argocd/commands/login.go index a9a9a1cbe5687..d695064654753 100644 --- a/cmd/argocd/commands/login.go +++ b/cmd/argocd/commands/login.go @@ -25,7 +25,6 @@ import ( argocdclient "github.com/argoproj/argo-cd/v3/pkg/apiclient" sessionpkg "github.com/argoproj/argo-cd/v3/pkg/apiclient/session" settingspkg "github.com/argoproj/argo-cd/v3/pkg/apiclient/settings" - claimsutil "github.com/argoproj/argo-cd/v3/util/claims" "github.com/argoproj/argo-cd/v3/util/cli" "github.com/argoproj/argo-cd/v3/util/errors" grpc_util "github.com/argoproj/argo-cd/v3/util/grpc" @@ -199,7 +198,7 @@ func userDisplayName(claims jwt.MapClaims) string { if name := jwtutil.StringField(claims, "name"); name != "" { return name } - return claimsutil.GetUserIdentifier(claims) + return jwtutil.GetUserIdentifier(claims) } // oauth2Login opens a browser, runs a temporary HTTP server to delegate OAuth2 login flow and diff --git a/cmd/argocd/commands/project_role.go b/cmd/argocd/commands/project_role.go index 1fb30623f9083..d506afce9af21 100644 --- a/cmd/argocd/commands/project_role.go +++ b/cmd/argocd/commands/project_role.go @@ -7,8 +7,6 @@ import ( "text/tabwriter" "time" - claimsutil "github.com/argoproj/argo-cd/v3/util/claims" - timeutil "github.com/argoproj/pkg/v2/time" jwtgo "github.com/golang-jwt/jwt/v5" "github.com/spf13/cobra" @@ -335,7 +333,7 @@ Create token succeeded for proj:test-project:test-role. issuedAt, _ := jwt.IssuedAt(claims) expiresAt := int64(jwt.Float64Field(claims, "exp")) id := jwt.StringField(claims, "jti") - subject := claimsutil.GetUserIdentifier(claims) + subject := jwt.GetUserIdentifier(claims) if !outputTokenOnly { fmt.Printf("Create token succeeded for %s.\n", subject) fmt.Printf(" ID: %s\n Issued At: %s\n Expires At: %s\n", diff --git a/server/rbacpolicy/rbacpolicy.go b/server/rbacpolicy/rbacpolicy.go index 3111c82681dd7..5426a66f8303c 100644 --- a/server/rbacpolicy/rbacpolicy.go +++ b/server/rbacpolicy/rbacpolicy.go @@ -8,7 +8,6 @@ import ( "github.com/argoproj/argo-cd/v3/pkg/apis/application/v1alpha1" applister "github.com/argoproj/argo-cd/v3/pkg/client/listers/application/v1alpha1" - claimsutil "github.com/argoproj/argo-cd/v3/util/claims" jwtutil "github.com/argoproj/argo-cd/v3/util/jwt" "github.com/argoproj/argo-cd/v3/util/rbac" ) @@ -63,7 +62,7 @@ func (p *RBACPolicyEnforcer) EnforceClaims(claims jwt.Claims, rvals ...any) bool return false } - subject := claimsutil.GetUserIdentifier(mapClaims) + subject := jwtutil.GetUserIdentifier(mapClaims) // Check if the request is for an application resource. We have special enforcement which takes // into consideration the project's token and group bindings var runtimePolicy string diff --git a/util/claims/claims.go b/util/claims/claims.go deleted file mode 100644 index e418b3d08e693..0000000000000 --- a/util/claims/claims.go +++ /dev/null @@ -1,33 +0,0 @@ -package claims - -import ( - "github.com/golang-jwt/jwt/v5" - - jwtutil "github.com/argoproj/argo-cd/v3/util/jwt" -) - -// GetUserIdentifier returns a consistent user identifier, checking federated_claims.user_id when Dex is in use -func GetUserIdentifier(c jwt.MapClaims) string { - if c == nil { - return "" - } - - // Fallback to sub if federated_claims.user_id is not set. - fallback := jwtutil.StringField(c, "sub") - - f := c["federated_claims"] - if f == nil { - return fallback - } - federatedClaims, ok := f.(map[string]any) - if !ok { - return fallback - } - - userId, ok := federatedClaims["user_id"].(string) - if !ok || userId == "" { - return fallback - } - - return userId -} diff --git a/util/claims/claims_test.go b/util/claims/claims_test.go deleted file mode 100644 index 7b28b526f25ac..0000000000000 --- a/util/claims/claims_test.go +++ /dev/null @@ -1,76 +0,0 @@ -package claims - -import ( - "testing" - - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/assert" -) - -func TestGetUserIdentifier(t *testing.T) { - tests := []struct { - name string - claims jwt.MapClaims - want string - }{ - { - name: "when both dex and sub defined - prefer dex user_id", - claims: jwt.MapClaims{ - "sub": "ignored:login", - "federated_claims": map[string]any{ - "user_id": "dex-user", - }, - }, - want: "dex-user", - }, - { - name: "when both dex and sub defined but dex user_id empty - fallback to sub", - claims: jwt.MapClaims{ - "sub": "test:apiKey", - "federated_claims": map[string]any{ - "user_id": "", - }, - }, - want: "test:apiKey", - }, - { - name: "when only sub is defined (no dex) - use sub", - claims: jwt.MapClaims{ - "sub": "admin:login", - }, - want: "admin:login", - }, - { - name: "when neither dex nor sub defined - return empty", - claims: jwt.MapClaims{}, - want: "", - }, - { - name: "nil claims", - claims: nil, - want: "", - }, - { - name: "invalid subject", - claims: jwt.MapClaims{ - "sub": nil, - }, - want: "", - }, - { - name: "invalid federated_claims", - claims: jwt.MapClaims{ - "sub": "test:apiKey", - "federated_claims": "invalid", - }, - want: "test:apiKey", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := GetUserIdentifier(tt.claims) - assert.Equal(t, tt.want, got) - }) - } -} diff --git a/util/jwt/jwt.go b/util/jwt/jwt.go index bb6770da333a1..c9caa3794d06f 100644 --- a/util/jwt/jwt.go +++ b/util/jwt/jwt.go @@ -139,3 +139,29 @@ func GetGroups(mapClaims jwtgo.MapClaims, scopes []string) []string { func IsValid(token string) bool { return len(strings.SplitN(token, ".", 3)) == 3 } + +// GetUserIdentifier returns a consistent user identifier, checking federated_claims.user_id when Dex is in use +func GetUserIdentifier(c jwtgo.MapClaims) string { + if c == nil { + return "" + } + + // Fallback to sub if federated_claims.user_id is not set. + fallback := StringField(c, "sub") + + f := c["federated_claims"] + if f == nil { + return fallback + } + federatedClaims, ok := f.(map[string]any) + if !ok { + return fallback + } + + userId, ok := federatedClaims["user_id"].(string) + if !ok || userId == "" { + return fallback + } + + return userId +} diff --git a/util/jwt/jwt_test.go b/util/jwt/jwt_test.go index 5ef645abb81f9..3dd43b3658142 100644 --- a/util/jwt/jwt_test.go +++ b/util/jwt/jwt_test.go @@ -68,3 +68,71 @@ func TestIsValid(t *testing.T) { assert.False(t, IsValid("foo")) assert.False(t, IsValid("")) } + +func TestGetUserIdentifier(t *testing.T) { + tests := []struct { + name string + claims jwt.MapClaims + want string + }{ + { + name: "when both dex and sub defined - prefer dex user_id", + claims: jwt.MapClaims{ + "sub": "ignored:login", + "federated_claims": map[string]any{ + "user_id": "dex-user", + }, + }, + want: "dex-user", + }, + { + name: "when both dex and sub defined but dex user_id empty - fallback to sub", + claims: jwt.MapClaims{ + "sub": "test:apiKey", + "federated_claims": map[string]any{ + "user_id": "", + }, + }, + want: "test:apiKey", + }, + { + name: "when only sub is defined (no dex) - use sub", + claims: jwt.MapClaims{ + "sub": "admin:login", + }, + want: "admin:login", + }, + { + name: "when neither dex nor sub defined - return empty", + claims: jwt.MapClaims{}, + want: "", + }, + { + name: "nil claims", + claims: nil, + want: "", + }, + { + name: "invalid subject", + claims: jwt.MapClaims{ + "sub": nil, + }, + want: "", + }, + { + name: "invalid federated_claims", + claims: jwt.MapClaims{ + "sub": "test:apiKey", + "federated_claims": "invalid", + }, + want: "test:apiKey", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetUserIdentifier(tt.claims) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/util/rbac/rbac.go b/util/rbac/rbac.go index 3f99bbe11da49..29781fb26476c 100644 --- a/util/rbac/rbac.go +++ b/util/rbac/rbac.go @@ -11,7 +11,6 @@ import ( "time" "github.com/argoproj/argo-cd/v3/util/assets" - claimsutil "github.com/argoproj/argo-cd/v3/util/claims" "github.com/argoproj/argo-cd/v3/util/glob" jwtutil "github.com/argoproj/argo-cd/v3/util/jwt" @@ -341,7 +340,7 @@ func (e *Enforcer) EnforceErr(rvals ...any) error { if s, ok := rvals[0].(jwt.Claims); ok { claims, err := jwtutil.MapClaims(s) if err == nil { - userId := claimsutil.GetUserIdentifier(claims) + userId := jwtutil.GetUserIdentifier(claims) if userId != "" { rvalsStrs = append(rvalsStrs, "sub: "+userId) } diff --git a/util/session/sessionmanager.go b/util/session/sessionmanager.go index a0f8e90737018..2d5c44e8768cf 100644 --- a/util/session/sessionmanager.go +++ b/util/session/sessionmanager.go @@ -25,7 +25,6 @@ import ( "github.com/argoproj/argo-cd/v3/common" "github.com/argoproj/argo-cd/v3/pkg/client/listers/application/v1alpha1" "github.com/argoproj/argo-cd/v3/util/cache/appstate" - claimsutil "github.com/argoproj/argo-cd/v3/util/claims" "github.com/argoproj/argo-cd/v3/util/dex" "github.com/argoproj/argo-cd/v3/util/env" httputil "github.com/argoproj/argo-cd/v3/util/http" @@ -226,7 +225,7 @@ func (mgr *SessionManager) Parse(tokenString string) (jwt.Claims, string, error) return nil, "", err } - subject := claimsutil.GetUserIdentifier(claims) + subject := jwtutil.GetUserIdentifier(claims) id := jwtutil.StringField(claims, "jti") if projName, role, ok := rbacpolicy.GetProjectRoleFromSubject(subject); ok { @@ -597,13 +596,13 @@ func Username(ctx context.Context) string { } switch jwtutil.StringField(mapClaims, "iss") { case SessionManagerClaimsIssuer: - return claimsutil.GetUserIdentifier(mapClaims) + return jwtutil.GetUserIdentifier(mapClaims) default: e := jwtutil.StringField(mapClaims, "email") if e != "" { return e } - return claimsutil.GetUserIdentifier(mapClaims) + return jwtutil.GetUserIdentifier(mapClaims) } } @@ -629,7 +628,7 @@ func GetUserIdentifier(ctx context.Context) string { if !ok { return "" } - return claimsutil.GetUserIdentifier(mapClaims) + return jwtutil.GetUserIdentifier(mapClaims) } func Groups(ctx context.Context, scopes []string) []string {