diff --git a/CHANGELOG.md b/CHANGELOG.md index b12f9c7a32..d54603bd2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Optimized wildcard matching runtime performance ([#5470](https://github.com/opensearch-project/security/pull/5470)) * Optimized performance for construction of internal action privileges data structure ([#5470](https://github.com/opensearch-project/security/pull/5470)) * Restricting query optimization via star tree index for users with queries on indices with DLS/FLS/FieldMasked restrictions ([#5492](https://github.com/opensearch-project/security/pull/5492)) +* Handle subject in nested claim for JWT auth backends ([#5467](https://github.com/opensearch-project/security/pull/5467)) ### Bug Fixes diff --git a/src/integrationTest/java/org/opensearch/security/http/JwtAuthenticationNestedClaimsTests.java b/src/integrationTest/java/org/opensearch/security/http/JwtAuthenticationNestedClaimsTests.java index 47f4a9c980..b51d21d586 100644 --- a/src/integrationTest/java/org/opensearch/security/http/JwtAuthenticationNestedClaimsTests.java +++ b/src/integrationTest/java/org/opensearch/security/http/JwtAuthenticationNestedClaimsTests.java @@ -10,7 +10,6 @@ package org.opensearch.security.http; import java.security.KeyPair; -import java.util.Arrays; import java.util.Base64; import java.util.HashMap; import java.util.List; @@ -47,31 +46,74 @@ @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class JwtAuthenticationNestedClaimsTests { - public static final String CLAIM_USERNAME = "preferred-username"; - public static final List CLAIM_ROLES = List.of("attributes", "roles"); + public static final List USERNAME_CLAIM = List.of("preferred-username"); + public static final List NESTED_ROLES = List.of("attributes", "roles"); + public static final List NESTED_SUBJECT = List.of("attributes_sub", "sub"); + public static final List NESTED_SUBJECT_ATTRIBUTES_ONLY = List.of("attributes", "sub"); + public static final List ROLES_CLAIM = List.of("all_access", "securitymanager"); public static final String USER_SUPERHERO = "superhero"; private static final KeyPair KEY_PAIR1 = Keys.keyPairFor(SignatureAlgorithm.RS256); private static final String PUBLIC_KEY1 = new String(Base64.getEncoder().encode(KEY_PAIR1.getPublic().getEncoded()), US_ASCII); private static final String JWT_AUTH_HEADER = "jwt-auth"; + // Token factory for regular subject + nested roles private static final JwtAuthorizationHeaderFactory tokenFactory1 = new JwtAuthorizationHeaderFactory( KEY_PAIR1.getPrivate(), - CLAIM_USERNAME, - CLAIM_ROLES, + USERNAME_CLAIM, + NESTED_ROLES, JWT_AUTH_HEADER ); + + // Token factory for nested subject + nested roles + private static final JwtAuthorizationHeaderFactory tokenFactoryNestedSubjectAndRole = new JwtAuthorizationHeaderFactory( + KEY_PAIR1.getPrivate(), + NESTED_SUBJECT, + NESTED_ROLES, + JWT_AUTH_HEADER + ); + + // Token factory for both subject and roles nested under same "attributes" only + private static final JwtAuthorizationHeaderFactory tokenFactoryAttributesOnly = new JwtAuthorizationHeaderFactory( + KEY_PAIR1.getPrivate(), + NESTED_SUBJECT_ATTRIBUTES_ONLY, + NESTED_ROLES, + JWT_AUTH_HEADER + ); + + // JWT domain for regular subject + nested roles public static final TestSecurityConfig.AuthcDomain JWT_AUTH_DOMAIN = new TestSecurityConfig.AuthcDomain( "jwt", BASIC_AUTH_DOMAIN_ORDER - 1 ).jwtHttpAuthenticator( - new JwtConfigBuilder().jwtHeader(JWT_AUTH_HEADER).signingKey(List.of(PUBLIC_KEY1)).subjectKey(CLAIM_USERNAME).rolesKey(CLAIM_ROLES) + new JwtConfigBuilder().jwtHeader(JWT_AUTH_HEADER).signingKey(List.of(PUBLIC_KEY1)).subjectKey(USERNAME_CLAIM).rolesKey(NESTED_ROLES) + ).backend("noop"); + + // JWT domain for nested subject + nested roles + public static final TestSecurityConfig.AuthcDomain JWT_AUTH_DOMAIN_NESTED_SUBJECT = new TestSecurityConfig.AuthcDomain( + "jwt-nested", + BASIC_AUTH_DOMAIN_ORDER - 2 + ).jwtHttpAuthenticator( + new JwtConfigBuilder().jwtHeader(JWT_AUTH_HEADER).signingKey(List.of(PUBLIC_KEY1)).subjectKey(NESTED_SUBJECT).rolesKey(NESTED_ROLES) + ).backend("noop"); + + // JWT domain for both subject and roles using "attributes" only + public static final TestSecurityConfig.AuthcDomain JWT_AUTH_DOMAIN_ATTRIBUTES_ONLY = new TestSecurityConfig.AuthcDomain( + "jwt-attributes-only", + BASIC_AUTH_DOMAIN_ORDER - 3 + ).jwtHttpAuthenticator( + new JwtConfigBuilder().jwtHeader(JWT_AUTH_HEADER) + .signingKey(List.of(PUBLIC_KEY1)) + .subjectKey(NESTED_SUBJECT_ATTRIBUTES_ONLY) + .rolesKey(NESTED_ROLES) ).backend("noop"); @ClassRule public static final LocalCluster cluster = new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE) .anonymousAuth(false) .authc(JWT_AUTH_DOMAIN) + .authc(JWT_AUTH_DOMAIN_NESTED_SUBJECT) + .authc(JWT_AUTH_DOMAIN_ATTRIBUTES_ONLY) .build(); @Rule @@ -82,8 +124,7 @@ public class JwtAuthenticationNestedClaimsTests { public void shouldAuthenticateWithNestedRolesClaim() { // Create nested claims structure Map attributes = new HashMap<>(); - List rolesClaim = Arrays.asList("all_access", "securitymanager"); - attributes.put("roles", rolesClaim); + attributes.put("roles", ROLES_CLAIM); Map nestedClaims = new HashMap<>(); nestedClaims.put("attributes", attributes); @@ -124,4 +165,203 @@ public void shouldHandleMissingNestedRolesClaim() { assertThat(roles, hasSize(0)); } } + + @Test + public void shouldAuthenticateWithNestedSubjectAndNestedRoles() { + // Create nested subject structure - the key should match NESTED_SUBJECT path + Map attributesSub = new HashMap<>(); + attributesSub.put("sub", USER_SUPERHERO); + + // Create nested roles structure + Map attributes = new HashMap<>(); + attributes.put("roles", ROLES_CLAIM); + + // Combine both in the claims + Map nestedClaims = new HashMap<>(); + nestedClaims.put("attributes_sub", attributesSub); + nestedClaims.put("attributes", attributes); + + // Use the token factory with nested subject configuration + Header header = tokenFactoryNestedSubjectAndRole.generateValidTokenWithCustomClaims(null, null, nestedClaims); + + try (TestRestClient client = cluster.getRestClient(header)) { + HttpResponse response = client.getAuthInfo(); + + response.assertStatusCode(200); + String username = response.getTextFromJsonBody(POINTER_USERNAME); + assertThat(username, equalTo(USER_SUPERHERO)); + + List roles = response.getTextArrayFromJsonBody(POINTER_BACKEND_ROLES); + assertThat(roles, hasSize(2)); + assertThat(roles, containsInAnyOrder("all_access", "securitymanager")); + } + } + + @Test + public void shouldAuthenticateWithNestedSubjectAndSimpleRoles() { + // Create nested subject structure + Map attributesSub = new HashMap<>(); + attributesSub.put("sub", USER_SUPERHERO); + + Map nestedClaims = new HashMap<>(); + nestedClaims.put("attributes_sub", attributesSub); + + Header header = tokenFactoryNestedSubjectAndRole.generateValidTokenWithCustomClaims(null, null, nestedClaims); + + try (TestRestClient client = cluster.getRestClient(header)) { + HttpResponse response = client.getAuthInfo(); + + response.assertStatusCode(200); + String username = response.getTextFromJsonBody(POINTER_USERNAME); + assertThat(username, equalTo(USER_SUPERHERO)); + + // Should have no roles since they're not in the expected nested location + List roles = response.getTextArrayFromJsonBody(POINTER_BACKEND_ROLES); + assertThat(roles, hasSize(0)); + } + } + + // Negative test cases + + @Test + public void shouldFailAuthenticationWithMissingNestedSubject() { + // Create nested roles structure but missing nested subject + Map attributes = new HashMap<>(); + attributes.put("roles", ROLES_CLAIM); + + Map nestedClaims = new HashMap<>(); + nestedClaims.put("attributes", attributes); + // Missing attributes_sub structure + + Header header = tokenFactoryNestedSubjectAndRole.generateValidTokenWithCustomClaims(null, null, nestedClaims); + + try (TestRestClient client = cluster.getRestClient(header)) { + HttpResponse response = client.getAuthInfo(); + + // Should fail authentication due to missing subject + response.assertStatusCode(401); + } + } + + @Test + public void shouldFailAuthenticationWithWrongNestedSubjectStructure() { + // Create wrong nested subject structure + Map attributesSub = new HashMap<>(); + attributesSub.put("wrong_key", USER_SUPERHERO); // Wrong key, should be "sub" + + Map attributes = new HashMap<>(); + attributes.put("roles", ROLES_CLAIM); + + Map nestedClaims = new HashMap<>(); + nestedClaims.put("attributes_sub", attributesSub); + nestedClaims.put("attributes", attributes); + + Header header = tokenFactoryNestedSubjectAndRole.generateValidTokenWithCustomClaims(null, null, nestedClaims); + + try (TestRestClient client = cluster.getRestClient(header)) { + HttpResponse response = client.getAuthInfo(); + + // Should fail authentication due to wrong subject structure + response.assertStatusCode(401); + } + } + + @Test + public void shouldAuthenticateWithMissingRolesButValidSubject() { + // Create nested subject structure but missing roles + Map attributesSub = new HashMap<>(); + attributesSub.put("sub", USER_SUPERHERO); + + Map nestedClaims = new HashMap<>(); + nestedClaims.put("attributes_sub", attributesSub); + // Missing roles structure + + Header header = tokenFactoryNestedSubjectAndRole.generateValidTokenWithCustomClaims(null, null, nestedClaims); + + try (TestRestClient client = cluster.getRestClient(header)) { + HttpResponse response = client.getAuthInfo(); + + // Should authenticate but with no roles + response.assertStatusCode(200); + String username = response.getTextFromJsonBody(POINTER_USERNAME); + assertThat(username, equalTo(USER_SUPERHERO)); + + List roles = response.getTextArrayFromJsonBody(POINTER_BACKEND_ROLES); + assertThat(roles, hasSize(0)); + } + } + + @Test + public void shouldHandleWrongNestedRolesStructure() { + // Create nested subject structure with wrong roles structure + Map attributesSub = new HashMap<>(); + attributesSub.put("sub", USER_SUPERHERO); + + Map attributes = new HashMap<>(); + attributes.put("wrong_roles_key", ROLES_CLAIM); // Wrong key, should be "roles" + + Map nestedClaims = new HashMap<>(); + nestedClaims.put("attributes_sub", attributesSub); + nestedClaims.put("attributes", attributes); + + Header header = tokenFactoryNestedSubjectAndRole.generateValidTokenWithCustomClaims(null, null, nestedClaims); + + try (TestRestClient client = cluster.getRestClient(header)) { + HttpResponse response = client.getAuthInfo(); + + // Should authenticate but with no roles due to wrong roles structure + response.assertStatusCode(200); + String username = response.getTextFromJsonBody(POINTER_USERNAME); + assertThat(username, equalTo(USER_SUPERHERO)); + + List roles = response.getTextArrayFromJsonBody(POINTER_BACKEND_ROLES); + assertThat(roles, hasSize(0)); + } + } + + @Test + public void shouldFailAuthenticationWithCompletelyWrongTokenStructure() { + // Create completely wrong token structure + Map wrongClaims = new HashMap<>(); + wrongClaims.put("completely", "wrong"); + wrongClaims.put("structure", "invalid"); + + Header header = tokenFactoryNestedSubjectAndRole.generateValidTokenWithCustomClaims(null, null, wrongClaims); + + try (TestRestClient client = cluster.getRestClient(header)) { + HttpResponse response = client.getAuthInfo(); + + // Should fail authentication due to completely wrong structure + response.assertStatusCode(401); + } + } + + @Test + public void shouldAuthenticateWithBothSubjectAndRolesInAttributesOnly() { + // Create nested structure where both subject and roles are under "attributes" + // Subject path: attributes -> sub + // Roles path: attributes -> roles + Map attributes = new HashMap<>(); + attributes.put("sub", USER_SUPERHERO); + attributes.put("roles", ROLES_CLAIM); + + Map nestedClaims = new HashMap<>(); + nestedClaims.put("attributes", attributes); + + // Use the token factory configured for attributes-only paths + Header header = tokenFactoryAttributesOnly.generateValidTokenWithCustomClaims(null, null, nestedClaims); + + try (TestRestClient client = cluster.getRestClient(header)) { + HttpResponse response = client.getAuthInfo(); + + response.assertStatusCode(200); + String username = response.getTextFromJsonBody(POINTER_USERNAME); + assertThat(username, equalTo(USER_SUPERHERO)); + + List roles = response.getTextArrayFromJsonBody(POINTER_BACKEND_ROLES); + assertThat(roles, hasSize(2)); + assertThat(roles, containsInAnyOrder("all_access", "securitymanager")); + } + } + } diff --git a/src/integrationTest/java/org/opensearch/security/http/JwtAuthenticationTests.java b/src/integrationTest/java/org/opensearch/security/http/JwtAuthenticationTests.java index 6173dd7c55..3f2f8f7db5 100644 --- a/src/integrationTest/java/org/opensearch/security/http/JwtAuthenticationTests.java +++ b/src/integrationTest/java/org/opensearch/security/http/JwtAuthenticationTests.java @@ -68,7 +68,7 @@ @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class JwtAuthenticationTests { - public static final String CLAIM_USERNAME = "preferred-username"; + public static final List CLAIM_USERNAME = List.of("preferred-username"); public static final List CLAIM_ROLES = List.of("backend-user-roles"); public static final String USER_SUPERHERO = "superhero"; diff --git a/src/integrationTest/java/org/opensearch/security/http/JwtAuthenticationWithUrlParamTests.java b/src/integrationTest/java/org/opensearch/security/http/JwtAuthenticationWithUrlParamTests.java index 788abb1432..d30643f758 100644 --- a/src/integrationTest/java/org/opensearch/security/http/JwtAuthenticationWithUrlParamTests.java +++ b/src/integrationTest/java/org/opensearch/security/http/JwtAuthenticationWithUrlParamTests.java @@ -50,7 +50,7 @@ @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class JwtAuthenticationWithUrlParamTests { - public static final String CLAIM_USERNAME = "preferred-username"; + public static final List CLAIM_USERNAME = List.of("preferred-username"); public static final List CLAIM_ROLES = List.of("backend-user-roles"); public static final String POINTER_USERNAME = "/user_name"; diff --git a/src/integrationTest/java/org/opensearch/security/http/JwtAuthorizationHeaderFactory.java b/src/integrationTest/java/org/opensearch/security/http/JwtAuthorizationHeaderFactory.java index d189d3062d..38cd1992c3 100644 --- a/src/integrationTest/java/org/opensearch/security/http/JwtAuthorizationHeaderFactory.java +++ b/src/integrationTest/java/org/opensearch/security/http/JwtAuthorizationHeaderFactory.java @@ -30,13 +30,18 @@ class JwtAuthorizationHeaderFactory { public static final String ISSUER = "test-code"; private final PrivateKey privateKey; - private final String usernameClaimName; + private final List usernameClaimName; private final List rolesClaimName; private final String headerName; - public JwtAuthorizationHeaderFactory(PrivateKey privateKey, String usernameClaimName, List rolesClaimName, String headerName) { + public JwtAuthorizationHeaderFactory( + PrivateKey privateKey, + List usernameClaimName, + List rolesClaimName, + String headerName + ) { this.privateKey = requireNonNull(privateKey, "Private key is required"); this.usernameClaimName = requireNonNull(usernameClaimName, "Username claim name is required"); this.rolesClaimName = requireNonNull(rolesClaimName, "Roles claim name is required."); @@ -60,9 +65,31 @@ Header generateValidToken(String username, String... roles) { private Map customClaimsMap(String username, String[] roles) { ImmutableMap.Builder builder = new ImmutableMap.Builder(); + // Handle username claim if (StringUtils.isNoneEmpty(username)) { - builder.put(usernameClaimName, username); + if (usernameClaimName.size() == 1) { + // Simple case - no nesting + builder.put(usernameClaimName.get(0), username); + } else { + // Handle nested claims + Map nestedMap = new HashMap<>(); + Map currentMap = nestedMap; + + // Build the nested structure + for (int i = 0; i < usernameClaimName.size() - 1; i++) { + Map nextMap = new HashMap<>(); + currentMap.put(usernameClaimName.get(i), nextMap); + currentMap = nextMap; + } + + // Add the username at the deepest level + currentMap.put(usernameClaimName.get(usernameClaimName.size() - 1), username); + + // Add the entire nested structure to the builder + builder.putAll(nestedMap); + } } + if (roles != null && roles.length > 0) { if (rolesClaimName.size() == 1) { // Simple case - no nesting @@ -90,7 +117,6 @@ private Map customClaimsMap(String username, String[] roles) { } Header generateValidTokenWithCustomClaims(String username, String[] roles, Map additionalClaims) { - requireNonNull(username, "Username is required"); requireNonNull(additionalClaims, "Custom claims are required"); Map claims = new HashMap<>(customClaimsMap(username, roles)); claims.putAll(additionalClaims); @@ -128,7 +154,7 @@ public Header generateExpiredToken(String username) { requireNonNull(username, "Username is required"); Date now = new Date(1000); String token = Jwts.builder() - .setClaims(Map.of(usernameClaimName, username)) + .setClaims(customClaimsMap(username, null)) .setIssuer(ISSUER) .setSubject(subject(username)) .setAudience(AUDIENCE) @@ -144,7 +170,7 @@ public Header generateTokenSignedWithKey(PrivateKey key, String username) { requireNonNull(username, "Username is required"); Date now = new Date(); String token = Jwts.builder() - .setClaims(Map.of(usernameClaimName, username)) + .setClaims(customClaimsMap(username, null)) .setIssuer(ISSUER) .setSubject(subject(username)) .setAudience(AUDIENCE) diff --git a/src/integrationTest/java/org/opensearch/test/framework/JwtConfigBuilder.java b/src/integrationTest/java/org/opensearch/test/framework/JwtConfigBuilder.java index 871419d2bf..0dedebd1dd 100644 --- a/src/integrationTest/java/org/opensearch/test/framework/JwtConfigBuilder.java +++ b/src/integrationTest/java/org/opensearch/test/framework/JwtConfigBuilder.java @@ -21,7 +21,7 @@ public class JwtConfigBuilder { private String jwtHeader; private String jwtUrlParameter; private List signingKeys; - private String subjectKey; + private List subjectKey; private List rolesKey; public JwtConfigBuilder jwtHeader(String jwtHeader) { @@ -40,6 +40,11 @@ public JwtConfigBuilder signingKey(List signingKeys) { } public JwtConfigBuilder subjectKey(String subjectKey) { + this.subjectKey = List.of(subjectKey); + return this; + } + + public JwtConfigBuilder subjectKey(List subjectKey) { this.subjectKey = subjectKey; return this; } @@ -66,7 +71,7 @@ public Map build() { if (isNoneBlank(jwtUrlParameter)) { builder.put("jwt_url_parameter", jwtUrlParameter); } - if (isNoneBlank(subjectKey)) { + if (subjectKey != null && !subjectKey.isEmpty()) { builder.put("subject_key", subjectKey); } if (rolesKey != null && !rolesKey.isEmpty()) { diff --git a/src/main/java/org/opensearch/security/auth/http/jwt/AbstractHTTPJwtAuthenticator.java b/src/main/java/org/opensearch/security/auth/http/jwt/AbstractHTTPJwtAuthenticator.java index f79f61d3e8..c38afebdf5 100644 --- a/src/main/java/org/opensearch/security/auth/http/jwt/AbstractHTTPJwtAuthenticator.java +++ b/src/main/java/org/opensearch/security/auth/http/jwt/AbstractHTTPJwtAuthenticator.java @@ -60,7 +60,7 @@ public abstract class AbstractHTTPJwtAuthenticator implements HTTPAuthenticator private final String jwtHeaderName; private final boolean isDefaultAuthHeader; private final String jwtUrlParameter; - private final String subjectKey; + private final List subjectKey; private final List rolesKey; private final List requiredAudience; private final String requiredIssuer; @@ -73,7 +73,7 @@ public AbstractHTTPJwtAuthenticator(Settings settings, Path configPath) { jwtHeaderName = settings.get("jwt_header", AUTHORIZATION); isDefaultAuthHeader = AUTHORIZATION.equalsIgnoreCase(jwtHeaderName); rolesKey = settings.getAsList("roles_key"); - subjectKey = settings.get("subject_key"); + subjectKey = settings.getAsList("subject_key"); clockSkewToleranceSeconds = settings.getAsInt("jwt_clock_skew_tolerance_seconds", DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS); requiredAudience = settings.getAsList("required_audience"); requiredIssuer = settings.get("required_issuer"); @@ -183,12 +183,25 @@ protected String getJwtTokenString(SecurityRequest request) { return jwtToken; } + @SuppressWarnings("unchecked") @VisibleForTesting public String extractSubject(JWTClaimsSet claims) { String subject = claims.getSubject(); - if (subjectKey != null) { - Object subjectObject = claims.getClaim(subjectKey); + if (subjectKey != null && !subjectKey.isEmpty()) { + Object subjectObject = null; + Map claimsMap = claims.getClaims(); + // This loop is necessary for nested claim traversal + for (int i = 0; i < subjectKey.size(); i++) { + if (i == subjectKey.size() - 1) { + subjectObject = claimsMap.get(subjectKey.get(i)); + } else if (claimsMap.get(subjectKey.get(i)) instanceof Map) { + claimsMap = (Map) claimsMap.get(subjectKey.get(i)); + } else { + log.warn("Failed to get subject from JWT claims with subject_key '{}'.", subjectKey); + return null; + } + } if (subjectObject == null) { log.warn("Failed to get subject from JWT claims, check if subject_key '{}' is correct.", subjectKey); diff --git a/src/main/java/org/opensearch/security/auth/http/jwt/HTTPJwtAuthenticator.java b/src/main/java/org/opensearch/security/auth/http/jwt/HTTPJwtAuthenticator.java index 4be59a3c63..3bac360344 100644 --- a/src/main/java/org/opensearch/security/auth/http/jwt/HTTPJwtAuthenticator.java +++ b/src/main/java/org/opensearch/security/auth/http/jwt/HTTPJwtAuthenticator.java @@ -65,7 +65,7 @@ public class HTTPJwtAuthenticator implements HTTPAuthenticator { private final boolean isDefaultAuthHeader; private final String jwtUrlParameter; private final List rolesKey; - private final String subjectKey; + private final List subjectKey; private final List requiredAudience; private final String requireIssuer; private final int clockSkewToleranceSeconds; @@ -80,7 +80,7 @@ public HTTPJwtAuthenticator(final Settings settings, final Path configPath) { jwtHeaderName = settings.get("jwt_header", AUTHORIZATION); isDefaultAuthHeader = AUTHORIZATION.equalsIgnoreCase(jwtHeaderName); rolesKey = settings.getAsList("roles_key"); - subjectKey = settings.get("subject_key"); + subjectKey = settings.getAsList("subject_key"); requiredAudience = settings.getAsList("required_audience"); requireIssuer = settings.get("required_issuer"); clockSkewToleranceSeconds = settings.getAsInt( @@ -185,7 +185,7 @@ private AuthCredentials extractCredentials0(final SecurityRequest request) { assertValidAudienceClaim(claims); } - final String subject = extractSubject(claims, request); + final String subject = extractSubject(claims); if (subject == null) { log.error("No subject found in JWT token"); @@ -260,25 +260,41 @@ public String getType() { return "jwt"; } - protected String extractSubject(final Claims claims, final SecurityRequest request) { + protected String extractSubject(final Claims claims) { String subject = claims.getSubject(); - if (subjectKey != null) { - // try to get roles from claims, first as Object to avoid having to catch the ExpectedTypeException - Object subjectObject = claims.get(subjectKey, Object.class); - if (subjectObject == null) { - log.warn("Failed to get subject from JWT claims, check if subject_key '{}' is correct.", subjectKey); - return null; + if (subjectKey != null && !subjectKey.isEmpty()) { + // ── 1. Traverse the nested structure ─────────────────────────────────────── + Object node = claims; // start at root + for (String key : subjectKey) { + if (!(node instanceof Map map)) { // unexpected shape + log.warn( + "While following subject_key path {}, expected a JSON object before '{}', but found '{}' ({}).", + subjectKey, + key, + node, + node.getClass() + ); + return null; // Subject cannot be extracted from the configured path + } + node = map.get(key); + if (node == null) { // key missing + log.warn("Failed to find '{}' in JWT claims while following subject_key path {}.", key, subjectKey); + return null; // Subject cannot be extracted from the configured path + } } - // We expect a String. If we find something else, convert to String but issue a warning - if (!(subjectObject instanceof String)) { + // ── 2. Interpret the leaf value ──────────────────────────────────────────── + if (node instanceof String str) { + return str.trim(); + } else { // something odd log.warn( - "Expected type String for roles in the JWT for subject_key {}, but value was '{}' ({}). Will convert this value to String.", + "Expected a String at the end of subject_key path {}, but found '{}' ({}). Converting to String.", subjectKey, - subjectObject, - subjectObject.getClass() + node, + node.getClass() ); + return String.valueOf(node).trim(); } - subject = String.valueOf(subjectObject); + } return subject; } diff --git a/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java b/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java index debe581fa6..e2f43433bf 100644 --- a/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java +++ b/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java @@ -285,6 +285,56 @@ public void testRolesInNestedClaim() { assertThat(creds.getBackendRoles(), Matchers.is(TestJwts.TEST_ROLES)); } + @Test + public void testSubjectInNestedClaim() { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .putList("subject_key", TestJwts.NESTED_MCCOY_SUBJECT) + .put("roles_key", TestJwts.ROLES_CLAIM) + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE) + .build(); + + HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); + + AuthCredentials creds = jwtAuth.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_NESTED_SUBJECT_OCT_1), + new HashMap() + ).asSecurityRequest(), + null + ); + + Assert.assertNotNull(creds); + assertThat(creds.getUsername(), Matchers.is(TestJwts.MCCOY_SUBJECT)); + assertThat(creds.getBackendRoles(), Matchers.is(TestJwts.TEST_ROLES)); + } + + @Test + public void testSubjectAndRolesInNestedClaim() { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .putList("subject_key", TestJwts.NESTED_ROLES_AND_SUBJECT_CLAIM) + .putList("roles_key", TestJwts.NESTED_ROLES_CLAIM) + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE) + .build(); + + HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); + + AuthCredentials creds = jwtAuth.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_NESTED_ROLES_AND_SUBJECT_OCT_1), + new HashMap() + ).asSecurityRequest(), + null + ); + + Assert.assertNotNull(creds); + assertThat(creds.getUsername(), Matchers.is(TestJwts.MCCOY_SUBJECT)); + assertThat(creds.getBackendRoles(), Matchers.is(TestJwts.TEST_ROLES)); + } + @Test public void testExp() { Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build(); diff --git a/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/TestJwts.java b/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/TestJwts.java index c120bf7e45..9971ca4bc5 100644 --- a/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/TestJwts.java +++ b/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/TestJwts.java @@ -11,6 +11,7 @@ package org.opensearch.security.auth.http.jwt.keybyoidc; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -33,12 +34,14 @@ class TestJwts { static final String ROLES_CLAIM = "roles"; static final List NESTED_ROLES_CLAIM = List.of("attributes", "roles"); + static final List NESTED_ROLES_AND_SUBJECT_CLAIM = List.of("attributes", "sub"); static final Set TEST_ROLES = ImmutableSet.of("role1", "role2"); static final String TEST_ROLES_STRING = String.join(",", TEST_ROLES); static final String TEST_AUDIENCE = "TestAudience"; static final String MCCOY_SUBJECT = "Leonard McCoy"; + static final List NESTED_MCCOY_SUBJECT = List.of("attributes_sub", "sub"); static final String TEST_ISSUER = "TestIssuer"; @@ -46,6 +49,16 @@ class TestJwts { static final JWTClaimsSet MC_COY_2 = create(MCCOY_SUBJECT, TEST_AUDIENCE, TEST_ISSUER, ROLES_CLAIM, TEST_ROLES_STRING); + static final JWTClaimsSet MC_COY_NESTED_SUBJECT = create( + null, + TEST_AUDIENCE, + TEST_ISSUER, + NESTED_MCCOY_SUBJECT, + MCCOY_SUBJECT, + ROLES_CLAIM, + TEST_ROLES_STRING + ); + static final JWTClaimsSet MC_COY_NESTED_ROLES = create( MCCOY_SUBJECT, TEST_AUDIENCE, @@ -54,6 +67,16 @@ class TestJwts { TEST_ROLES_STRING ); + static final JWTClaimsSet MC_COY_NESTED_ROLES_AND_SUBJECT = create( + null, + TEST_AUDIENCE, + TEST_ISSUER, + NESTED_ROLES_CLAIM, + TEST_ROLES_STRING, + NESTED_ROLES_AND_SUBJECT_CLAIM, + MCCOY_SUBJECT + ); + static final JWTClaimsSet MC_COY_NO_AUDIENCE = create(MCCOY_SUBJECT, null, TEST_ISSUER, ROLES_CLAIM, TEST_ROLES_STRING); static final JWTClaimsSet MC_COY_NO_ISSUER = create(MCCOY_SUBJECT, TEST_AUDIENCE, null, ROLES_CLAIM, TEST_ROLES_STRING); @@ -71,8 +94,9 @@ class TestJwts { static final String MC_COY_SIGNED_OCT_1 = createSigned(MC_COY, TestJwk.OCT_1); static final String MC_COY_SIGNED_OCT_2 = createSigned(MC_COY_2, TestJwk.OCT_2); - + static final String MC_COY_SIGNED_NESTED_SUBJECT_OCT_1 = createSigned(MC_COY_NESTED_SUBJECT, TestJwk.OCT_1); static final String MC_COY_SIGNED_NESTED_ROLES_OCT_1 = createSigned(MC_COY_NESTED_ROLES, TestJwk.OCT_1); + static final String MC_COY_SIGNED_NESTED_ROLES_AND_SUBJECT_OCT_1 = createSigned(MC_COY_NESTED_ROLES_AND_SUBJECT, TestJwk.OCT_1); static final String MC_COY_SIGNED_NO_AUDIENCE_OCT_1 = createSigned(MC_COY_NO_AUDIENCE, TestJwk.OCT_1); static final String MC_COY_SIGNED_NO_ISSUER_OCT_1 = createSigned(MC_COY_NO_ISSUER, TestJwk.OCT_1); @@ -94,10 +118,13 @@ static class PeculiarEscaping { static final String MC_COY_SIGNED_RSA_1 = createSignedWithPeculiarEscaping(MC_COY, TestJwk.RSA_1); } + @SuppressWarnings("unchecked") static JWTClaimsSet create(String subject, String audience, String issuer, Object... moreClaims) { JWTClaimsSet.Builder claimsBuilder = new JWTClaimsSet.Builder(); - claimsBuilder.subject(subject); + if (subject != null) { + claimsBuilder.subject(String.valueOf(subject)); + } if (audience != null) { claimsBuilder.audience(audience); } @@ -105,38 +132,57 @@ static JWTClaimsSet create(String subject, String audience, String issuer, Objec claimsBuilder.issuer(issuer); } + Map topLevelClaims = new HashMap<>(); + if (moreClaims != null) { for (int i = 0; i < moreClaims.length; i += 2) { Object claimPath = moreClaims[i]; Object claimValue = moreClaims[i + 1]; if (claimPath instanceof List pathParts) { - // Handle nested path specified as List if (!pathParts.isEmpty()) { - Map nestedMap = new HashMap<>(); - Map currentMap = nestedMap; + String topLevelKey = String.valueOf(pathParts.get(0)); + @SuppressWarnings("unchecked") + Map currentMap = topLevelClaims.containsKey(topLevelKey) + ? (Map) topLevelClaims.get(topLevelKey) + : new HashMap<>(); + + if (!topLevelClaims.containsKey(topLevelKey)) { + topLevelClaims.put(topLevelKey, currentMap); + } + + // Navigate to the correct nested level + for (int j = 1; j < pathParts.size() - 1; j++) { + String key = String.valueOf(pathParts.get(j)); + Map nextMap = currentMap.containsKey(key) + ? (Map) currentMap.get(key) + : new HashMap<>(); - // Build nested structure for all but last element - for (int j = 0; j < pathParts.size() - 1; j++) { - Map nextMap = new HashMap<>(); - currentMap.put(String.valueOf(pathParts.get(j)), nextMap); + if (!currentMap.containsKey(key)) { + currentMap.put(key, nextMap); + } currentMap = nextMap; } - // Set the final value at the deepest level - currentMap.put(String.valueOf(pathParts.get(pathParts.size() - 1)), claimValue); - - // Add the top-level claim - claimsBuilder.claim(String.valueOf(pathParts.get(0)), nestedMap.get(pathParts.get(0))); + // Set the final value + String lastKey = String.valueOf(pathParts.get(pathParts.size() - 1)); + if (claimValue instanceof String && lastKey.equals("roles")) { + // Handle roles as array + currentMap.put(lastKey, Arrays.asList(((String) claimValue).split(","))); + } else { + currentMap.put(lastKey, claimValue); + } } } else { // Handle simple claim - claimsBuilder.claim(String.valueOf(claimPath), claimValue); + topLevelClaims.put(String.valueOf(claimPath), claimValue); } } } - // JwtToken result = new JwtToken(claimsBuilder); + // Add all claims to the builder + topLevelClaims.forEach(claimsBuilder::claim); + return claimsBuilder.build(); }