diff --git a/service/common/src/main/java/org/apache/polaris/service/storage/StorageConfiguration.java b/service/common/src/main/java/org/apache/polaris/service/storage/StorageConfiguration.java index 711cf032e9..188a2b85c1 100644 --- a/service/common/src/main/java/org/apache/polaris/service/storage/StorageConfiguration.java +++ b/service/common/src/main/java/org/apache/polaris/service/storage/StorageConfiguration.java @@ -20,6 +20,7 @@ import com.google.auth.oauth2.AccessToken; import com.google.auth.oauth2.GoogleCredentials; +import com.google.common.base.Suppliers; import java.io.IOException; import java.time.Duration; import java.time.Instant; @@ -60,38 +61,40 @@ public interface StorageConfiguration { Optional gcpAccessTokenLifespan(); default Supplier stsClientSupplier() { - return () -> { - StsClientBuilder stsClientBuilder = StsClient.builder(); - if (awsAccessKey().isPresent() && awsSecretKey().isPresent()) { - LoggerFactory.getLogger(StorageConfiguration.class) - .warn("Using hard-coded AWS credentials - this is not recommended for production"); - StaticCredentialsProvider awsCredentialsProvider = - StaticCredentialsProvider.create( - AwsBasicCredentials.create(awsAccessKey().get(), awsSecretKey().get())); - stsClientBuilder.credentialsProvider(awsCredentialsProvider); - } - return stsClientBuilder.build(); - }; + return Suppliers.memoize( + () -> { + StsClientBuilder stsClientBuilder = StsClient.builder(); + if (awsAccessKey().isPresent() && awsSecretKey().isPresent()) { + LoggerFactory.getLogger(StorageConfiguration.class) + .warn("Using hard-coded AWS credentials - this is not recommended for production"); + StaticCredentialsProvider awsCredentialsProvider = + StaticCredentialsProvider.create( + AwsBasicCredentials.create(awsAccessKey().get(), awsSecretKey().get())); + stsClientBuilder.credentialsProvider(awsCredentialsProvider); + } + return stsClientBuilder.build(); + }); } default Supplier gcpCredentialsSupplier() { - return () -> { - if (gcpAccessToken().isEmpty()) { - try { - return GoogleCredentials.getApplicationDefault(); - } catch (IOException e) { - throw new RuntimeException("Failed to get GCP credentials", e); - } - } else { - AccessToken accessToken = - new AccessToken( - gcpAccessToken().get(), - new Date( - Instant.now() - .plus(gcpAccessTokenLifespan().orElse(DEFAULT_TOKEN_LIFESPAN)) - .toEpochMilli())); - return GoogleCredentials.create(accessToken); - } - }; + return Suppliers.memoize( + () -> { + if (gcpAccessToken().isEmpty()) { + try { + return GoogleCredentials.getApplicationDefault(); + } catch (IOException e) { + throw new RuntimeException("Failed to get GCP credentials", e); + } + } else { + AccessToken accessToken = + new AccessToken( + gcpAccessToken().get(), + new Date( + Instant.now() + .plus(gcpAccessTokenLifespan().orElse(DEFAULT_TOKEN_LIFESPAN)) + .toEpochMilli())); + return GoogleCredentials.create(accessToken); + } + }); } } diff --git a/service/common/src/test/java/org/apache/polaris/service/storage/StorageConfigurationTest.java b/service/common/src/test/java/org/apache/polaris/service/storage/StorageConfigurationTest.java new file mode 100644 index 0000000000..61eb174250 --- /dev/null +++ b/service/common/src/test/java/org/apache/polaris/service/storage/StorageConfigurationTest.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.polaris.service.storage; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.*; + +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.GoogleCredentials; +import java.time.Duration; +import java.time.Instant; +import java.util.Optional; +import java.util.function.Supplier; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.services.sts.StsClientBuilder; + +public class StorageConfigurationTest { + + private static final String TEST_ACCESS_KEY = "test-access-key"; + private static final String TEST_GCP_TOKEN = "ya29.test-token"; + private static final String TEST_SECRET_KEY = "test-secret-key"; + private static final Duration TEST_TOKEN_LIFESPAN = Duration.ofMinutes(20); + + private StorageConfiguration configWithAwsCredentialsAndGcpToken() { + return new StorageConfiguration() { + @Override + public Optional awsAccessKey() { + return Optional.of(TEST_ACCESS_KEY); + } + + @Override + public Optional awsSecretKey() { + return Optional.of(TEST_SECRET_KEY); + } + + @Override + public Optional gcpAccessToken() { + return Optional.of(TEST_GCP_TOKEN); + } + + @Override + public Optional gcpAccessTokenLifespan() { + return Optional.of(TEST_TOKEN_LIFESPAN); + } + }; + } + + private StorageConfiguration configWithoutGcpToken() { + return new StorageConfiguration() { + @Override + public Optional awsAccessKey() { + return Optional.empty(); + } + + @Override + public Optional awsSecretKey() { + return Optional.empty(); + } + + @Override + public Optional gcpAccessToken() { + return Optional.empty(); + } + + @Override + public Optional gcpAccessTokenLifespan() { + return Optional.empty(); + } + }; + } + + @Test + public void testSingletonStsClientWithStaticCredentials() { + StsClientBuilder mockBuilder = mock(StsClientBuilder.class); + StsClient mockStsClient = mock(StsClient.class); + ArgumentCaptor credsCaptor = + ArgumentCaptor.forClass(StaticCredentialsProvider.class); + + when(mockBuilder.credentialsProvider(credsCaptor.capture())).thenReturn(mockBuilder); + when(mockBuilder.region(any())).thenReturn(mockBuilder); + when(mockBuilder.build()).thenReturn(mockStsClient); + + try (MockedStatic staticMock = Mockito.mockStatic(StsClient.class)) { + staticMock.when(StsClient::builder).thenReturn(mockBuilder); + + StorageConfiguration config = configWithAwsCredentialsAndGcpToken(); + Supplier supplier = config.stsClientSupplier(); + StsClient client1 = supplier.get(); + StsClient client2 = supplier.get(); + + assertThat(client1).isSameAs(client2); + assertThat(client1).isNotNull(); + + StaticCredentialsProvider credentialsProvider = credsCaptor.getValue(); + assertThat(credentialsProvider.resolveCredentials().accessKeyId()).isEqualTo(TEST_ACCESS_KEY); + assertThat(credentialsProvider.resolveCredentials().secretAccessKey()) + .isEqualTo(TEST_SECRET_KEY); + } + } + + @Test + public void testCreateGcpCredentialsFromStaticToken() { + Supplier supplier = + configWithAwsCredentialsAndGcpToken().gcpCredentialsSupplier(); + + GoogleCredentials credentials = supplier.get(); + assertThat(credentials).isNotNull(); + + AccessToken accessToken = credentials.getAccessToken(); + assertThat(accessToken).isNotNull(); + assertThat(accessToken.getTokenValue()).isEqualTo(TEST_GCP_TOKEN); + long expectedExpiry = Instant.now().plus(Duration.ofMinutes(20)).toEpochMilli(); + long actualExpiry = accessToken.getExpirationTime().getTime(); + assertThat(actualExpiry).isBetween(expectedExpiry - 500, expectedExpiry + 500); + } + + @Test + public void testGcpCredentialsFromDefault() { + GoogleCredentials mockDefaultCreds = mock(GoogleCredentials.class); + + try (MockedStatic mockedStatic = + Mockito.mockStatic(GoogleCredentials.class)) { + + mockedStatic.when(GoogleCredentials::getApplicationDefault).thenReturn(mockDefaultCreds); + + Supplier supplier = configWithoutGcpToken().gcpCredentialsSupplier(); + GoogleCredentials result = supplier.get(); + + assertThat(result).isSameAs(mockDefaultCreds); + mockedStatic.verify(GoogleCredentials::getApplicationDefault, times(1)); + } + } +}