Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,15 @@ public abstract class JWTBroker implements TokenBroker {
private static final String CLAIM_KEY_SCOPE = "scope";

private final PolarisMetaStoreManager metaStoreManager;
private final PolarisCallContext polarisCallContext;
private final int maxTokenGenerationInSeconds;

JWTBroker(PolarisMetaStoreManager metaStoreManager, int maxTokenGenerationInSeconds) {
JWTBroker(
PolarisMetaStoreManager metaStoreManager,
PolarisCallContext polarisCallContext,
int maxTokenGenerationInSeconds) {
this.metaStoreManager = metaStoreManager;
this.polarisCallContext = polarisCallContext;
this.maxTokenGenerationInSeconds = maxTokenGenerationInSeconds;
}

Expand Down Expand Up @@ -86,7 +91,6 @@ public TokenResponse generateFromToken(
String subjectToken,
String grantType,
String scope,
PolarisCallContext polarisCallContext,
TokenType requestedTokenType) {
if (requestedTokenType != null && !TokenType.ACCESS_TOKEN.equals(requestedTokenType)) {
return TokenResponse.of(OAuthError.invalid_request);
Expand Down Expand Up @@ -125,7 +129,6 @@ public TokenResponse generateFromClientSecrets(
String clientSecret,
String grantType,
String scope,
PolarisCallContext polarisCallContext,
TokenType requestedTokenType) {
// Initial sanity checks
TokenRequestValidator validator = new TokenRequestValidator();
Expand All @@ -135,8 +138,7 @@ public TokenResponse generateFromClientSecrets(
return TokenResponse.of(initialValidationResponse.get());
}

Optional<PrincipalEntity> principal =
findPrincipalEntity(clientId, clientSecret, polarisCallContext);
Optional<PrincipalEntity> principal = findPrincipalEntity(clientId, clientSecret);
if (principal.isEmpty()) {
return TokenResponse.of(OAuthError.unauthorized_client);
}
Expand Down Expand Up @@ -176,8 +178,7 @@ private String scopes(String scope) {
return scope == null || scope.isBlank() ? DefaultAuthenticator.PRINCIPAL_ROLE_ALL : scope;
}

private Optional<PrincipalEntity> findPrincipalEntity(
String clientId, String clientSecret, PolarisCallContext polarisCallContext) {
private Optional<PrincipalEntity> findPrincipalEntity(String clientId, String clientSecret) {
// Validate the principal is present and secrets match
PrincipalSecretsResult principalSecrets =
metaStoreManager.loadPrincipalSecrets(polarisCallContext, clientId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.auth0.jwt.algorithms.Algorithm;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import org.apache.polaris.core.PolarisCallContext;
import org.apache.polaris.core.persistence.PolarisMetaStoreManager;

/** Generates a JWT using a Public/Private RSA Key */
Expand All @@ -30,9 +31,10 @@ public class RSAKeyPairJWTBroker extends JWTBroker {

RSAKeyPairJWTBroker(
PolarisMetaStoreManager metaStoreManager,
PolarisCallContext polarisCallContext,
int maxTokenGenerationInSeconds,
KeyProvider keyProvider) {
super(metaStoreManager, maxTokenGenerationInSeconds);
super(metaStoreManager, polarisCallContext, maxTokenGenerationInSeconds);
this.keyProvider = keyProvider;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import java.time.Duration;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.apache.polaris.core.PolarisCallContext;
import org.apache.polaris.core.context.RealmContext;
import org.apache.polaris.core.persistence.MetaStoreManagerFactory;
import org.apache.polaris.core.persistence.PolarisMetaStoreManager;
import org.apache.polaris.service.auth.AuthenticationConfiguration;
import org.apache.polaris.service.auth.AuthenticationRealmConfiguration;
Expand All @@ -36,38 +36,32 @@
@Identifier("rsa-key-pair")
public class RSAKeyPairJWTBrokerFactory implements TokenBrokerFactory {

private final MetaStoreManagerFactory metaStoreManagerFactory;
private final AuthenticationConfiguration authenticationConfiguration;

private final ConcurrentMap<String, RSAKeyPairJWTBroker> tokenBrokers = new ConcurrentHashMap<>();
private final ConcurrentMap<String, KeyProvider> keyProviders = new ConcurrentHashMap<>();

@Inject
public RSAKeyPairJWTBrokerFactory(
MetaStoreManagerFactory metaStoreManagerFactory,
AuthenticationConfiguration authenticationConfiguration) {
this.metaStoreManagerFactory = metaStoreManagerFactory;
public RSAKeyPairJWTBrokerFactory(AuthenticationConfiguration authenticationConfiguration) {
this.authenticationConfiguration = authenticationConfiguration;
}

@Override
public TokenBroker apply(RealmContext realmContext) {
return tokenBrokers.computeIfAbsent(
realmContext.getRealmIdentifier(), k -> createTokenBroker(realmContext));
}

private RSAKeyPairJWTBroker createTokenBroker(RealmContext realmContext) {
public TokenBroker create(
PolarisMetaStoreManager metaStoreManager, PolarisCallContext polarisCallContext) {
RealmContext realmContext = polarisCallContext.getRealmContext();
AuthenticationRealmConfiguration config = authenticationConfiguration.forRealm(realmContext);
Duration maxTokenGeneration = config.tokenBroker().maxTokenGeneration();
KeyProvider keyProvider =
config
.tokenBroker()
.rsaKeyPair()
.map(this::fileSystemKeyPair)
.orElseGet(this::generateEphemeralKeyPair);
PolarisMetaStoreManager metaStoreManager =
metaStoreManagerFactory.getOrCreateMetaStoreManager(realmContext);
keyProviders.computeIfAbsent(
realmContext.getRealmIdentifier(),
k ->
config
.tokenBroker()
.rsaKeyPair()
.map(this::fileSystemKeyPair)
.orElseGet(this::generateEphemeralKeyPair));
return new RSAKeyPairJWTBroker(
metaStoreManager, (int) maxTokenGeneration.toSeconds(), keyProvider);
metaStoreManager, polarisCallContext, (int) maxTokenGeneration.toSeconds(), keyProvider);
}

private KeyProvider fileSystemKeyPair(RSAKeyPairConfiguration config) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.auth0.jwt.algorithms.Algorithm;
import java.util.function.Supplier;
import org.apache.polaris.core.PolarisCallContext;
import org.apache.polaris.core.persistence.PolarisMetaStoreManager;

/** Generates a JWT using a Symmetric Key. */
Expand All @@ -28,9 +29,10 @@ public class SymmetricKeyJWTBroker extends JWTBroker {

public SymmetricKeyJWTBroker(
PolarisMetaStoreManager metaStoreManager,
PolarisCallContext polarisCallContext,
int maxTokenGenerationInSeconds,
Supplier<String> secretSupplier) {
super(metaStoreManager, maxTokenGenerationInSeconds);
super(metaStoreManager, polarisCallContext, maxTokenGenerationInSeconds);
this.secretSupplier = secretSupplier;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Supplier;
import org.apache.polaris.core.PolarisCallContext;
import org.apache.polaris.core.context.RealmContext;
import org.apache.polaris.core.persistence.MetaStoreManagerFactory;
import org.apache.polaris.core.persistence.PolarisMetaStoreManager;
import org.apache.polaris.service.auth.AuthenticationConfiguration;
import org.apache.polaris.service.auth.AuthenticationRealmConfiguration;
import org.apache.polaris.service.auth.AuthenticationRealmConfiguration.TokenBrokerConfiguration.SymmetricKeyConfiguration;
Expand All @@ -40,51 +42,46 @@
@Identifier("symmetric-key")
public class SymmetricKeyJWTBrokerFactory implements TokenBrokerFactory {

private final MetaStoreManagerFactory metaStoreManagerFactory;
private final AuthenticationConfiguration authenticationConfiguration;

private final ConcurrentMap<String, SymmetricKeyJWTBroker> tokenBrokers =
new ConcurrentHashMap<>();
private final ConcurrentMap<String, Supplier<String>> secretSuppliers = new ConcurrentHashMap<>();

@Inject
public SymmetricKeyJWTBrokerFactory(
MetaStoreManagerFactory metaStoreManagerFactory,
AuthenticationConfiguration authenticationConfiguration) {
this.metaStoreManagerFactory = metaStoreManagerFactory;
public SymmetricKeyJWTBrokerFactory(AuthenticationConfiguration authenticationConfiguration) {
this.authenticationConfiguration = authenticationConfiguration;
}

@Override
public TokenBroker apply(RealmContext realmContext) {
return tokenBrokers.computeIfAbsent(
realmContext.getRealmIdentifier(), k -> createTokenBroker(realmContext));
}

private SymmetricKeyJWTBroker createTokenBroker(RealmContext realmContext) {
public TokenBroker create(
PolarisMetaStoreManager metaStoreManager, PolarisCallContext polarisCallContext) {
RealmContext realmContext = polarisCallContext.getRealmContext();
AuthenticationRealmConfiguration config = authenticationConfiguration.forRealm(realmContext);
Duration maxTokenGeneration = config.tokenBroker().maxTokenGeneration();
SymmetricKeyConfiguration symmetricKeyConfiguration =
config
.tokenBroker()
.symmetricKey()
.orElseThrow(() -> new IllegalStateException("Symmetric key configuration is missing"));
String secret = symmetricKeyConfiguration.secret().orElse(null);
Path file = symmetricKeyConfiguration.file().orElse(null);
checkState(secret != null || file != null, "Either file or secret must be set");
Supplier<String> secretSupplier = secret != null ? () -> secret : readSecretFromDisk(file);
Supplier<String> secretSupplier =
secretSuppliers.computeIfAbsent(
realmContext.getRealmIdentifier(),
k -> {
SymmetricKeyConfiguration symmetricKeyConfiguration =
config
.tokenBroker()
.symmetricKey()
.orElseThrow(
() ->
new IllegalStateException("Symmetric key configuration is missing"));
String secret = symmetricKeyConfiguration.secret().orElse(null);
Path file = symmetricKeyConfiguration.file().orElse(null);
checkState(secret != null || file != null, "Either file or secret must be set");
return () -> Objects.requireNonNullElseGet(secret, () -> readSecretFromDisk(file));
});
return new SymmetricKeyJWTBroker(
metaStoreManagerFactory.getOrCreateMetaStoreManager(realmContext),
(int) maxTokenGeneration.toSeconds(),
secretSupplier);
metaStoreManager, polarisCallContext, (int) maxTokenGeneration.toSeconds(), secretSupplier);
}

private static Supplier<String> readSecretFromDisk(Path file) {
return () -> {
try {
return Files.readString(file);
} catch (IOException e) {
throw new RuntimeException("Failed to read secret from file: " + file, e);
}
};
private static String readSecretFromDisk(Path file) {
try {
return Files.readString(file);
} catch (IOException e) {
throw new RuntimeException("Failed to read secret from file: " + file, e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/
package org.apache.polaris.service.auth.internal.broker;

import org.apache.polaris.core.PolarisCallContext;
import org.apache.polaris.service.auth.PolarisCredential;
import org.apache.polaris.service.types.TokenType;

Expand All @@ -39,7 +38,6 @@ TokenResponse generateFromClientSecrets(
final String clientSecret,
final String grantType,
final String scope,
PolarisCallContext polarisCallContext,
TokenType requestedTokenType);

/**
Expand All @@ -52,7 +50,6 @@ TokenResponse generateFromToken(
String subjectToken,
final String grantType,
final String scope,
PolarisCallContext polarisCallContext,
TokenType requestedTokenType);

/** Decodes and verifies the token, then returns the associated {@link PolarisCredential}. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
*/
package org.apache.polaris.service.auth.internal.broker;

import java.util.function.Function;
import org.apache.polaris.core.context.RealmContext;
import org.apache.polaris.core.PolarisCallContext;
import org.apache.polaris.core.persistence.PolarisMetaStoreManager;

/**
* Factory that creates a {@link TokenBroker} for generating and parsing. The {@link TokenBroker} is
* created based on the realm context.
*/
public interface TokenBrokerFactory extends Function<RealmContext, TokenBroker> {}
public interface TokenBrokerFactory {
TokenBroker create(
PolarisMetaStoreManager metaStoreManager, PolarisCallContext polarisCallContext);
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import jakarta.ws.rs.core.SecurityContext;
import java.util.Base64;
import org.apache.iceberg.rest.responses.OAuthTokenResponse;
import org.apache.polaris.core.context.CallContext;
import org.apache.polaris.core.context.RealmContext;
import org.apache.polaris.service.auth.internal.broker.TokenBroker;
import org.apache.polaris.service.auth.internal.broker.TokenResponse;
Expand All @@ -49,12 +48,10 @@ public class DefaultOAuth2ApiService implements IcebergRestOAuth2ApiService {
private static final String BEARER = "bearer";

private final TokenBroker tokenBroker;
private final CallContext callContext;

@Inject
public DefaultOAuth2ApiService(TokenBroker tokenBroker, CallContext callContext) {
public DefaultOAuth2ApiService(TokenBroker tokenBroker) {
this.tokenBroker = tokenBroker;
this.callContext = callContext;
}

@Override
Expand Down Expand Up @@ -104,21 +101,11 @@ public Response getToken(
if (clientSecret != null) {
tokenResponse =
tokenBroker.generateFromClientSecrets(
clientId,
clientSecret,
grantType,
scope,
callContext.getPolarisCallContext(),
requestedTokenType);
clientId, clientSecret, grantType, scope, requestedTokenType);
} else if (subjectToken != null) {
tokenResponse =
tokenBroker.generateFromToken(
subjectTokenType,
subjectToken,
grantType,
scope,
callContext.getPolarisCallContext(),
requestedTokenType);
subjectTokenType, subjectToken, grantType, scope, requestedTokenType);
} else {
return OAuthUtils.getResponseFromError(OAuthError.invalid_request);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,13 +398,14 @@ public IcebergRestOAuth2ApiService icebergRestOAuth2ApiService(
@RequestScoped
public TokenBroker tokenBroker(
AuthenticationRealmConfiguration config,
RealmContext realmContext,
@Any Instance<TokenBrokerFactory> tokenBrokerFactories) {
@Any Instance<TokenBrokerFactory> tokenBrokerFactories,
PolarisMetaStoreManager polarisMetaStoreManager,
CallContext callContext) {
String type =
config.type() == AuthenticationType.EXTERNAL ? "none" : config.tokenBroker().type();
TokenBrokerFactory tokenBrokerFactory =
tokenBrokerFactories.select(Identifier.Literal.of(type)).get();
return tokenBrokerFactory.apply(realmContext);
return tokenBrokerFactory.create(polarisMetaStoreManager, callContext.getPolarisCallContext());
}

// other beans
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ public void testJWTSymmetricKeyGenerator() {
new PrincipalEntity.Builder().setId(principalId).setName("principal").build();
Mockito.when(metastoreManager.findPrincipalById(polarisCallContext, principalId))
.thenReturn(Optional.of(principal));
TokenBroker generator = new SymmetricKeyJWTBroker(metastoreManager, 666, () -> "polaris");
TokenBroker generator =
new SymmetricKeyJWTBroker(metastoreManager, polarisCallContext, 666, () -> "polaris");
TokenResponse token =
generator.generateFromClientSecrets(
clientId,
mainSecret,
TokenRequestValidator.CLIENT_CREDENTIALS,
"PRINCIPAL_ROLE:TEST",
polarisCallContext,
TokenType.ACCESS_TOKEN);
assertThat(token).isNotNull();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ public void testSuccessfulTokenGeneration() throws Exception {
Mockito.when(metastoreManager.findPrincipalById(polarisCallContext, principalId))
.thenReturn(Optional.of(principal));
KeyProvider provider = new LocalRSAKeyProvider(keyPair);
TokenBroker tokenBroker = new RSAKeyPairJWTBroker(metastoreManager, 420, provider);
TokenBroker tokenBroker =
new RSAKeyPairJWTBroker(metastoreManager, polarisCallContext, 420, provider);
TokenResponse token =
tokenBroker.generateFromClientSecrets(
clientId,
mainSecret,
TokenRequestValidator.CLIENT_CREDENTIALS,
scope,
polarisCallContext,
TokenType.ACCESS_TOKEN);
assertThat(token).isNotNull();
assertThat(token.getExpiresIn()).isEqualTo(420);
Expand Down
Loading