Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions src/main/java/com/auth0/jwk/UrlJwkProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import com.fasterxml.jackson.databind.ObjectReader;

import java.io.IOException;
import java.io.InputStream;
import java.net.*;
import java.net.http.*;
import java.time.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand All @@ -20,9 +21,10 @@ public class UrlJwkProvider implements JwkProvider {
@VisibleForTesting
static final String WELL_KNOWN_JWKS_PATH = "/.well-known/jwks.json";

final HttpClient client;
final HttpRequest request;
final URL url;
final Proxy proxy;
final Map<String, String> headers;
final Integer connectTimeout;
final Integer readTimeout;

Expand All @@ -33,7 +35,7 @@ public class UrlJwkProvider implements JwkProvider {
*
* @param url to load the jwks
*/
public UrlJwkProvider(URL url) {
public UrlJwkProvider(URL url) throws URISyntaxException {
this(url, null, null, null, null);
}

Expand All @@ -45,7 +47,7 @@ public UrlJwkProvider(URL url) {
* @param readTimeout read timeout in milliseconds (null for default)
* @param proxy proxy server to use when making the connection
*/
public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Proxy proxy) {
public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Proxy proxy) throws URISyntaxException {
this(url, connectTimeout, readTimeout, proxy, null);
}

Expand All @@ -58,7 +60,7 @@ public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Prox
* @param proxy proxy server to use when making the connection (default is null)
* @param headers a map of request header keys to values to send on the request. Default is "Accept: application/json".
*/
public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Proxy proxy, Map<String, String> headers) {
public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Proxy proxy, Map<String, String> headers) throws URISyntaxException {
Util.checkArgument(url != null, "A non-null url is required");
Util.checkArgument(connectTimeout == null || connectTimeout >= 0, "Invalid connect timeout value '" + connectTimeout + "'. Must be a non-negative integer.");
Util.checkArgument(readTimeout == null || readTimeout >= 0, "Invalid read timeout value '" + readTimeout + "'. Must be a non-negative integer.");
Expand All @@ -69,8 +71,21 @@ public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Prox
this.readTimeout = readTimeout;
this.reader = new ObjectMapper().readerFor(Map.class);

this.headers = (headers == null) ?
Collections.singletonMap("Accept", "application/json") : headers;
Map<String, String> mHeaders = (headers == null) ? Collections.singletonMap("Accept", "application/json") : headers;

this.client = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_2)
.connectTimeout(Duration.ofMillis(connectTimeout != null ? connectTimeout : 10000))
.build();

HttpRequest.Builder requestBuilder = HttpRequest.newBuilder()
.uri(url.toURI())
.GET();
for (Map.Entry<String, String> entry : mHeaders.entrySet()) {
requestBuilder.header(entry.getKey(), entry.getValue());
}

this.request = requestBuilder.build();
}

/**
Expand All @@ -80,7 +95,7 @@ public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Prox
* @param connectTimeout connection timeout in milliseconds (null for default)
* @param readTimeout read timeout in milliseconds (null for default)
*/
public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout) {
public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout) throws URISyntaxException {
this(url, connectTimeout, readTimeout, null, null);
}

Expand Down Expand Up @@ -120,23 +135,15 @@ static URL urlForDomain(String domain) {

private Map<String, Object> getJwks() throws SigningKeyNotFoundException {
try {
final URLConnection c = (proxy == null) ? this.url.openConnection() : this.url.openConnection(proxy);
if (connectTimeout != null) {
c.setConnectTimeout(connectTimeout);
}
if (readTimeout != null) {
c.setReadTimeout(readTimeout);
}
HttpResponse<String> response = this.client.send(this.request, HttpResponse.BodyHandlers.ofString());

for (Map.Entry<String, String> entry : headers.entrySet()) {
c.setRequestProperty(entry.getKey(), entry.getValue());
if (response.statusCode() != 200) {
throw new IOException("Failed to fetch JWKS: HTTP " + response.statusCode());
}

try (InputStream inputStream = c.getInputStream()) {
return reader.readValue(inputStream);
}
} catch (IOException e) {
throw new NetworkException("Cannot obtain jwks from url " + url.toString(), e);
return reader.readValue(response.body());
} catch (IOException | InterruptedException e) {
throw new NetworkException("Cannot obtain JWKS from URL " + url, e);
}
}

Expand Down