Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
PradhanPrerak committed Jun 17, 2024
1 parent 7f58d07 commit 917e214
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -778,18 +778,18 @@ private Constants() {
/**
* List of custom headers to be set on the service client.
* Multiple parameters can be used to specify custom headers.
* fs.s3a.s3.custom.headers - headers to add on all the s3 requests.
* fs.s3a.sts.custom.headers - headers to add on all the sts requests.
* fs.s3a.client.s3.custom.headers - headers to add on all the s3 requests.
* fs.s3a.client.sts.custom.headers - headers to add on all the sts requests.
* Examples
* CustomHeader {@literal ->} 'Header1:Value1'
* CustomHeaders {@literal ->} 'Header1=Value1:Value2,Header2=Value1'
*/
public static final String CUSTOM_HEADERS_STS =
"fs.s3a." + Constants.AWS_SERVICE_IDENTIFIER_STS.toLowerCase()
"fs.s3a.client." + Constants.AWS_SERVICE_IDENTIFIER_STS.toLowerCase()
+ ".custom.headers";

public static final String CUSTOM_HEADERS_S3 =
"fs.s3a." + Constants.AWS_SERVICE_IDENTIFIER_S3.toLowerCase()
"fs.s3a.client." + Constants.AWS_SERVICE_IDENTIFIER_S3.toLowerCase()
+ ".custom.headers";

public static final String S3N_FOLDER_SUFFIX = "_$folder$";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import org.slf4j.Logger;
Expand Down Expand Up @@ -433,25 +434,12 @@ private static void initRequestHeaders(Configuration conf,
// Nothing to do. The original signer override is already setup
}
if (configKey != null) {
String[] customHeaders = conf.getTrimmedStrings(configKey);
if (customHeaders == null || customHeaders.length == 0) {
LOG.debug("No custom headers specified");
return;
}

for (String customHeader : customHeaders) {
String[] parts = customHeader.split("=");
if (parts.length != 2) {
String message = "Invalid format (Expected header1=value1:value2,header2=value1) for Header: ["
+ customHeader
+ "]";
LOG.error(message);
throw new IllegalArgumentException(message);
}

List<String> values = Arrays.asList(parts[1].split(":"));
clientConfig.putHeader(parts[0], values);
}
Map<String, String> awsClientCustomHeadersMap =
S3AUtils.getTrimmedStringCollectionSplitByEquals(conf, configKey);
awsClientCustomHeadersMap.forEach((header, valueString) -> {
List<String> headerValues = Arrays.asList(valueString.split(":"));
clientConfig.putHeader(header, headerValues);
});
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@

package org.apache.hadoop.fs.s3a.impl;

import java.io.IOException;
import java.time.Duration;
import java.util.Arrays;

import org.apache.hadoop.util.Lists;
import org.assertj.core.api.Assertions;
import org.junit.After;
import org.junit.Test;
Expand All @@ -30,10 +32,14 @@
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.test.AbstractHadoopTestBase;

import static org.apache.hadoop.fs.s3a.Constants.AWS_SERVICE_IDENTIFIER_S3;
import static org.apache.hadoop.fs.s3a.Constants.AWS_SERVICE_IDENTIFIER_STS;
import static org.apache.hadoop.fs.s3a.Constants.CONNECTION_ACQUISITION_TIMEOUT;
import static org.apache.hadoop.fs.s3a.Constants.CONNECTION_IDLE_TIME;
import static org.apache.hadoop.fs.s3a.Constants.CONNECTION_KEEPALIVE;
import static org.apache.hadoop.fs.s3a.Constants.CONNECTION_TTL;
import static org.apache.hadoop.fs.s3a.Constants.CUSTOM_HEADERS_S3;
import static org.apache.hadoop.fs.s3a.Constants.CUSTOM_HEADERS_STS;
import static org.apache.hadoop.fs.s3a.Constants.DEFAULT_CONNECTION_ACQUISITION_TIMEOUT_DURATION;
import static org.apache.hadoop.fs.s3a.Constants.DEFAULT_CONNECTION_IDLE_TIME_DURATION;
import static org.apache.hadoop.fs.s3a.Constants.DEFAULT_CONNECTION_KEEPALIVE;
Expand All @@ -48,6 +54,7 @@
import static org.apache.hadoop.fs.s3a.Constants.REQUEST_TIMEOUT;
import static org.apache.hadoop.fs.s3a.Constants.SOCKET_TIMEOUT;
import static org.apache.hadoop.fs.s3a.impl.AWSClientConfig.createApiConnectionSettings;
import static org.apache.hadoop.fs.s3a.impl.AWSClientConfig.createClientConfigBuilder;
import static org.apache.hadoop.fs.s3a.impl.AWSClientConfig.createConnectionSettings;
import static org.apache.hadoop.fs.s3a.impl.ConfigurationHelper.enforceMinimumDuration;

Expand Down Expand Up @@ -201,4 +208,56 @@ public void testCreateApiConnectionSettingsDefault() {
private void setOptionsToValue(String value, Configuration conf, String... keys) {
Arrays.stream(keys).forEach(key -> conf.set(key, value));
}

/**
* if {@link org.apache.hadoop.fs.s3a.Constants#CUSTOM_HEADERS_STS} is set,
* verify that returned client configuration has desired headers set.
*/
@Test
public void testInitRequestHeadersForSTS() throws IOException {
final Configuration conf = new Configuration();
conf.set(CUSTOM_HEADERS_STS, "foo=bar:baz,qux=quux");
Assertions.assertThat(conf.get(CUSTOM_HEADERS_S3))
.describedAs("Custom client headers for s3 %s", CUSTOM_HEADERS_S3)
.isNull();

Assertions.assertThat(createClientConfigBuilder(conf, AWS_SERVICE_IDENTIFIER_S3).headers().size())
.describedAs("Count of S3 client headers")
.isEqualTo(0);
Assertions.assertThat(createClientConfigBuilder(conf, AWS_SERVICE_IDENTIFIER_STS).headers().size())
.describedAs("Count of STS client headers")
.isEqualTo(2);
Assertions.assertThat(createClientConfigBuilder(conf, AWS_SERVICE_IDENTIFIER_STS).headers().get("foo"))
.describedAs("STS client 'foo' header value")
.isEqualTo(Lists.newArrayList("bar", "baz"));
Assertions.assertThat(createClientConfigBuilder(conf, AWS_SERVICE_IDENTIFIER_STS).headers().get("qux"))
.describedAs("STS client 'qux' header value")
.isEqualTo(Lists.newArrayList("quux"));
}

/**
* if {@link org.apache.hadoop.fs.s3a.Constants#CUSTOM_HEADERS_S3} is set,
* verify that returned client configuration has desired headers set.
*/
@Test
public void testInitRequestHeadersForS3() throws IOException {
final Configuration conf = new Configuration();
conf.set(CUSTOM_HEADERS_S3, "foo=bar:baz,qux=quux");
Assertions.assertThat(conf.get(CUSTOM_HEADERS_STS))
.describedAs("Custom client headers for STS %s", CUSTOM_HEADERS_STS)
.isNull();

Assertions.assertThat(createClientConfigBuilder(conf, AWS_SERVICE_IDENTIFIER_STS).headers().size())
.describedAs("Count of STS client headers")
.isEqualTo(0);
Assertions.assertThat(createClientConfigBuilder(conf, AWS_SERVICE_IDENTIFIER_S3).headers().size())
.describedAs("Count of S3 client headers")
.isEqualTo(2);
Assertions.assertThat(createClientConfigBuilder(conf, AWS_SERVICE_IDENTIFIER_S3).headers().get("foo"))
.describedAs("S3 client 'foo' header value")
.isEqualTo(Lists.newArrayList("bar", "baz"));
Assertions.assertThat(createClientConfigBuilder(conf, AWS_SERVICE_IDENTIFIER_S3).headers().get("qux"))
.describedAs("S3 client 'qux' header value")
.isEqualTo(Lists.newArrayList("quux"));
}
}

0 comments on commit 917e214

Please sign in to comment.