diff --git a/src/main/java/com/auth0/jwk/UrlJwkProvider.java b/src/main/java/com/auth0/jwk/UrlJwkProvider.java index f3c1c9b..c71d588 100644 --- a/src/main/java/com/auth0/jwk/UrlJwkProvider.java +++ b/src/main/java/com/auth0/jwk/UrlJwkProvider.java @@ -4,8 +4,9 @@ import com.fasterxml.jackson.databind.ObjectReader; import java.io.IOException; -import java.io.InputStream; import java.net.*; +import java.net.http.*; +import java.time.*; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -20,9 +21,10 @@ public class UrlJwkProvider implements JwkProvider { @VisibleForTesting static final String WELL_KNOWN_JWKS_PATH = "/.well-known/jwks.json"; + final HttpClient client; + final HttpRequest request; final URL url; final Proxy proxy; - final Map headers; final Integer connectTimeout; final Integer readTimeout; @@ -33,7 +35,7 @@ public class UrlJwkProvider implements JwkProvider { * * @param url to load the jwks */ - public UrlJwkProvider(URL url) { + public UrlJwkProvider(URL url) throws URISyntaxException { this(url, null, null, null, null); } @@ -45,7 +47,7 @@ public UrlJwkProvider(URL url) { * @param readTimeout read timeout in milliseconds (null for default) * @param proxy proxy server to use when making the connection */ - public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Proxy proxy) { + public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Proxy proxy) throws URISyntaxException { this(url, connectTimeout, readTimeout, proxy, null); } @@ -58,7 +60,7 @@ public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Prox * @param proxy proxy server to use when making the connection (default is null) * @param headers a map of request header keys to values to send on the request. Default is "Accept: application/json". */ - public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Proxy proxy, Map headers) { + public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Proxy proxy, Map headers) throws URISyntaxException { Util.checkArgument(url != null, "A non-null url is required"); Util.checkArgument(connectTimeout == null || connectTimeout >= 0, "Invalid connect timeout value '" + connectTimeout + "'. Must be a non-negative integer."); Util.checkArgument(readTimeout == null || readTimeout >= 0, "Invalid read timeout value '" + readTimeout + "'. Must be a non-negative integer."); @@ -69,8 +71,21 @@ public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Prox this.readTimeout = readTimeout; this.reader = new ObjectMapper().readerFor(Map.class); - this.headers = (headers == null) ? - Collections.singletonMap("Accept", "application/json") : headers; + Map mHeaders = (headers == null) ? Collections.singletonMap("Accept", "application/json") : headers; + + this.client = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .connectTimeout(Duration.ofMillis(connectTimeout != null ? connectTimeout : 10000)) + .build(); + + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() + .uri(url.toURI()) + .GET(); + for (Map.Entry entry : mHeaders.entrySet()) { + requestBuilder.header(entry.getKey(), entry.getValue()); + } + + this.request = requestBuilder.build(); } /** @@ -80,7 +95,7 @@ public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Prox * @param connectTimeout connection timeout in milliseconds (null for default) * @param readTimeout read timeout in milliseconds (null for default) */ - public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout) { + public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout) throws URISyntaxException { this(url, connectTimeout, readTimeout, null, null); } @@ -120,23 +135,15 @@ static URL urlForDomain(String domain) { private Map getJwks() throws SigningKeyNotFoundException { try { - final URLConnection c = (proxy == null) ? this.url.openConnection() : this.url.openConnection(proxy); - if (connectTimeout != null) { - c.setConnectTimeout(connectTimeout); - } - if (readTimeout != null) { - c.setReadTimeout(readTimeout); - } + HttpResponse response = this.client.send(this.request, HttpResponse.BodyHandlers.ofString()); - for (Map.Entry entry : headers.entrySet()) { - c.setRequestProperty(entry.getKey(), entry.getValue()); + if (response.statusCode() != 200) { + throw new IOException("Failed to fetch JWKS: HTTP " + response.statusCode()); } - try (InputStream inputStream = c.getInputStream()) { - return reader.readValue(inputStream); - } - } catch (IOException e) { - throw new NetworkException("Cannot obtain jwks from url " + url.toString(), e); + return reader.readValue(response.body()); + } catch (IOException | InterruptedException e) { + throw new NetworkException("Cannot obtain JWKS from URL " + url, e); } }