diff --git a/azure/src/main/java/org/apache/iceberg/azure/AzureProperties.java b/azure/src/main/java/org/apache/iceberg/azure/AzureProperties.java index 8313000ab35e..f03c28ab6a75 100644 --- a/azure/src/main/java/org/apache/iceberg/azure/AzureProperties.java +++ b/azure/src/main/java/org/apache/iceberg/azure/AzureProperties.java @@ -25,23 +25,42 @@ import java.util.Collections; import java.util.Map; import java.util.Optional; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.azure.adlsv2.VendedAdlsCredentialProvider; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.base.Strings; import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.rest.RESTUtil; import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SerializableMap; public class AzureProperties implements Serializable { public static final String ADLS_SAS_TOKEN_PREFIX = "adls.sas-token."; + public static final String ADLS_SAS_TOKEN_EXPIRES_AT_MS_PREFIX = "adls.sas-token-expires-at-ms."; public static final String ADLS_CONNECTION_STRING_PREFIX = "adls.connection-string."; public static final String ADLS_READ_BLOCK_SIZE = "adls.read.block-size-bytes"; public static final String ADLS_WRITE_BLOCK_SIZE = "adls.write.block-size-bytes"; public static final String ADLS_SHARED_KEY_ACCOUNT_NAME = "adls.auth.shared-key.account.name"; public static final String ADLS_SHARED_KEY_ACCOUNT_KEY = "adls.auth.shared-key.account.key"; + /** + * When set, the {@link VendedAdlsCredentialProvider} will be used to fetch and refresh vended + * credentials from this endpoint. + */ + public static final String ADLS_REFRESH_CREDENTIALS_ENDPOINT = + "adls.refresh-credentials-endpoint"; + + /** Controls whether vended credentials should be refreshed or not. Defaults to true. */ + public static final String ADLS_REFRESH_CREDENTIALS_ENABLED = "adls.refresh-credentials-enabled"; + private Map adlsSasTokens = Collections.emptyMap(); private Map adlsConnectionStrings = Collections.emptyMap(); private Map.Entry namedKeyCreds; private Integer adlsReadBlockSize; private Long adlsWriteBlockSize; + private String adlsRefreshCredentialsEndpoint; + private boolean adlsRefreshCredentialsEnabled; + private Map allProperties; public AzureProperties() {} @@ -67,6 +86,13 @@ public AzureProperties(Map properties) { if (properties.containsKey(ADLS_WRITE_BLOCK_SIZE)) { this.adlsWriteBlockSize = Long.parseLong(properties.get(ADLS_WRITE_BLOCK_SIZE)); } + this.adlsRefreshCredentialsEndpoint = + RESTUtil.resolveEndpoint( + properties.get(CatalogProperties.URI), + properties.get(ADLS_REFRESH_CREDENTIALS_ENDPOINT)); + this.adlsRefreshCredentialsEnabled = + PropertyUtil.propertyAsBoolean(properties, ADLS_REFRESH_CREDENTIALS_ENABLED, true); + this.allProperties = SerializableMap.copyOf(properties); } public Optional adlsReadBlockSize() { @@ -77,6 +103,17 @@ public Optional adlsWriteBlockSize() { return Optional.ofNullable(adlsWriteBlockSize); } + public Optional vendedAdlsCredentialProvider() { + if (adlsRefreshCredentialsEnabled && !Strings.isNullOrEmpty(adlsRefreshCredentialsEndpoint)) { + Map credentialProviderProperties = Maps.newHashMap(allProperties); + credentialProviderProperties.put( + VendedAdlsCredentialProvider.URI, adlsRefreshCredentialsEndpoint); + return Optional.of(new VendedAdlsCredentialProvider(credentialProviderProperties)); + } else { + return Optional.empty(); + } + } + /** * Applies configuration to the {@link DataLakeFileSystemClientBuilder} to provide the endpoint * and credentials required to create an instance of the client. @@ -87,14 +124,16 @@ public Optional adlsWriteBlockSize() { * @param builder the builder instance */ public void applyClientConfiguration(String account, DataLakeFileSystemClientBuilder builder) { - String sasToken = adlsSasTokens.get(account); - if (sasToken != null && !sasToken.isEmpty()) { - builder.sasToken(sasToken); - } else if (namedKeyCreds != null) { - builder.credential( - new StorageSharedKeyCredential(namedKeyCreds.getKey(), namedKeyCreds.getValue())); - } else { - builder.credential(new DefaultAzureCredentialBuilder().build()); + if (!adlsRefreshCredentialsEnabled || Strings.isNullOrEmpty(adlsRefreshCredentialsEndpoint)) { + String sasToken = adlsSasTokens.get(account); + if (sasToken != null && !sasToken.isEmpty()) { + builder.sasToken(sasToken); + } else if (namedKeyCreds != null) { + builder.credential( + new StorageSharedKeyCredential(namedKeyCreds.getKey(), namedKeyCreds.getValue())); + } else { + builder.credential(new DefaultAzureCredentialBuilder().build()); + } } // apply connection string last so its parameters take precedence, e.g. SAS token diff --git a/azure/src/main/java/org/apache/iceberg/azure/adlsv2/ADLSFileIO.java b/azure/src/main/java/org/apache/iceberg/azure/adlsv2/ADLSFileIO.java index e1bf21f69dc8..2e7f77aca90e 100644 --- a/azure/src/main/java/org/apache/iceberg/azure/adlsv2/ADLSFileIO.java +++ b/azure/src/main/java/org/apache/iceberg/azure/adlsv2/ADLSFileIO.java @@ -27,6 +27,7 @@ import com.azure.storage.file.datalake.models.ListPathsOptions; import java.util.Collections; import java.util.Map; +import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import org.apache.iceberg.azure.AzureProperties; import org.apache.iceberg.common.DynConstructors; @@ -55,6 +56,7 @@ public class ADLSFileIO implements DelegateFileIO { private AzureProperties azureProperties; private MetricsContext metrics = MetricsContext.nullMetrics(); private SerializableMap properties; + private VendedAdlsCredentialProvider vendedAdlsCredentialProvider; /** * No-arg constructor to load the FileIO dynamically. @@ -111,6 +113,9 @@ DataLakeFileSystemClient client(ADLSLocation location) { new DataLakeFileSystemClientBuilder().httpClient(HTTP); location.container().ifPresent(clientBuilder::fileSystemName); + Optional.ofNullable(vendedAdlsCredentialProvider) + .map(p -> new VendedAzureSasCredentialPolicy(location.host(), p)) + .ifPresent(clientBuilder::addPolicy); azureProperties.applyClientConfiguration(location.host(), clientBuilder); return clientBuilder.buildClient(); @@ -126,6 +131,9 @@ public void initialize(Map props) { this.properties = SerializableMap.copyOf(props); this.azureProperties = new AzureProperties(properties); initMetrics(properties); + this.azureProperties + .vendedAdlsCredentialProvider() + .ifPresent((provider -> this.vendedAdlsCredentialProvider = provider)); } @SuppressWarnings("CatchBlockLogException") @@ -212,4 +220,13 @@ public void deletePrefix(String prefix) { } } } + + @Override + public void close() { + if (vendedAdlsCredentialProvider != null) { + vendedAdlsCredentialProvider.close(); + } + + DelegateFileIO.super.close(); + } } diff --git a/azure/src/main/java/org/apache/iceberg/azure/adlsv2/VendedAdlsCredentialProvider.java b/azure/src/main/java/org/apache/iceberg/azure/adlsv2/VendedAdlsCredentialProvider.java new file mode 100644 index 000000000000..3a03a5824c58 --- /dev/null +++ b/azure/src/main/java/org/apache/iceberg/azure/adlsv2/VendedAdlsCredentialProvider.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.azure.adlsv2; + +import com.azure.core.credential.AccessToken; +import com.azure.core.credential.SimpleTokenCache; +import java.io.IOException; +import java.io.Serializable; +import java.io.UncheckedIOException; +import java.time.Instant; +import java.time.ZoneOffset; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.azure.AzureProperties; +import org.apache.iceberg.io.CloseableGroup; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.rest.ErrorHandlers; +import org.apache.iceberg.rest.HTTPClient; +import org.apache.iceberg.rest.RESTClient; +import org.apache.iceberg.rest.auth.AuthManager; +import org.apache.iceberg.rest.auth.AuthManagers; +import org.apache.iceberg.rest.auth.AuthSession; +import org.apache.iceberg.rest.credentials.Credential; +import org.apache.iceberg.rest.responses.LoadCredentialsResponse; +import org.apache.iceberg.util.SerializableMap; +import reactor.core.publisher.Mono; + +public class VendedAdlsCredentialProvider implements Serializable, AutoCloseable { + + public static final String URI = "credentials.uri"; + + private final SerializableMap properties; + private final String credentialsEndpoint; + private final String catalogEndpoint; + private transient volatile Map sasCredentialByAccount; + private transient volatile HTTPClient client; + private transient AuthManager authManager; + private transient AuthSession authSession; + + public VendedAdlsCredentialProvider(Map properties) { + Preconditions.checkArgument(null != properties, "Invalid properties: null"); + Preconditions.checkArgument(null != properties.get(URI), "Invalid credentials endpoint: null"); + Preconditions.checkArgument( + null != properties.get(CatalogProperties.URI), "Invalid catalog endpoint: null"); + this.properties = SerializableMap.copyOf(properties); + this.credentialsEndpoint = properties.get(URI); + this.catalogEndpoint = properties.get(CatalogProperties.URI); + } + + String credentialForAccount(String storageAccount) { + return sasCredentialByAccount() + .computeIfAbsent( + storageAccount, + ignored -> + new SimpleTokenCache( + () -> Mono.fromSupplier(() -> sasTokenForAccount(storageAccount)))) + .getToken() + .map(AccessToken::getToken) + .block(); + } + + private AccessToken sasTokenForAccount(String storageAccount) { + LoadCredentialsResponse response = fetchCredentials(); + List adlsCredentials = + response.credentials().stream() + .filter(c -> c.prefix().contains(storageAccount)) + .collect(Collectors.toList()); + Preconditions.checkState( + !adlsCredentials.isEmpty(), + String.format("Invalid ADLS Credentials for storage-account %s: empty", storageAccount)); + Preconditions.checkState( + adlsCredentials.size() == 1, + "Invalid ADLS Credentials: only one ADLS credential should exist per storage-account"); + + Credential adlsCredential = adlsCredentials.get(0); + checkCredential(adlsCredential, AzureProperties.ADLS_SAS_TOKEN_PREFIX + storageAccount); + checkCredential( + adlsCredential, AzureProperties.ADLS_SAS_TOKEN_EXPIRES_AT_MS_PREFIX + storageAccount); + + String sasToken = + adlsCredential.config().get(AzureProperties.ADLS_SAS_TOKEN_PREFIX + storageAccount); + Instant tokenExpiresAt = + Instant.ofEpochMilli( + Long.parseLong( + adlsCredential + .config() + .get(AzureProperties.ADLS_SAS_TOKEN_EXPIRES_AT_MS_PREFIX + storageAccount))); + + return new AccessToken(sasToken, tokenExpiresAt.atOffset(ZoneOffset.UTC)); + } + + private Map sasCredentialByAccount() { + if (this.sasCredentialByAccount == null) { + synchronized (this) { + if (this.sasCredentialByAccount == null) { + this.sasCredentialByAccount = Maps.newHashMap(); + } + } + } + return this.sasCredentialByAccount; + } + + private RESTClient httpClient() { + if (null == client) { + synchronized (this) { + if (null == client) { + authManager = AuthManagers.loadAuthManager("adls-credentials-refresh", properties); + HTTPClient httpClient = HTTPClient.builder(properties).uri(catalogEndpoint).build(); + authSession = authManager.catalogSession(httpClient, properties); + client = httpClient.withAuthSession(authSession); + } + } + } + + return client; + } + + private LoadCredentialsResponse fetchCredentials() { + return httpClient() + .get( + credentialsEndpoint, + null, + LoadCredentialsResponse.class, + Map.of(), + ErrorHandlers.defaultErrorHandler()); + } + + private void checkCredential(Credential credential, String property) { + Preconditions.checkState( + credential.config().containsKey(property), + "Invalid ADLS Credentials: %s not set", + property); + } + + @Override + public void close() { + CloseableGroup closeableGroup = new CloseableGroup(); + closeableGroup.addCloseable(authSession); + closeableGroup.addCloseable(authManager); + closeableGroup.addCloseable(client); + closeableGroup.setSuppressCloseFailure(true); + try { + closeableGroup.close(); + } catch (IOException e) { + throw new UncheckedIOException("Failed to close the VendedAdlsCredentialProvider", e); + } + } +} diff --git a/azure/src/main/java/org/apache/iceberg/azure/adlsv2/VendedAzureSasCredentialPolicy.java b/azure/src/main/java/org/apache/iceberg/azure/adlsv2/VendedAzureSasCredentialPolicy.java new file mode 100644 index 000000000000..b5ac3d9c405d --- /dev/null +++ b/azure/src/main/java/org/apache/iceberg/azure/adlsv2/VendedAzureSasCredentialPolicy.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.azure.adlsv2; + +import com.azure.core.credential.AzureSasCredential; +import com.azure.core.http.HttpPipelineCallContext; +import com.azure.core.http.HttpPipelineNextPolicy; +import com.azure.core.http.HttpPipelineNextSyncPolicy; +import com.azure.core.http.HttpResponse; +import com.azure.core.http.policy.AzureSasCredentialPolicy; +import com.azure.core.http.policy.HttpPipelinePolicy; +import reactor.core.publisher.Mono; + +class VendedAzureSasCredentialPolicy implements HttpPipelinePolicy { + private final String account; + private final VendedAdlsCredentialProvider vendedAdlsCredentialProvider; + private AzureSasCredential azureSasCredential; + private AzureSasCredentialPolicy azureSasCredentialPolicy; + + VendedAzureSasCredentialPolicy( + String account, VendedAdlsCredentialProvider vendedAdlsCredentialProvider) { + this.account = account; + this.vendedAdlsCredentialProvider = vendedAdlsCredentialProvider; + } + + @Override + public Mono process( + HttpPipelineCallContext httpPipelineCallContext, + HttpPipelineNextPolicy httpPipelineNextPolicy) { + maybeUpdateCredential(); + return azureSasCredentialPolicy.process(httpPipelineCallContext, httpPipelineNextPolicy); + } + + @Override + public HttpResponse processSync( + HttpPipelineCallContext context, HttpPipelineNextSyncPolicy next) { + maybeUpdateCredential(); + return azureSasCredentialPolicy.processSync(context, next); + } + + private void maybeUpdateCredential() { + String sasToken = vendedAdlsCredentialProvider.credentialForAccount(account); + if (azureSasCredential == null) { + this.azureSasCredential = new AzureSasCredential(sasToken); + this.azureSasCredentialPolicy = new AzureSasCredentialPolicy(azureSasCredential, false); + } else { + azureSasCredential.update(sasToken); + } + } +} diff --git a/azure/src/test/java/org/apache/iceberg/azure/AzurePropertiesTest.java b/azure/src/test/java/org/apache/iceberg/azure/AzurePropertiesTest.java index 6b8287c44e58..153c088c4f84 100644 --- a/azure/src/test/java/org/apache/iceberg/azure/AzurePropertiesTest.java +++ b/azure/src/test/java/org/apache/iceberg/azure/AzurePropertiesTest.java @@ -20,6 +20,8 @@ import static org.apache.iceberg.azure.AzureProperties.ADLS_CONNECTION_STRING_PREFIX; import static org.apache.iceberg.azure.AzureProperties.ADLS_READ_BLOCK_SIZE; +import static org.apache.iceberg.azure.AzureProperties.ADLS_REFRESH_CREDENTIALS_ENABLED; +import static org.apache.iceberg.azure.AzureProperties.ADLS_REFRESH_CREDENTIALS_ENDPOINT; import static org.apache.iceberg.azure.AzureProperties.ADLS_SAS_TOKEN_PREFIX; import static org.apache.iceberg.azure.AzureProperties.ADLS_SHARED_KEY_ACCOUNT_KEY; import static org.apache.iceberg.azure.AzureProperties.ADLS_SHARED_KEY_ACCOUNT_NAME; @@ -32,10 +34,15 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import com.azure.core.credential.AzureSasCredential; import com.azure.core.credential.TokenCredential; +import com.azure.identity.DefaultAzureCredential; import com.azure.storage.common.StorageSharedKeyCredential; import com.azure.storage.file.datalake.DataLakeFileSystemClientBuilder; +import java.util.Optional; +import org.apache.iceberg.CatalogProperties; import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.azure.adlsv2.VendedAdlsCredentialProvider; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.junit.jupiter.api.Test; @@ -71,6 +78,45 @@ public void testWithSasToken() { verify(clientBuilder, never()).credential(any(StorageSharedKeyCredential.class)); } + @Test + public void testWithRefreshCredentialsEndpoint() { + AzureProperties props = + new AzureProperties( + ImmutableMap.of( + ADLS_REFRESH_CREDENTIALS_ENDPOINT, + "endpoint", + CatalogProperties.URI, + "catalog-endpoint")); + + DataLakeFileSystemClientBuilder clientBuilder = mock(DataLakeFileSystemClientBuilder.class); + props.applyClientConfiguration("account1", clientBuilder); + Optional vendedAdlsCredentialProvider = + props.vendedAdlsCredentialProvider(); + + verify(clientBuilder, never()).credential(any(AzureSasCredential.class)); + verify(clientBuilder, never()).sasToken(any()); + verify(clientBuilder, never()).credential(any(StorageSharedKeyCredential.class)); + assertThat(vendedAdlsCredentialProvider).isPresent(); + } + + @Test + public void testWithRefreshCredentialsEndpointDisabled() { + AzureProperties props = + new AzureProperties( + ImmutableMap.of( + ADLS_REFRESH_CREDENTIALS_ENDPOINT, + "endpoint", + ADLS_REFRESH_CREDENTIALS_ENABLED, + "false")); + + DataLakeFileSystemClientBuilder clientBuilder = mock(DataLakeFileSystemClientBuilder.class); + props.applyClientConfiguration("account1", clientBuilder); + Optional vendedAdlsCredentialProvider = + props.vendedAdlsCredentialProvider(); + verify(clientBuilder).credential(any(DefaultAzureCredential.class)); + assertThat(vendedAdlsCredentialProvider).isEmpty(); + } + @Test public void testNoMatchingSasToken() { AzureProperties props = diff --git a/azure/src/test/java/org/apache/iceberg/azure/adlsv2/BaseVendedCredentialsTest.java b/azure/src/test/java/org/apache/iceberg/azure/adlsv2/BaseVendedCredentialsTest.java new file mode 100644 index 000000000000..3fe969d00d9b --- /dev/null +++ b/azure/src/test/java/org/apache/iceberg/azure/adlsv2/BaseVendedCredentialsTest.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.azure.adlsv2; + +import static org.mockserver.integration.ClientAndServer.startClientAndServer; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.mockserver.integration.ClientAndServer; + +public class BaseVendedCredentialsTest { + protected static String baseUri; + protected static ClientAndServer mockServer; + + @BeforeAll + public static void beforeAll() { + // Allocate port dynamically as there could be parallel test executions. + mockServer = startClientAndServer(0); + int mockServerPort = mockServer.getPort(); + baseUri = String.format("http://127.0.0.1:%d", mockServerPort); + } + + @AfterAll + public static void stopServer() { + mockServer.stop(); + } + + @BeforeEach + public void before() { + mockServer.reset(); + } +} diff --git a/azure/src/test/java/org/apache/iceberg/azure/adlsv2/VendedAdlsCredentialProviderTest.java b/azure/src/test/java/org/apache/iceberg/azure/adlsv2/VendedAdlsCredentialProviderTest.java new file mode 100644 index 000000000000..f17b1ea5a685 --- /dev/null +++ b/azure/src/test/java/org/apache/iceberg/azure/adlsv2/VendedAdlsCredentialProviderTest.java @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.azure.adlsv2; + +import static org.apache.iceberg.azure.AzureProperties.ADLS_SAS_TOKEN_EXPIRES_AT_MS_PREFIX; +import static org.apache.iceberg.azure.AzureProperties.ADLS_SAS_TOKEN_PREFIX; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockserver.model.HttpRequest.request; +import static org.mockserver.model.HttpResponse.response; + +import com.azure.core.http.HttpMethod; +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Map; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.exceptions.RESTException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.rest.credentials.Credential; +import org.apache.iceberg.rest.credentials.ImmutableCredential; +import org.apache.iceberg.rest.responses.ImmutableLoadCredentialsResponse; +import org.apache.iceberg.rest.responses.LoadCredentialsResponse; +import org.apache.iceberg.rest.responses.LoadCredentialsResponseParser; +import org.junit.jupiter.api.Test; +import org.mockserver.model.HttpRequest; +import org.mockserver.model.HttpResponse; +import org.mockserver.verify.VerificationTimes; + +public class VendedAdlsCredentialProviderTest extends BaseVendedCredentialsTest { + private static final String CREDENTIALS_URI = String.format("%s%s", baseUri, "/v1/credentials"); + private static final String CATALOG_URI = String.format("%s%s", baseUri, "/v1/"); + private static final String STORAGE_ACCOUNT = "account1"; + private static final String CREDENTIAL_PREFIX = + "abfs://container@account1.dfs.core.windows.net/dir"; + private static final String STORAGE_ACCOUNT_2 = "account2"; + private static final String CREDENTIAL_PREFIX_2 = + "abfs://container@account2.dfs.core.windows.net/dir"; + private static final Map PROPERTIES = + ImmutableMap.of( + VendedAdlsCredentialProvider.URI, CREDENTIALS_URI, CatalogProperties.URI, CATALOG_URI); + + @Test + public void invalidOrMissingUri() { + assertThatThrownBy(() -> new VendedAdlsCredentialProvider(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid properties: null"); + assertThatThrownBy( + () -> + new VendedAdlsCredentialProvider( + ImmutableMap.of(CatalogProperties.URI, CATALOG_URI))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid credentials endpoint: null"); + assertThatThrownBy( + () -> + new VendedAdlsCredentialProvider( + ImmutableMap.of(VendedAdlsCredentialProvider.URI, CREDENTIALS_URI))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid catalog endpoint: null"); + + try (VendedAdlsCredentialProvider provider = + new VendedAdlsCredentialProvider( + ImmutableMap.of( + VendedAdlsCredentialProvider.URI, + "invalid uri", + CatalogProperties.URI, + CATALOG_URI))) { + assertThatThrownBy(() -> provider.credentialForAccount(STORAGE_ACCOUNT)) + .isInstanceOf(RESTException.class) + .hasMessageStartingWith( + "Failed to create request URI from base %sinvalid uri", CATALOG_URI); + } + } + + @Test + public void noADLSCredentials() { + HttpRequest mockRequest = request("/v1/credentials").withMethod(HttpMethod.GET.name()); + + HttpResponse mockResponse = + response( + LoadCredentialsResponseParser.toJson( + ImmutableLoadCredentialsResponse.builder().build())) + .withStatusCode(200); + mockServer.when(mockRequest).respond(mockResponse); + + try (VendedAdlsCredentialProvider provider = new VendedAdlsCredentialProvider(PROPERTIES)) { + assertThatThrownBy(() -> provider.credentialForAccount(STORAGE_ACCOUNT)) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Invalid ADLS Credentials for storage-account account1: empty"); + } + } + + @Test + public void expirationNotSet() { + HttpRequest mockRequest = request("/v1/credentials").withMethod(HttpMethod.GET.name()); + LoadCredentialsResponse response = + ImmutableLoadCredentialsResponse.builder() + .addCredentials( + ImmutableCredential.builder() + .prefix(CREDENTIAL_PREFIX) + .config( + ImmutableMap.of(ADLS_SAS_TOKEN_PREFIX + STORAGE_ACCOUNT, "randomSasToken")) + .build()) + .build(); + HttpResponse mockResponse = + response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200); + mockServer.when(mockRequest).respond(mockResponse); + + try (VendedAdlsCredentialProvider provider = new VendedAdlsCredentialProvider(PROPERTIES)) { + assertThatThrownBy(() -> provider.credentialForAccount(STORAGE_ACCOUNT)) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Invalid ADLS Credentials: adls.sas-token-expires-at-ms.account1 not set"); + } + } + + @Test + public void nonExpiredSasToken() { + HttpRequest mockRequest = request("/v1/credentials").withMethod(HttpMethod.GET.name()); + Credential credential = + ImmutableCredential.builder() + .prefix(CREDENTIAL_PREFIX) + .config( + ImmutableMap.of( + ADLS_SAS_TOKEN_PREFIX + STORAGE_ACCOUNT, + "randomSasToken", + ADLS_SAS_TOKEN_EXPIRES_AT_MS_PREFIX + STORAGE_ACCOUNT, + Long.toString(Instant.now().plus(1, ChronoUnit.HOURS).toEpochMilli()))) + .build(); + LoadCredentialsResponse response = + ImmutableLoadCredentialsResponse.builder().addCredentials(credential).build(); + HttpResponse mockResponse = + response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200); + mockServer.when(mockRequest).respond(mockResponse); + + try (VendedAdlsCredentialProvider provider = new VendedAdlsCredentialProvider(PROPERTIES)) { + String azureSasCredential = provider.credentialForAccount(STORAGE_ACCOUNT); + assertThat(azureSasCredential) + .isEqualTo(credential.config().get(ADLS_SAS_TOKEN_PREFIX + STORAGE_ACCOUNT)); + + for (int i = 0; i < 5; i++) { + // resolving credentials multiple times should not hit the credentials endpoint again + assertThat(provider.credentialForAccount(STORAGE_ACCOUNT)).isSameAs(azureSasCredential); + } + } + + mockServer.verify(mockRequest, VerificationTimes.once()); + } + + @Test + public void expiredSasToken() { + HttpRequest mockRequest = request("/v1/credentials").withMethod(HttpMethod.GET.name()); + Credential credential = + ImmutableCredential.builder() + .prefix(CREDENTIAL_PREFIX) + .config( + ImmutableMap.of( + ADLS_SAS_TOKEN_PREFIX + STORAGE_ACCOUNT, + "randomSasToken", + ADLS_SAS_TOKEN_EXPIRES_AT_MS_PREFIX + STORAGE_ACCOUNT, + Long.toString(Instant.now().minus(1, ChronoUnit.MINUTES).toEpochMilli()))) + .build(); + LoadCredentialsResponse response = + ImmutableLoadCredentialsResponse.builder().addCredentials(credential).build(); + HttpResponse mockResponse = + response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200); + mockServer.when(mockRequest).respond(mockResponse); + + try (VendedAdlsCredentialProvider provider = new VendedAdlsCredentialProvider(PROPERTIES)) { + String azureSasCredential = provider.credentialForAccount(STORAGE_ACCOUNT); + assertThat(azureSasCredential) + .isEqualTo(credential.config().get(ADLS_SAS_TOKEN_PREFIX + STORAGE_ACCOUNT)); + + // resolving credentials multiple times should hit the credentials endpoint again + String refreshedAzureSasCredential = provider.credentialForAccount(STORAGE_ACCOUNT); + assertThat(refreshedAzureSasCredential) + .isEqualTo(credential.config().get(ADLS_SAS_TOKEN_PREFIX + STORAGE_ACCOUNT)); + } + + mockServer.verify(mockRequest, VerificationTimes.exactly(2)); + } + + @Test + public void multipleADLSCredentialsPerStorageAccount() { + HttpRequest mockRequest = request("/v1/credentials").withMethod(HttpMethod.GET.name()); + Credential credential1 = + ImmutableCredential.builder() + .prefix(CREDENTIAL_PREFIX) + .config( + ImmutableMap.of( + ADLS_SAS_TOKEN_PREFIX + STORAGE_ACCOUNT, + "randomSasToken1", + ADLS_SAS_TOKEN_EXPIRES_AT_MS_PREFIX + STORAGE_ACCOUNT, + Long.toString(Instant.now().plus(1, ChronoUnit.HOURS).toEpochMilli()))) + .build(); + Credential credential2 = + ImmutableCredential.builder() + .prefix(CREDENTIAL_PREFIX + "/dir2") + .config( + ImmutableMap.of( + ADLS_SAS_TOKEN_PREFIX + STORAGE_ACCOUNT, + "randomSasToken2", + ADLS_SAS_TOKEN_EXPIRES_AT_MS_PREFIX + STORAGE_ACCOUNT, + Long.toString(Instant.now().plus(1, ChronoUnit.HOURS).toEpochMilli()))) + .build(); + LoadCredentialsResponse response = + ImmutableLoadCredentialsResponse.builder().addCredentials(credential1, credential2).build(); + HttpResponse mockResponse = + response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200); + mockServer.when(mockRequest).respond(mockResponse); + + try (VendedAdlsCredentialProvider provider = new VendedAdlsCredentialProvider(PROPERTIES)) { + assertThatThrownBy(() -> provider.credentialForAccount(STORAGE_ACCOUNT)) + .isInstanceOf(IllegalStateException.class) + .hasMessage( + "Invalid ADLS Credentials: only one ADLS credential should exist per storage-account"); + } + } + + @Test + public void multipleStorageAccounts() { + HttpRequest mockRequest = request("/v1/credentials").withMethod(HttpMethod.GET.name()); + Credential credential1 = + ImmutableCredential.builder() + .prefix(CREDENTIAL_PREFIX) + .config( + ImmutableMap.of( + ADLS_SAS_TOKEN_PREFIX + STORAGE_ACCOUNT, + "randomSasToken1", + ADLS_SAS_TOKEN_EXPIRES_AT_MS_PREFIX + STORAGE_ACCOUNT, + Long.toString(Instant.now().plus(1, ChronoUnit.HOURS).toEpochMilli()))) + .build(); + Credential credential2 = + ImmutableCredential.builder() + .prefix(CREDENTIAL_PREFIX_2) + .config( + ImmutableMap.of( + ADLS_SAS_TOKEN_PREFIX + STORAGE_ACCOUNT_2, + "randomSasToken2", + ADLS_SAS_TOKEN_EXPIRES_AT_MS_PREFIX + STORAGE_ACCOUNT_2, + Long.toString(Instant.now().plus(1, ChronoUnit.HOURS).toEpochMilli()))) + .build(); + LoadCredentialsResponse response = + ImmutableLoadCredentialsResponse.builder().addCredentials(credential1, credential2).build(); + HttpResponse mockResponse = + response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200); + mockServer.when(mockRequest).respond(mockResponse); + + try (VendedAdlsCredentialProvider provider = new VendedAdlsCredentialProvider(PROPERTIES)) { + String azureSasCredential1 = provider.credentialForAccount(STORAGE_ACCOUNT); + String azureSasCredential2 = provider.credentialForAccount(STORAGE_ACCOUNT_2); + assertThat(azureSasCredential1).isNotSameAs(azureSasCredential2); + assertThat(azureSasCredential1) + .isEqualTo(credential1.config().get(ADLS_SAS_TOKEN_PREFIX + STORAGE_ACCOUNT)); + assertThat(azureSasCredential2) + .isEqualTo(credential2.config().get(ADLS_SAS_TOKEN_PREFIX + STORAGE_ACCOUNT_2)); + } + } + + @Test + public void serializableTest() throws IOException, ClassNotFoundException { + HttpRequest mockRequest = request("/v1/credentials").withMethod(HttpMethod.GET.name()); + Credential credential = + ImmutableCredential.builder() + .prefix(CREDENTIAL_PREFIX) + .config( + ImmutableMap.of( + ADLS_SAS_TOKEN_PREFIX + STORAGE_ACCOUNT, + "randomSasToken", + ADLS_SAS_TOKEN_EXPIRES_AT_MS_PREFIX + STORAGE_ACCOUNT, + Long.toString(Instant.now().plus(1, ChronoUnit.HOURS).toEpochMilli()))) + .build(); + LoadCredentialsResponse response = + ImmutableLoadCredentialsResponse.builder().addCredentials(credential).build(); + HttpResponse mockResponse = + response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200); + mockServer.when(mockRequest).respond(mockResponse); + + try (VendedAdlsCredentialProvider provider = new VendedAdlsCredentialProvider(PROPERTIES)) { + String azureSasCredential = provider.credentialForAccount(STORAGE_ACCOUNT); + assertThat(azureSasCredential) + .isEqualTo(credential.config().get(ADLS_SAS_TOKEN_PREFIX + STORAGE_ACCOUNT)); + + VendedAdlsCredentialProvider deserializedProvider = TestHelpers.roundTripSerialize(provider); + String reGeneratedAzureSasCredential = + deserializedProvider.credentialForAccount(STORAGE_ACCOUNT); + + assertThat(azureSasCredential).isNotSameAs(reGeneratedAzureSasCredential); + } + + mockServer.verify(mockRequest, VerificationTimes.exactly(2)); + } +} diff --git a/azure/src/test/java/org/apache/iceberg/azure/adlsv2/VendedAzureSasCredentialPolicyTest.java b/azure/src/test/java/org/apache/iceberg/azure/adlsv2/VendedAzureSasCredentialPolicyTest.java new file mode 100644 index 000000000000..8aa2369c8270 --- /dev/null +++ b/azure/src/test/java/org/apache/iceberg/azure/adlsv2/VendedAzureSasCredentialPolicyTest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.azure.adlsv2; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockserver.model.HttpRequest.request; + +import com.azure.core.http.HttpClient; +import com.azure.core.http.HttpMethod; +import com.azure.storage.file.datalake.DataLakeFileSystemClient; +import com.azure.storage.file.datalake.DataLakeFileSystemClientBuilder; +import com.azure.storage.file.datalake.models.DataLakeStorageException; +import org.junit.jupiter.api.Test; +import org.mockserver.model.HttpRequest; +import org.mockserver.model.HttpResponse; +import org.mockserver.verify.VerificationTimes; + +public class VendedAzureSasCredentialPolicyTest extends BaseVendedCredentialsTest { + + private static final String STORAGE_ACCOUNT = "account1"; + private static final String ACCOUNT_ENDPOINT = String.format("%s/%s", baseUri, STORAGE_ACCOUNT); + + @Test + public void vendedSasTokenAsRequestQueryParameters() { + String filePath = "file1"; + String container = "container1"; + String validSasToken = "tokenInstance=1"; + String expiredSasToken = "tokenInstance=2"; + + VendedAdlsCredentialProvider vendedAdlsCredentialProvider = + mock(VendedAdlsCredentialProvider.class); + VendedAzureSasCredentialPolicy vendedAzureSasCredentialPolicy = + new VendedAzureSasCredentialPolicy(STORAGE_ACCOUNT, vendedAdlsCredentialProvider); + + DataLakeFileSystemClient client = + new DataLakeFileSystemClientBuilder() + .httpClient(HttpClient.createDefault()) + .addPolicy(vendedAzureSasCredentialPolicy) + .fileSystemName(container) + .endpoint(ACCOUNT_ENDPOINT) + .buildClient(); + + String requestPath = String.format("/%s/%s/%s", STORAGE_ACCOUNT, container, filePath); + + HttpRequest mockRequestWithValidSasToken = + request(requestPath) + .withMethod(HttpMethod.HEAD.name()) + .withQueryStringParameter("tokenInstance", "1"); + mockServer + .when(mockRequestWithValidSasToken) + .respond(HttpResponse.response().withStatusCode(200)); + + HttpRequest mockRequestWithExpiredSasToken = + request(requestPath) + .withMethod(HttpMethod.HEAD.name()) + .withQueryStringParameter("tokenInstance", "2"); + mockServer + .when(mockRequestWithExpiredSasToken) + .respond(HttpResponse.response().withStatusCode(403)); + + when(vendedAdlsCredentialProvider.credentialForAccount(STORAGE_ACCOUNT)) + .thenReturn(validSasToken); + assertThat(client.getFileClient(filePath).exists()).isTrue(); + mockServer.verify(mockRequestWithValidSasToken, VerificationTimes.exactly(1)); + + when(vendedAdlsCredentialProvider.credentialForAccount(STORAGE_ACCOUNT)) + .thenReturn(expiredSasToken); + + // Every new request of the same client fetches latest SasToken credentials from + // VendedAdlsCredentialProvider to build http request query parameters. + assertThatThrownBy(() -> client.getFileClient(filePath).exists()) + .isInstanceOf(DataLakeStorageException.class) + .hasMessageContaining( + "If you are using a SAS token, and the server returned an error message that says 'Signature did not match'"); + mockServer.verify(mockRequestWithExpiredSasToken, VerificationTimes.atLeast(1)); + } +} diff --git a/build.gradle b/build.gradle index 9131f863572b..0b8e9b7b347e 100644 --- a/build.gradle +++ b/build.gradle @@ -550,6 +550,8 @@ project(':iceberg-azure') { testImplementation project(path: ':iceberg-api', configuration: 'testArtifacts') testImplementation libs.esotericsoftware.kryo testImplementation libs.testcontainers + testImplementation libs.mockserver.netty + testImplementation libs.mockserver.client.java } sourceSets {