Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.iceberg.rest.ErrorHandlers;
import org.apache.iceberg.rest.HTTPClient;
import org.apache.iceberg.rest.RESTClient;
import org.apache.iceberg.rest.ResourcePaths;
import org.apache.iceberg.rest.auth.OAuth2Properties;
import org.apache.iceberg.rest.auth.OAuth2Util;
import org.apache.iceberg.rest.auth.OAuth2Util.AuthSession;
Expand Down Expand Up @@ -111,6 +112,12 @@ public String credential() {
return properties().get(OAuth2Properties.CREDENTIAL);
}

/** Token endpoint URI to fetch token from if the Rest Catalog is not the authorization server. */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Token endpoint URI" -> "OAuth Server URI" or just "Endpoint"?

@Value.Lazy
public String oauth2ServerUri() {
return properties().getOrDefault(OAuth2Properties.OAUTH2_SERVER_URI, ResourcePaths.tokens());
}

/** A Bearer token supplier which will be used for interaction with the server. */
@Value.Default
public Supplier<String> token() {
Expand Down Expand Up @@ -199,7 +206,8 @@ private AuthSession authSession() {
tokenRefreshExecutor(),
token,
expiresAtMillis(properties()),
new AuthSession(ImmutableMap.of(), token, null, credential(), SCOPE)));
new AuthSession(
ImmutableMap.of(), token, null, credential(), SCOPE, oauth2ServerUri())));
}

if (credentialProvided()) {
Expand All @@ -208,10 +216,12 @@ private AuthSession authSession() {
credential(),
id -> {
AuthSession session =
new AuthSession(ImmutableMap.of(), null, null, credential(), SCOPE);
new AuthSession(
ImmutableMap.of(), null, null, credential(), SCOPE, oauth2ServerUri());
long startTimeMillis = System.currentTimeMillis();
OAuthTokenResponse authResponse =
OAuth2Util.fetchToken(httpClient(), session.headers(), credential(), SCOPE);
OAuth2Util.fetchToken(
httpClient(), session.headers(), credential(), SCOPE, oauth2ServerUri());
return AuthSession.fromTokenResponse(
httpClient(), tokenRefreshExecutor(), authResponse, startTimeMillis, session);
});
Expand Down
29 changes: 16 additions & 13 deletions core/src/main/java/org/apache/iceberg/rest/HTTPClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,24 @@ private static void throwFailure(
}

private URI buildUri(String path, Map<String, String> params) {
String baseUri = String.format("%s/%s", uri, path);
// if full path is provided, use the input path as path
if (path.startsWith("/")) {
throw new RESTException(
"Received a malformed path for a REST request: %s. Paths should not start with /", path);
}
String fullPath =
(path.startsWith("https://") || path.startsWith("http://"))
? path
: String.format("%s/%s", uri, path);
try {
URIBuilder builder = new URIBuilder(baseUri);
URIBuilder builder = new URIBuilder(fullPath);
if (params != null) {
params.forEach(builder::addParameter);
}
return builder.build();
} catch (URISyntaxException e) {
throw new RESTException(
"Failed to create request URI from base %s, params %s", baseUri, params);
"Failed to create request URI from base %s, params %s", fullPath, params);
}
}

Expand All @@ -223,7 +231,7 @@ private URI buildUri(String path, Map<String, String> params) {
*
* @param method - HTTP method, such as GET, POST, HEAD, etc.
* @param queryParams - A map of query parameters
* @param path - URL path to send the request to
* @param path - URI to send the request to
* @param requestBody - Content to place in the request body
* @param responseType - Class of the Response type. Needs to have serializer registered with
* ObjectMapper
Expand All @@ -250,7 +258,7 @@ private <T> T execute(
*
* @param method - HTTP method, such as GET, POST, HEAD, etc.
* @param queryParams - A map of query parameters
* @param path - URL path to send the request to
* @param path - URL to send the request to
* @param requestBody - Content to place in the request body
* @param responseType - Class of the Response type. Needs to have serializer registered with
* ObjectMapper
Expand All @@ -270,11 +278,6 @@ private <T> T execute(
Map<String, String> headers,
Consumer<ErrorResponse> errorHandler,
Consumer<Map<String, String>> responseHeaders) {
if (path.startsWith("/")) {
throw new RESTException(
"Received a malformed path for a REST request: %s. Paths should not start with /", path);
}

HttpUriRequestBase request = new HttpUriRequestBase(method.name(), buildUri(path, queryParams));

if (requestBody instanceof Map) {
Expand Down Expand Up @@ -459,9 +462,9 @@ private Builder(Map<String, String> properties) {
this.properties = properties;
}

public Builder uri(String baseUri) {
Preconditions.checkNotNull(baseUri, "Invalid uri for http client: null");
this.uri = RESTUtil.stripTrailingSlash(baseUri);
public Builder uri(String path) {
Preconditions.checkNotNull(path, "Invalid uri for http client: null");
this.uri = RESTUtil.stripTrailingSlash(path);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,14 @@ public void initialize(String name, Map<String, String> unresolved) {
OAuthTokenResponse authResponse;
String credential = props.get(OAuth2Properties.CREDENTIAL);
String scope = props.getOrDefault(OAuth2Properties.SCOPE, OAuth2Properties.CATALOG_SCOPE);
String oauth2ServerUri =
props.getOrDefault(OAuth2Properties.OAUTH2_SERVER_URI, ResourcePaths.tokens());
try (RESTClient initClient = clientBuilder.apply(props)) {
Map<String, String> initHeaders =
RESTUtil.merge(configHeaders(props), OAuth2Util.authHeaders(initToken));
if (credential != null && !credential.isEmpty()) {
authResponse = OAuth2Util.fetchToken(initClient, initHeaders, credential, scope);
authResponse =
OAuth2Util.fetchToken(initClient, initHeaders, credential, scope, oauth2ServerUri);
Map<String, String> authHeaders =
RESTUtil.merge(initHeaders, OAuth2Util.authHeaders(authResponse.token()));
config = fetchConfig(initClient, authHeaders, props);
Expand All @@ -209,7 +212,7 @@ public void initialize(String name, Map<String, String> unresolved) {
this.paths = ResourcePaths.forCatalogProperties(mergedProps);

String token = mergedProps.get(OAuth2Properties.TOKEN);
this.catalogAuth = new AuthSession(baseHeaders, null, null, credential, scope);
this.catalogAuth = new AuthSession(baseHeaders, null, null, credential, scope, oauth2ServerUri);
if (authResponse != null) {
this.catalogAuth =
AuthSession.fromTokenResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ private OAuth2Properties() {}
/** A credential to exchange for a token in the OAuth2 client credentials flow. */
public static final String CREDENTIAL = "credential";

/** Token endpoint URI to fetch token from if the Rest Catalog is not the authorization server. */
public static final String OAUTH2_SERVER_URI = "oauth2-server-uri";

/**
* Interval in milliseconds to wait before attempting to exchange the configured catalog Bearer
* token. By default, token exchange will be attempted after 1 hour.
Expand Down
86 changes: 73 additions & 13 deletions core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ private static OAuthTokenResponse refreshToken(
Map<String, String> headers,
String subjectToken,
String subjectTokenType,
String scope) {
String scope,
String oauth2ServerUri) {
Map<String, String> request =
tokenExchangeRequest(
subjectToken,
Expand All @@ -143,7 +144,7 @@ private static OAuthTokenResponse refreshToken(

OAuthTokenResponse response =
client.postForm(
ResourcePaths.tokens(),
oauth2ServerUri,
request,
OAuthTokenResponse.class,
headers,
Expand All @@ -160,7 +161,8 @@ public static OAuthTokenResponse exchangeToken(
String subjectTokenType,
String actorToken,
String actorTokenType,
String scope) {
String scope,
String oauth2ServerUri) {
Map<String, String> request =
tokenExchangeRequest(
subjectToken,
Expand All @@ -171,7 +173,7 @@ public static OAuthTokenResponse exchangeToken(

OAuthTokenResponse response =
client.postForm(
ResourcePaths.tokens(),
oauth2ServerUri,
request,
OAuthTokenResponse.class,
headers,
Expand All @@ -181,15 +183,38 @@ public static OAuthTokenResponse exchangeToken(
return response;
}

public static OAuthTokenResponse exchangeToken(
RESTClient client,
Map<String, String> headers,
String subjectToken,
String subjectTokenType,
String actorToken,
String actorTokenType,
String scope) {
return exchangeToken(
client,
headers,
subjectToken,
subjectTokenType,
actorToken,
actorTokenType,
scope,
ResourcePaths.tokens());
}

public static OAuthTokenResponse fetchToken(
RESTClient client, Map<String, String> headers, String credential, String scope) {
RESTClient client,
Map<String, String> headers,
String credential,
String scope,
String oauth2ServerUri) {
Map<String, String> request =
clientCredentialsRequest(
credential, scope != null ? ImmutableList.of(scope) : ImmutableList.of());

OAuthTokenResponse response =
client.postForm(
ResourcePaths.tokens(),
oauth2ServerUri,
request,
OAuthTokenResponse.class,
headers,
Expand All @@ -199,6 +224,12 @@ public static OAuthTokenResponse fetchToken(
return response;
}

public static OAuthTokenResponse fetchToken(
RESTClient client, Map<String, String> headers, String credential, String scope) {

return fetchToken(client, headers, credential, scope, ResourcePaths.tokens());
}

private static Map<String, String> tokenExchangeRequest(
String subjectToken, String subjectTokenType, List<String> scopes) {
return tokenExchangeRequest(subjectToken, subjectTokenType, null, null, scopes);
Expand Down Expand Up @@ -361,7 +392,26 @@ public static class AuthSession {
private final String credential;
private final String scope;
private volatile boolean keepRefreshed = true;
private final String oauth2ServerUri;

public AuthSession(
Map<String, String> baseHeaders,
String token,
String tokenType,
String credential,
String scope,
String oauth2ServerUri) {
this.headers = RESTUtil.merge(baseHeaders, authHeaders(token));
this.token = token;
this.tokenType = tokenType;
this.expiresAtMillis = OAuth2Util.expiresAtMillis(token);
this.credential = credential;
this.scope = scope;
this.oauth2ServerUri = oauth2ServerUri;
}

/** @deprecated since 1.5.0, will be removed in 1.6.0 */
@Deprecated
public AuthSession(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it is not used anywhere. Should we deprecate it? If that's so, can we add deprecation comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that too. I lack the context to make that decision, so I will leave this comment thread for others to opine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't appear to be referenced anymore so you should add a deprecation warning and message (see deprecation docs)

Map<String, String> baseHeaders,
String token,
Expand All @@ -374,6 +424,7 @@ public AuthSession(
this.expiresAtMillis = OAuth2Util.expiresAtMillis(token);
this.credential = credential;
this.scope = scope;
this.oauth2ServerUri = ResourcePaths.tokens();
}

public Map<String, String> headers() {
Expand Down Expand Up @@ -404,6 +455,10 @@ public String credential() {
return credential;
}

public String oauth2ServerUri() {
return oauth2ServerUri;
}

@VisibleForTesting
static void setTokenRefreshNumRetries(int retries) {
tokenRefreshNumRetries = retries;
Expand All @@ -415,7 +470,8 @@ static void setTokenRefreshNumRetries(int retries) {
* @return A new {@link AuthSession} with empty headers.
*/
public static AuthSession empty() {
return new AuthSession(ImmutableMap.of(), null, null, null, OAuth2Properties.CATALOG_SCOPE);
return new AuthSession(
ImmutableMap.of(), null, null, null, OAuth2Properties.CATALOG_SCOPE, null);
}

/**
Expand Down Expand Up @@ -470,14 +526,14 @@ private OAuthTokenResponse refreshCurrentToken(RESTClient client) {
return refreshExpiredToken(client);
} else {
// attempt a normal refresh
return refreshToken(client, headers(), token, tokenType, scope);
return refreshToken(client, headers(), token, tokenType, scope, oauth2ServerUri);
}
}

private OAuthTokenResponse refreshExpiredToken(RESTClient client) {
if (credential != null) {
Map<String, String> basicHeaders = RESTUtil.merge(headers(), basicAuthHeaders(credential));
return refreshToken(client, basicHeaders, token, tokenType, scope);
return refreshToken(client, basicHeaders, token, tokenType, scope, oauth2ServerUri);
}

return null;
Expand Down Expand Up @@ -533,7 +589,8 @@ public static AuthSession fromAccessToken(
token,
OAuth2Properties.ACCESS_TOKEN_TYPE,
parent.credential(),
parent.scope());
parent.scope(),
parent.oauth2ServerUri());

long startTimeMillis = System.currentTimeMillis();
Long expiresAtMillis = session.expiresAtMillis();
Expand Down Expand Up @@ -571,7 +628,8 @@ public static AuthSession fromCredential(
AuthSession parent) {
long startTimeMillis = System.currentTimeMillis();
OAuthTokenResponse response =
fetchToken(client, parent.headers(), credential, parent.scope());
fetchToken(
client, parent.headers(), credential, parent.scope(), parent.oauth2ServerUri());
return fromTokenResponse(client, executor, response, startTimeMillis, parent, credential);
}

Expand All @@ -598,7 +656,8 @@ private static AuthSession fromTokenResponse(
response.token(),
response.issuedTokenType(),
credential,
parent.scope());
parent.scope(),
parent.oauth2ServerUri());

Long expiresAtMillis = session.expiresAtMillis();
if (null == expiresAtMillis && response.expiresInSeconds() != null) {
Expand Down Expand Up @@ -627,7 +686,8 @@ public static AuthSession fromTokenExchange(
tokenType,
parent.token(),
parent.tokenType(),
parent.scope());
parent.scope(),
parent.oauth2ServerUri());
return fromTokenResponse(client, executor, response, startTimeMillis, parent);
}
}
Expand Down
Loading