diff --git a/runtime/service/src/main/java/org/apache/polaris/service/auth/DecodedToken.java b/runtime/service/src/main/java/org/apache/polaris/service/auth/DecodedToken.java deleted file mode 100644 index a66a607d65..0000000000 --- a/runtime/service/src/main/java/org/apache/polaris/service/auth/DecodedToken.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.polaris.service.auth; - -import java.util.Arrays; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * A specialized {@link PolarisCredential} used for internal authentication, when Polaris is the - * identity provider. - */ -public interface DecodedToken extends PolarisCredential { - - String getClientId(); - - String getSub(); - - String getScope(); - - @Override - default String getPrincipalName() { - // Polaris stores the principal ID in the "sub" claim as a string, - // and in the "principal_id" claim as a numeric value. It doesn't store - // the principal name in the token, so we return null here. - return null; - } - - @Override - default Set getPrincipalRoles() { - // Polaris stores the principal roles in the "scope" claim - return Arrays.stream(getScope().split(" ")).collect(Collectors.toSet()); - } -} diff --git a/runtime/service/src/main/java/org/apache/polaris/service/auth/InternalPolarisToken.java b/runtime/service/src/main/java/org/apache/polaris/service/auth/InternalPolarisToken.java new file mode 100644 index 0000000000..00586db872 --- /dev/null +++ b/runtime/service/src/main/java/org/apache/polaris/service/auth/InternalPolarisToken.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.polaris.service.auth; + +import com.google.common.base.Splitter; +import jakarta.annotation.Nonnull; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.polaris.immutables.PolarisImmutable; +import org.immutables.value.Value; + +/** + * A specialized {@link PolarisCredential} used for internal authentication, when Polaris is the + * identity provider. + * + *

Such credentials are created by the Polaris service itself, from a JWT token previously issued + * by Polaris itself. + * + * @see JWTBroker + */ +@PolarisImmutable +abstract class InternalPolarisToken implements PolarisCredential { + + private static final Splitter SCOPE_SPLITTER = Splitter.on(' ').omitEmptyStrings().trimResults(); + + static InternalPolarisToken of( + String principalName, Long principalId, String clientId, String scope) { + return ImmutableInternalPolarisToken.builder() + .principalName(principalName) + .principalId(principalId) + .clientId(clientId) + .scope(scope) + .build(); + } + + @Nonnull // switch from nullable to non-nullable + @Override + @SuppressWarnings("NullableProblems") + public abstract String getPrincipalName(); + + @Nonnull // switch from nullable to non-nullable + @Override + @SuppressWarnings("NullableProblems") + public abstract Long getPrincipalId(); + + @Value.Lazy + @Override + public Set getPrincipalRoles() { + // Polaris stores roles in the scope claim + return SCOPE_SPLITTER.splitToStream(getScope()).collect(Collectors.toSet()); + } + + abstract String getClientId(); + + abstract String getScope(); +} diff --git a/runtime/service/src/main/java/org/apache/polaris/service/auth/JWTBroker.java b/runtime/service/src/main/java/org/apache/polaris/service/auth/JWTBroker.java index 559ba2bc57..a5e845d74b 100644 --- a/runtime/service/src/main/java/org/apache/polaris/service/auth/JWTBroker.java +++ b/runtime/service/src/main/java/org/apache/polaris/service/auth/JWTBroker.java @@ -20,12 +20,10 @@ import com.auth0.jwt.JWT; import com.auth0.jwt.algorithms.Algorithm; -import com.auth0.jwt.exceptions.JWTVerificationException; import com.auth0.jwt.interfaces.DecodedJWT; import com.auth0.jwt.interfaces.JWTVerifier; import java.time.Instant; import java.time.temporal.ChronoUnit; -import java.util.Objects; import java.util.Optional; import java.util.UUID; import org.apache.iceberg.exceptions.NotAuthorizedException; @@ -60,34 +58,22 @@ public abstract class JWTBroker implements TokenBroker { public abstract Algorithm getAlgorithm(); @Override - public DecodedToken verify(String token) { + public PolarisCredential verify(String token) { + return verifyInternal(token); + } + + private InternalPolarisToken verifyInternal(String token) { JWTVerifier verifier = JWT.require(getAlgorithm()).withClaim(CLAIM_KEY_ACTIVE, true).build(); try { DecodedJWT decodedJWT = verifier.verify(token); - return new DecodedToken() { - @Override - public Long getPrincipalId() { - return decodedJWT.getClaim("principalId").asLong(); - } - - @Override - public String getClientId() { - return decodedJWT.getClaim("client_id").asString(); - } - - @Override - public String getSub() { - return decodedJWT.getSubject(); - } - - @Override - public String getScope() { - return decodedJWT.getClaim("scope").asString(); - } - }; - - } catch (JWTVerificationException e) { + return InternalPolarisToken.of( + decodedJWT.getSubject(), + decodedJWT.getClaim(CLAIM_KEY_PRINCIPAL_ID).asLong(), + decodedJWT.getClaim(CLAIM_KEY_CLIENT_ID).asString(), + decodedJWT.getClaim(CLAIM_KEY_SCOPE).asString()); + + } catch (Exception e) { throw (NotAuthorizedException) new NotAuthorizedException("Failed to verify the token").initCause(e); } @@ -110,26 +96,26 @@ public TokenResponse generateFromToken( if (subjectToken == null || subjectToken.isBlank()) { return new TokenResponse(OAuthTokenErrorResponse.Error.invalid_request); } - DecodedToken decodedToken; + InternalPolarisToken decodedToken; try { - decodedToken = verify(subjectToken); + decodedToken = verifyInternal(subjectToken); } catch (NotAuthorizedException e) { LOGGER.error("Failed to verify the token", e.getCause()); return new TokenResponse(Error.invalid_client); } EntityResult principalLookup = metaStoreManager.loadEntity( - polarisCallContext, - 0L, - Objects.requireNonNull(decodedToken.getPrincipalId()), - PolarisEntityType.PRINCIPAL); + polarisCallContext, 0L, decodedToken.getPrincipalId(), PolarisEntityType.PRINCIPAL); if (!principalLookup.isSuccess() || principalLookup.getEntity().getType() != PolarisEntityType.PRINCIPAL) { return new TokenResponse(OAuthTokenErrorResponse.Error.unauthorized_client); } String tokenString = generateTokenString( - decodedToken.getClientId(), decodedToken.getScope(), decodedToken.getPrincipalId()); + decodedToken.getPrincipalName(), + decodedToken.getPrincipalId(), + decodedToken.getClientId(), + decodedToken.getScope()); return new TokenResponse( tokenString, TokenType.ACCESS_TOKEN.getValue(), maxTokenGenerationInSeconds); } @@ -156,16 +142,18 @@ public TokenResponse generateFromClientSecrets( if (principal.isEmpty()) { return new TokenResponse(OAuthTokenErrorResponse.Error.unauthorized_client); } - String tokenString = generateTokenString(clientId, scope, principal.get().getId()); + String tokenString = + generateTokenString(principal.get().getName(), principal.get().getId(), clientId, scope); return new TokenResponse( tokenString, TokenType.ACCESS_TOKEN.getValue(), maxTokenGenerationInSeconds); } - private String generateTokenString(String clientId, String scope, Long principalId) { + private String generateTokenString( + String principalName, long principalId, String clientId, String scope) { Instant now = Instant.now(); return JWT.create() .withIssuer(ISSUER_KEY) - .withSubject(String.valueOf(principalId)) + .withSubject(principalName) .withIssuedAt(now) .withExpiresAt(now.plus(maxTokenGenerationInSeconds, ChronoUnit.SECONDS)) .withJWTId(UUID.randomUUID().toString()) diff --git a/runtime/service/src/main/java/org/apache/polaris/service/auth/NoneTokenBrokerFactory.java b/runtime/service/src/main/java/org/apache/polaris/service/auth/NoneTokenBrokerFactory.java index 5744cef2e9..4a94908abd 100644 --- a/runtime/service/src/main/java/org/apache/polaris/service/auth/NoneTokenBrokerFactory.java +++ b/runtime/service/src/main/java/org/apache/polaris/service/auth/NoneTokenBrokerFactory.java @@ -64,7 +64,7 @@ public TokenResponse generateFromToken( } @Override - public DecodedToken verify(String token) { + public PolarisCredential verify(String token) { return null; } }; diff --git a/runtime/service/src/main/java/org/apache/polaris/service/auth/TokenBroker.java b/runtime/service/src/main/java/org/apache/polaris/service/auth/TokenBroker.java index 010490dc0e..e424dff1aa 100644 --- a/runtime/service/src/main/java/org/apache/polaris/service/auth/TokenBroker.java +++ b/runtime/service/src/main/java/org/apache/polaris/service/auth/TokenBroker.java @@ -61,7 +61,8 @@ TokenResponse generateFromToken( PolarisCallContext polarisCallContext, TokenType requestedTokenType); - DecodedToken verify(String token); + /** Decodes and verifies the token, then returns the associated {@link PolarisCredential}. */ + PolarisCredential verify(String token); static @Nonnull Optional findPrincipalEntity( PolarisMetaStoreManager metaStoreManager, diff --git a/runtime/service/src/main/java/org/apache/polaris/service/auth/internal/InternalAuthenticationMechanism.java b/runtime/service/src/main/java/org/apache/polaris/service/auth/internal/InternalAuthenticationMechanism.java index f0e63efb1a..657c4810a2 100644 --- a/runtime/service/src/main/java/org/apache/polaris/service/auth/internal/InternalAuthenticationMechanism.java +++ b/runtime/service/src/main/java/org/apache/polaris/service/auth/internal/InternalAuthenticationMechanism.java @@ -38,7 +38,7 @@ import java.util.Set; import org.apache.polaris.service.auth.AuthenticationRealmConfiguration; import org.apache.polaris.service.auth.AuthenticationType; -import org.apache.polaris.service.auth.DecodedToken; +import org.apache.polaris.service.auth.PolarisCredential; import org.apache.polaris.service.auth.TokenBroker; /** @@ -90,7 +90,7 @@ public Uni authenticate( String credential = authHeader.substring(spaceIdx + 1); - DecodedToken token; + PolarisCredential token; try { token = tokenBroker.verify(credential); } catch (Exception e) { diff --git a/runtime/service/src/test/java/org/apache/polaris/service/auth/DefaultAuthenticatorTest.java b/runtime/service/src/test/java/org/apache/polaris/service/auth/DefaultAuthenticatorTest.java index 85d98ece55..e121fb75bb 100644 --- a/runtime/service/src/test/java/org/apache/polaris/service/auth/DefaultAuthenticatorTest.java +++ b/runtime/service/src/test/java/org/apache/polaris/service/auth/DefaultAuthenticatorTest.java @@ -52,7 +52,7 @@ public void setUp() { @Test public void testFetchPrincipalThrowsServiceExceptionOnMetastoreException() { - DecodedToken token = Mockito.mock(DecodedToken.class); + PolarisCredential token = Mockito.mock(PolarisCredential.class); long principalId = 100L; when(token.getPrincipalId()).thenReturn(principalId); when(metaStoreManager.loadEntity( @@ -69,10 +69,9 @@ public void testFetchPrincipalThrowsServiceExceptionOnMetastoreException() { @Test public void testFetchPrincipalThrowsNotAuthorizedWhenNotFound() { - DecodedToken token = Mockito.mock(DecodedToken.class); + PolarisCredential token = Mockito.mock(PolarisCredential.class); long principalId = 100L; when(token.getPrincipalId()).thenReturn(principalId); - when(token.getClientId()).thenReturn("abc"); when(metaStoreManager.loadEntity( authenticator.callContext.getPolarisCallContext(), 0L, diff --git a/runtime/service/src/test/java/org/apache/polaris/service/auth/internal/InternalAuthenticationMechanismTest.java b/runtime/service/src/test/java/org/apache/polaris/service/auth/internal/InternalAuthenticationMechanismTest.java index 1a87b38539..a7c3308bdd 100644 --- a/runtime/service/src/test/java/org/apache/polaris/service/auth/internal/InternalAuthenticationMechanismTest.java +++ b/runtime/service/src/test/java/org/apache/polaris/service/auth/internal/InternalAuthenticationMechanismTest.java @@ -34,7 +34,7 @@ import org.apache.iceberg.exceptions.NotAuthorizedException; import org.apache.polaris.service.auth.AuthenticationRealmConfiguration; import org.apache.polaris.service.auth.AuthenticationType; -import org.apache.polaris.service.auth.DecodedToken; +import org.apache.polaris.service.auth.PolarisCredential; import org.apache.polaris.service.auth.TokenBroker; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -156,7 +156,7 @@ public void testAuthenticateWithValidToken() { when(routingContext.request()).thenReturn(mock(io.vertx.core.http.HttpServerRequest.class)); when(routingContext.request().getHeader("Authorization")).thenReturn("Bearer validToken"); - DecodedToken decodedToken = mock(DecodedToken.class); + PolarisCredential decodedToken = mock(PolarisCredential.class); when(tokenBroker.verify("validToken")).thenReturn(decodedToken); SecurityIdentity securityIdentity = mock(SecurityIdentity.class);