diff --git a/sdk/keyvault/azure-security-keyvault-jca/CHANGELOG.md b/sdk/keyvault/azure-security-keyvault-jca/CHANGELOG.md index bed1ca469f3e..2e75f9aa4416 100644 --- a/sdk/keyvault/azure-security-keyvault-jca/CHANGELOG.md +++ b/sdk/keyvault/azure-security-keyvault-jca/CHANGELOG.md @@ -1,15 +1,16 @@ # Release History -## 2.9.0-beta.2 (Unreleased) +## 2.9.0-beta.2 (2024-07-01) ### Features Added +- Added the new system property `azure.keyvault.disable-challenge-resource-verification`, which can be set to `true` to disable challenge resource verification when authenticating against the Azure Key Vault service. For more information, please refer to [this link](https://devblogs.microsoft.com/azure-sdk/guidance-for-applications-using-the-key-vault-libraries/). ### Breaking Changes +- Removed support for providing a custom login URI to get access tokens from via the system property `azure.login.uri`. ### Bugs Fixed -- Fix bug: AccessTokenUtil does not urlencode its parameters when getting an access token. ([40616](https://github.com/Azure/azure-sdk-for-java/issues/40616)) - -### Other Changes +- Fix bug: AccessTokenUtil does not urlencode its parameters when getting an access token. ([#40616](https://github.com/Azure/azure-sdk-for-java/issues/40616)) +- Changed the authentication mechanism to allow for discovering the login URI for a given Azure Key Vault instance by requesting an authentication challenge from the service, as opposed to using a hard-coded list of URIs to choose from depending on a vault's URI. This should add support for customers using Azure Stack instances, for example. ## 2.9.0-beta.1 (2024-05-15) diff --git a/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/KeyVaultKeyStore.java b/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/KeyVaultKeyStore.java index 965f52365788..1b4d5fe634e5 100644 --- a/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/KeyVaultKeyStore.java +++ b/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/KeyVaultKeyStore.java @@ -41,7 +41,6 @@ * @see KeyStoreSpi */ public final class KeyVaultKeyStore extends KeyStoreSpi { - /** * Stores the key-store name. */ @@ -98,13 +97,13 @@ public final class KeyVaultKeyStore extends KeyStoreSpi { * Store the path where the well know certificate is placed */ final String wellKnowPath = Optional.ofNullable(System.getProperty("azure.cert-path.well-known")) - .orElse("/etc/certs/well-known/"); + .orElse("/etc/certs/well-known/"); /** * Store the path where the custom certificate is placed */ final String customPath = Optional.ofNullable(System.getProperty("azure.cert-path.custom")) - .orElse("/etc/certs/custom/"); + .orElse("/etc/certs/custom/"); /** * Constructor. @@ -122,41 +121,50 @@ public final class KeyVaultKeyStore extends KeyStoreSpi { */ public KeyVaultKeyStore() { LOGGER.log(FINE, "Constructing KeyVaultKeyStore."); + creationDate = new Date(); String keyVaultUri = System.getProperty("azure.keyvault.uri"); - String loginUri = System.getProperty("azure.login.uri"); String tenantId = System.getProperty("azure.keyvault.tenant-id"); String clientId = System.getProperty("azure.keyvault.client-id"); String clientSecret = System.getProperty("azure.keyvault.client-secret"); String managedIdentity = System.getProperty("azure.keyvault.managed-identity"); + boolean disableChallengeResourceVerification = + Boolean.parseBoolean(System.getProperty("azure.keyvault.disable-challenge-resource-verification")); long refreshInterval = getRefreshInterval(); refreshCertificatesWhenHaveUnTrustCertificate = Optional.of("azure.keyvault.jca.refresh-certificates-when-have-un-trust-certificate") - .map(System::getProperty) - .map(Boolean::parseBoolean) - .orElse(false); + .map(System::getProperty) + .map(Boolean::parseBoolean) + .orElse(false); + jreCertificates = JreCertificates.getInstance(); LOGGER.log(FINE, String.format("Loaded jre certificates: %s.", jreCertificates.getAliases())); + wellKnowCertificates = SpecificPathCertificates.getSpecificPathCertificates(wellKnowPath); LOGGER.log(FINE, String.format("Loaded well known certificates: %s.", wellKnowCertificates.getAliases())); + customCertificates = SpecificPathCertificates.getSpecificPathCertificates(customPath); LOGGER.log(FINE, String.format("Loaded custom certificates: %s.", customCertificates.getAliases())); - keyVaultCertificates = new KeyVaultCertificates( - refreshInterval, keyVaultUri, loginUri, tenantId, clientId, clientSecret, managedIdentity); + + keyVaultCertificates = new KeyVaultCertificates(refreshInterval, keyVaultUri, tenantId, clientId, clientSecret, + managedIdentity, disableChallengeResourceVerification); LOGGER.log(FINE, String.format("Loaded Key Vault certificates: %s.", keyVaultCertificates.getAliases())); + classpathCertificates = new ClasspathCertificates(); LOGGER.log(FINE, String.format("Loaded classpath certificates: %s.", classpathCertificates.getAliases())); - allCertificates = Arrays.asList( - jreCertificates, wellKnowCertificates, customCertificates, keyVaultCertificates, classpathCertificates); + + allCertificates = Arrays.asList(jreCertificates, wellKnowCertificates, customCertificates, keyVaultCertificates, + classpathCertificates); } Long getRefreshInterval() { - return Stream.of("azure.keyvault.jca.certificates-refresh-interval-in-ms", "azure.keyvault.jca.certificates-refresh-interval") - .map(System::getProperty) - .filter(Objects::nonNull) - .map(Long::valueOf) - .findFirst() - .orElse(0L); + return Stream.of("azure.keyvault.jca.certificates-refresh-interval-in-ms", + "azure.keyvault.jca.certificates-refresh-interval") + .map(System::getProperty) + .filter(Objects::nonNull) + .map(Long::valueOf) + .findFirst() + .orElse(0L); } /** @@ -169,17 +177,23 @@ Long getRefreshInterval() { * @throws KeyStoreException when no Provider supports a KeyStoreSpi implementation for the specified type * @throws IOException when an I/O error occurs. */ - public static KeyStore getKeyVaultKeyStoreBySystemProperty() throws CertificateException, NoSuchAlgorithmException, KeyStoreException, IOException { + public static KeyStore getKeyVaultKeyStoreBySystemProperty() throws CertificateException, NoSuchAlgorithmException, + KeyStoreException, IOException { + KeyStore keyStore = KeyStore.getInstance(KeyVaultJcaProvider.PROVIDER_NAME); - KeyVaultLoadStoreParameter parameter = new KeyVaultLoadStoreParameter( - System.getProperty("azure.keyvault.uri"), - System.getProperty("azure.login.uri"), - System.getProperty("azure.keyvault.tenant-id"), - System.getProperty("azure.keyvault.client-id"), - System.getProperty("azure.keyvault.client-secret"), - System.getProperty("azure.keyvault.managed-identity")); + KeyVaultLoadStoreParameter keyVaultLoadStoreParameter = + new KeyVaultLoadStoreParameter( + System.getProperty("azure.keyvault.uri"), + System.getProperty("azure.keyvault.tenant-id"), + System.getProperty("azure.keyvault.client-id"), + System.getProperty("azure.keyvault.client-secret"), + System.getProperty("azure.keyvault.managed-identity")); + + if (Boolean.parseBoolean(System.getProperty("azure.keyvault.disable-challenge-resource-verification"))) { + keyVaultLoadStoreParameter.disableChallengeResourceVerification(); + } - keyStore.load(parameter); + keyStore.load(keyVaultLoadStoreParameter); return keyStore; } @@ -241,16 +255,17 @@ public boolean engineEntryInstanceOf(String alias, Class a.containsKey(alias)) - .findFirst() - .map(certificates -> certificates.get(alias)) - .orElse(null); + .map(AzureCertificates::getCertificates) + .filter(a -> a.containsKey(alias)) + .findFirst() + .map(certificates -> certificates.get(alias)) + .orElse(null); if (refreshCertificatesWhenHaveUnTrustCertificate && certificate == null) { keyVaultCertificates.refreshCertificates(); certificate = keyVaultCertificates.getCertificates().get(alias); } + return certificate; } @@ -264,20 +279,25 @@ public Certificate engineGetCertificate(String alias) { @Override public String engineGetCertificateAlias(Certificate cert) { String alias = null; + if (cert != null) { List aliasList = getAllAliases(); for (String candidateAlias : aliasList) { Certificate certificate = engineGetCertificate(candidateAlias); + if (certificate.equals(cert)) { alias = candidateAlias; + break; } } } + if (refreshCertificatesWhenHaveUnTrustCertificate && alias == null) { alias = keyVaultCertificates.refreshAndGetAliasByCertificate(cert); } + return alias; } @@ -293,10 +313,12 @@ public String engineGetCertificateAlias(Certificate cert) { public Certificate[] engineGetCertificateChain(String alias) { Certificate[] chain = null; Certificate certificate = engineGetCertificate(alias); + if (certificate != null) { chain = new Certificate[1]; chain[0] = certificate; } + return chain; } @@ -322,7 +344,9 @@ public Date engineGetCreationDate(String alias) { * @exception UnrecoverableEntryException if the specified {@code protParam} were insufficient or invalid */ @Override - public KeyStore.Entry engineGetEntry(String alias, KeyStore.ProtectionParameter protParam) throws KeyStoreException, NoSuchAlgorithmException, UnrecoverableEntryException { + public KeyStore.Entry engineGetEntry(String alias, KeyStore.ProtectionParameter protParam) throws KeyStoreException, + NoSuchAlgorithmException, UnrecoverableEntryException { + return super.engineGetEntry(alias, protParam); } @@ -336,11 +360,11 @@ public KeyStore.Entry engineGetEntry(String alias, KeyStore.ProtectionParameter @Override public Key engineGetKey(String alias, char[] password) { return allCertificates.stream() - .map(AzureCertificates::getCertificateKeys) - .filter(a -> a.containsKey(alias)) - .findFirst() - .map(certificateKeys -> certificateKeys.get(alias)) - .orElse(null); + .map(AzureCertificates::getCertificateKeys) + .filter(a -> a.containsKey(alias)) + .findFirst() + .map(certificateKeys -> certificateKeys.get(alias)) + .orElse(null); } /** @@ -368,17 +392,18 @@ public boolean engineIsKeyEntry(String alias) { /** * Loads the keystore using the given {@code KeyStore.LoadStoreParameter}. * - * @param param the {@code KeyStore.LoadStoreParameter} that specifies how to load the keystore, which may be - * {@code null}. + * @param param the {@code KeyStore.LoadStoreParameter} + * that specifies how to load the keystore, + * which may be {@code null} */ @Override public void engineLoad(KeyStore.LoadStoreParameter param) { if (param instanceof KeyVaultLoadStoreParameter) { KeyVaultLoadStoreParameter parameter = (KeyVaultLoadStoreParameter) param; - keyVaultCertificates.updateKeyVaultClient( - parameter.getUri(), parameter.getLoginUri(), parameter.getTenantId(), parameter.getClientId(), - parameter.getClientSecret(), parameter.getManagedIdentity()); + keyVaultCertificates.updateKeyVaultClient(parameter.getUri(), parameter.getTenantId(), + parameter.getClientId(), parameter.getClientSecret(), parameter.getManagedIdentity(), + parameter.isChallengeResourceVerificationDisabled()); } classpathCertificates.loadCertificatesFromClasspath(); @@ -398,6 +423,7 @@ public void engineLoad(InputStream stream, char[] password) { private List getAllAliases() { List allAliases = new ArrayList<>(jreCertificates.getAliases()); Map> aliasLists = new HashMap<>(); + aliasLists.put("well known certificates", wellKnowCertificates.getAliases()); aliasLists.put("custom certificates", customCertificates.getAliases()); aliasLists.put("key vault certificates", keyVaultCertificates.getAliases()); @@ -410,6 +436,7 @@ private List getAllAliases() { allAliases.add(alias); } })); + return allAliases; } @@ -423,8 +450,10 @@ private List getAllAliases() { public void engineSetCertificateEntry(String alias, Certificate certificate) { if (getAllAliases().contains(alias)) { LOGGER.log(WARNING, "Cannot overwrite own certificate"); + return; } + classpathCertificates.setCertificateEntry(alias, certificate); } @@ -439,7 +468,9 @@ public void engineSetCertificateEntry(String alias, Certificate certificate) { * @throws KeyStoreException if this operation fails */ @Override - public void engineSetEntry(String alias, KeyStore.Entry entry, KeyStore.ProtectionParameter protParam) throws KeyStoreException { + public void engineSetEntry(String alias, KeyStore.Entry entry, KeyStore.ProtectionParameter protParam) + throws KeyStoreException { + super.engineSetEntry(alias, entry, protParam); } @@ -494,6 +525,4 @@ public void engineStore(OutputStream stream, char[] password) { @Override public void engineStore(KeyStore.LoadStoreParameter param) { } - - } diff --git a/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/KeyVaultLoadStoreParameter.java b/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/KeyVaultLoadStoreParameter.java index e70a14be3f5f..3424b826c354 100644 --- a/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/KeyVaultLoadStoreParameter.java +++ b/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/KeyVaultLoadStoreParameter.java @@ -11,17 +11,11 @@ * @see KeyStore.LoadStoreParameter */ public final class KeyVaultLoadStoreParameter implements KeyStore.LoadStoreParameter { - /** * Stores the Key Vault URI. */ private final String keyVaultUri; - /** - * Stores the Azure login URI. - */ - private final String loginUri; - /** * Stores the tenant id. */ @@ -42,63 +36,55 @@ public final class KeyVaultLoadStoreParameter implements KeyStore.LoadStoreParam */ private final String managedIdentity; + /** + * Stores a flag indicating if challenge resource verification shall be disabled. + */ + private boolean disableChallengeResourceVerification; + /** * Constructor. * * @param keyVaultUri The Azure Key Vault URI. */ public KeyVaultLoadStoreParameter(String keyVaultUri) { - this(keyVaultUri, null, null, null, null, null); + this(keyVaultUri, null, null, null, null); } /** * Constructor. * - * @param keyVaultUri the Azure Key Vault URI. - * @param managedIdentity The Managed Identity. + * @param keyVaultUri The Azure Key Vault URI. + * @param managedIdentity The managed identity. */ public KeyVaultLoadStoreParameter(String keyVaultUri, String managedIdentity) { - this(keyVaultUri, null, null, null, null, managedIdentity); + this(keyVaultUri, null, null, null, managedIdentity); } /** * Constructor. * - * @param keyVaultUri the Azure Key Vault URI. + * @param keyVaultUri The Azure Key Vault URI. * @param tenantId The tenant id. * @param clientId The client id. * @param clientSecret The client secret. */ public KeyVaultLoadStoreParameter(String keyVaultUri, String tenantId, String clientId, String clientSecret) { - this(keyVaultUri, null, tenantId, clientId, clientSecret, null); + this(keyVaultUri, tenantId, clientId, clientSecret, null); } /** * Constructor. * - * @param keyVaultUri the Azure Key Vault URI. + * @param keyVaultUri The Azure Key Vault URI. * @param tenantId The tenant id. * @param clientId The client id. * @param clientSecret The client secret. - * @param managedIdentity The Managed Identity. + * @param managedIdentity The managed identity. */ - public KeyVaultLoadStoreParameter(String keyVaultUri, String tenantId, String clientId, String clientSecret, String managedIdentity) { - this(keyVaultUri, null, tenantId, clientId, clientSecret, managedIdentity); - } + public KeyVaultLoadStoreParameter(String keyVaultUri, String tenantId, String clientId, String clientSecret, + String managedIdentity) { - /** - * Constructor. - * - * @param keyVaultUri the Azure Key Vault URI. - * @param loginUri The Azure login URI. - * @param tenantId The tenant id. - * @param clientId The client id. - * @param clientSecret The client secret. - * @param managedIdentity The Managed Identity. - */ - public KeyVaultLoadStoreParameter(String keyVaultUri, String loginUri, String tenantId, String clientId, String clientSecret, String managedIdentity) { this.keyVaultUri = keyVaultUri; - this.loginUri = loginUri; this.tenantId = tenantId; this.clientId = clientId; this.clientSecret = clientSecret; @@ -134,7 +120,7 @@ public String getClientSecret() { } /** - * Get the managed identity. + * Get the Managed Identity. * * @return The Managed Identity. */ @@ -145,7 +131,7 @@ public String getManagedIdentity() { /** * Get the tenant id. * - * @return The tenant id. + * @return the tenant id. */ public String getTenantId() { return tenantId; @@ -161,11 +147,20 @@ public String getUri() { } /** - * Get the Azure login URI. + * Get a value indicating a check verifying if the authentication challenge resource matches the Key Vault or + * Managed HSM domain will be performed. This verification is performed by default. * - * @return The Azure login URI. + * @return A value indicating if challenge resource verification is disabled. + */ + public boolean isChallengeResourceVerificationDisabled() { + return disableChallengeResourceVerification; + } + + /** + * Disables verifying if the authentication challenge resource matches the Key Vault or Managed HSM domain. This + * verification is performed by default. */ - public String getLoginUri() { - return loginUri; + public void disableChallengeResourceVerification() { + disableChallengeResourceVerification = true; } } diff --git a/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/KeyVaultClient.java b/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/KeyVaultClient.java index 3f93ebbb5223..890029c32fb1 100644 --- a/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/KeyVaultClient.java +++ b/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/KeyVaultClient.java @@ -18,8 +18,6 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.StringReader; -import java.net.URI; -import java.net.URISyntaxException; import java.net.URLEncoder; import java.security.Key; import java.security.KeyFactory; @@ -42,6 +40,11 @@ import java.util.Optional; import java.util.logging.Logger; +import static com.azure.security.keyvault.jca.implementation.utils.AccessTokenUtil.getLoginUri; +import static com.azure.security.keyvault.jca.implementation.utils.HttpUtil.API_VERSION_POSTFIX; +import static com.azure.security.keyvault.jca.implementation.utils.HttpUtil.HTTPS_PREFIX; +import static com.azure.security.keyvault.jca.implementation.utils.HttpUtil.addTrailingSlashIfRequired; +import static com.azure.security.keyvault.jca.implementation.utils.HttpUtil.validateUri; import static java.util.logging.Level.INFO; import static java.util.logging.Level.WARNING; @@ -49,40 +52,7 @@ * The REST client specific to Azure Key Vault. */ public class KeyVaultClient { - - public static final String KEY_VAULT_BASE_URI_GLOBAL = "https://vault.azure.net"; - public static final String KEY_VAULT_BASE_URI_CN = "https://vault.azure.cn"; - public static final String KEY_VAULT_BASE_URI_US = "https://vault.usgovcloudapi.net"; - public static final String KEY_VAULT_BASE_URI_DE = "https://vault.microsoftazure.de"; - public static final String AAD_LOGIN_URI_GLOBAL = "https://login.microsoftonline.com/"; - public static final String AAD_LOGIN_URI_CN = "https://login.partner.microsoftonline.cn/"; - public static final String AAD_LOGIN_URI_US = "https://login.microsoftonline.us/"; - public static final String AAD_LOGIN_URI_DE = "https://login.microsoftonline.de/"; - private static final Logger LOGGER = Logger.getLogger(KeyVaultClient.class.getName()); - private static final String HTTPS_PREFIX = "https://"; - private static final String API_VERSION_POSTFIX = "?api-version=7.1"; - - public static String getAADLoginURIByKeyVaultBaseUri(String keyVaultBaseUri) { - String aadAuthenticationUrl; - switch (keyVaultBaseUri) { - case KEY_VAULT_BASE_URI_GLOBAL : - aadAuthenticationUrl = AAD_LOGIN_URI_GLOBAL; - break; - case KEY_VAULT_BASE_URI_CN : - aadAuthenticationUrl = AAD_LOGIN_URI_CN; - break; - case KEY_VAULT_BASE_URI_US : - aadAuthenticationUrl = AAD_LOGIN_URI_US; - break; - case KEY_VAULT_BASE_URI_DE: - aadAuthenticationUrl = AAD_LOGIN_URI_DE; - break; - default: - throw new IllegalArgumentException("Property of azure.keyvault.uri is illegal."); - } - return aadAuthenticationUrl; - } /** * Stores the Key Vault cloud URI. @@ -94,11 +64,6 @@ public static String getAADLoginURIByKeyVaultBaseUri(String keyVaultBaseUri) { */ private final String keyVaultUri; - /** - * Stores the AAD authentication URL (or null to default to Azure Public Cloud). - */ - private final String aadAuthenticationUri; - /** * Stores the tenant ID. */ @@ -115,7 +80,7 @@ public static String getAADLoginURIByKeyVaultBaseUri(String keyVaultBaseUri) { private final String clientSecret; /** - * Stores the managed identity (either the user-assigned managed identity object ID or null if system-assigned) + * Stores the managed identity (either the user-assigned managed identity object ID or null if system-assigned). */ private String managedIdentity; @@ -124,292 +89,295 @@ public static String getAADLoginURIByKeyVaultBaseUri(String keyVaultBaseUri) { */ private AccessToken accessToken; + /** + * Stores a flag indicating if challenge resource verification shall be disabled. + */ + private final boolean disableChallengeResourceVerification; + + /** * Constructor for authentication with user-assigned managed identity. * - * @param keyVaultUri the Azure Key Vault URI. - * @param managedIdentity the user-assigned managed identity object ID. + * @param keyVaultUri The Azure Key Vault URI. + * @param managedIdentity The user-assigned managed identity object ID. */ KeyVaultClient(String keyVaultUri, String managedIdentity) { - this(keyVaultUri, null, null, null, null, managedIdentity); + this(keyVaultUri, null, null, null, managedIdentity, false); } /** * Constructor for authentication with service principal. * - * @param keyVaultUri the Azure Key Vault URI. - * @param tenantId the tenant ID. - * @param clientId the client ID. - * @param clientSecret the client secret. + * @param keyVaultUri The Azure Key Vault URI. + * @param tenantId The tenant ID. + * @param clientId The client ID. + * @param clientSecret The client secret. */ public KeyVaultClient(String keyVaultUri, String tenantId, String clientId, String clientSecret) { - this(keyVaultUri, null, tenantId, clientId, clientSecret, null); + this(keyVaultUri, tenantId, clientId, clientSecret, null, false); } - /** * Constructor. * - * @param keyVaultUri the Azure Key Vault URI. - * @param tenantId the tenant ID. - * @param clientId the client ID. - * @param clientSecret the client secret. - * @param managedIdentity the user-assigned managed identity object ID. + * @param keyVaultUri The Azure Key Vault URI. + * @param tenantId The tenant ID. + * @param clientId The client ID. + * @param clientSecret The client secret. + * @param managedIdentity The user-assigned managed identity object ID. + * @param disableChallengeResourceVerification Indicates if the challenge resource verification should be disabled. */ - public KeyVaultClient(String keyVaultUri, String loginUri, String tenantId, String clientId, String clientSecret, - String managedIdentity) { - LOGGER.log(INFO, "Using Azure Key Vault: {0}", keyVaultUri); - - if (!keyVaultUri.endsWith("/")) { - keyVaultUri = keyVaultUri + "/"; - } + public KeyVaultClient(String keyVaultUri, String tenantId, String clientId, String clientSecret, + String managedIdentity, boolean disableChallengeResourceVerification) { - this.keyVaultUri = keyVaultUri; - // Base Uri shouldn't end with a slash. - String domainNameSuffix = Optional.of(keyVaultUri) - .map(uri -> uri.split("\\.", 2)[1]) - .map(suffix -> suffix.substring(0, suffix.length() - 1)) - .orElse(null); - this.keyVaultBaseUri = validateUri(HTTPS_PREFIX + domainNameSuffix, "Key Vault URI"); - this.aadAuthenticationUri = addTrailingSlashIfRequired( - loginUri != null - ? validateUri(loginUri, "Login URI") // Validate any user-provided login URI. - : getAADLoginURIByKeyVaultBaseUri(keyVaultBaseUri)); // These are all valid URIs. + LOGGER.log(INFO, "Using Azure Key Vault: {0}", keyVaultUri); + this.keyVaultUri = addTrailingSlashIfRequired(validateUri(keyVaultUri, "Azure Key Vault URI")); + // Base URI shouldn't end with a slash. + String domainNameSuffix = Optional.of(this.keyVaultUri) + .map(uri -> uri.split("\\.", 2)[1]) + .map(suffix -> suffix.substring(0, suffix.length() - 1)) + .orElse(null); + this.keyVaultBaseUri = HTTPS_PREFIX + domainNameSuffix; this.tenantId = tenantId; this.clientId = clientId; this.clientSecret = clientSecret; this.managedIdentity = managedIdentity; + this.disableChallengeResourceVerification = disableChallengeResourceVerification; } public static KeyVaultClient createKeyVaultClientBySystemProperty() { String keyVaultUri = System.getProperty("azure.keyvault.uri"); - String loginUri = System.getProperty("azure.login.uri"); String tenantId = System.getProperty("azure.keyvault.tenant-id"); String clientId = System.getProperty("azure.keyvault.client-id"); String clientSecret = System.getProperty("azure.keyvault.client-secret"); String managedIdentity = System.getProperty("azure.keyvault.managed-identity"); + boolean disableChallengeResourceVerification = + Boolean.parseBoolean(System.getProperty("azure.keyvault.disable-challenge-resource-verification")); - return new KeyVaultClient(keyVaultUri, loginUri, tenantId, clientId, clientSecret, managedIdentity); - } - - private String validateUri(String uri, String propertyName) { - if (uri == null) { // Should the login URI be allowed to be null to default to Azure Public Cloud? - StringBuilder messageBuilder = new StringBuilder(); - - if (propertyName != null) { - messageBuilder.append(propertyName); - } else { - messageBuilder.append("Provided URI "); - } - - messageBuilder.append("cannot be null."); - - throw new NullPointerException(messageBuilder.toString()); - } - - if (!uri.startsWith(HTTPS_PREFIX)) { - throw new IllegalArgumentException("Provided URI '" + uri + "' must start with 'https://'."); - } - - try { - new URI(uri); - } catch (URISyntaxException e) { - throw new IllegalArgumentException("Provided URI '" + uri + "' is not a valid URI."); - } - - return uri; - } - - private String addTrailingSlashIfRequired(String uri) { - if (!uri.endsWith("/")) { - return uri + "/"; - } - - return uri; + return new KeyVaultClient(keyVaultUri, tenantId, clientId, clientSecret, managedIdentity, + disableChallengeResourceVerification); } /** * Get the access token. * - * @return the access token. + * @return The access token. */ private String getAccessToken() { if (accessToken != null && !accessToken.isExpired()) { return accessToken.getAccessToken(); } + accessToken = getAccessTokenByHttpRequest(); + return accessToken.getAccessToken(); } /** * Get the access token. * - * @return the access token. + * @return The access token. */ private AccessToken getAccessTokenByHttpRequest() { - LOGGER.entering("KeyVaultClient", "getAccessToken"); + LOGGER.entering("KeyVaultClient", "getAccessTokenByHttpRequest"); + AccessToken accessToken = null; + try { String resource = URLEncoder.encode(keyVaultBaseUri, "UTF-8"); + if (managedIdentity != null) { managedIdentity = URLEncoder.encode(managedIdentity, "UTF-8"); } if (tenantId != null && clientId != null && clientSecret != null) { - accessToken = AccessTokenUtil.getAccessToken(resource, aadAuthenticationUri, tenantId, clientId, - clientSecret); + String aadAuthenticationUri = getLoginUri(keyVaultUri + "certificates" + API_VERSION_POSTFIX, + disableChallengeResourceVerification); + accessToken = + AccessTokenUtil.getAccessToken(resource, aadAuthenticationUri, tenantId, clientId, clientSecret); } else { accessToken = AccessTokenUtil.getAccessToken(resource, managedIdentity); } - } catch (Throwable throwable) { - LOGGER.log(WARNING, "Unsupported encoding or missing Httpclient", throwable); + } catch (Throwable t) { + LOGGER.log(WARNING, "Could not obtain access token to authenticate with.", t); } - LOGGER.exiting("KeyVaultClient", "getAccessToken", accessToken); + + LOGGER.exiting("KeyVaultClient", "getAccessTokenByHttpRequest", accessToken); + return accessToken; } /** * Get the list of aliases. * - * @return the list of aliases. + * @return The list of aliases. */ public List getAliases() { ArrayList result = new ArrayList<>(); HashMap headers = new HashMap<>(); + headers.put("Authorization", "Bearer " + getAccessToken()); - String url = String.format("%scertificates%s", keyVaultUri, API_VERSION_POSTFIX); - while (url != null && url.length() != 0) { - String response = HttpUtil.get(url, headers); + String uri = keyVaultUri + "certificates" + API_VERSION_POSTFIX; + + while (uri != null && !uri.isEmpty()) { + String response = HttpUtil.get(uri, headers); CertificateListResult certificateListResult = null; + if (response != null) { - certificateListResult = (CertificateListResult) JsonConverterUtil.fromJson(response, - CertificateListResult.class); + certificateListResult = + (CertificateListResult) JsonConverterUtil.fromJson(response, CertificateListResult.class); } + if (certificateListResult != null) { - url = certificateListResult.getNextLink(); + uri = certificateListResult.getNextLink(); + for (CertificateItem certificateItem : certificateListResult.getValue()) { String id = certificateItem.getId(); String alias = id.substring(id.indexOf("certificates") + "certificates".length() + 1); + result.add(alias); } } else { - url = null; + uri = null; } } + return result; } /** * Get the certificate bundle. * - * @param alias the alias. - * @return the certificate bundle. + * @param alias The alias. + * @return The certificate bundle. */ private CertificateBundle getCertificateBundle(String alias) { CertificateBundle result = null; HashMap headers = new HashMap<>(); + headers.put("Authorization", "Bearer " + getAccessToken()); - String url = String.format("%scertificates/%s%s", keyVaultUri, alias, API_VERSION_POSTFIX); - String response = HttpUtil.get(url, headers); + + String uri = keyVaultUri + "certificates/" + alias + API_VERSION_POSTFIX; + String response = HttpUtil.get(uri, headers); + if (response != null) { result = (CertificateBundle) JsonConverterUtil.fromJson(response, CertificateBundle.class); } + return result; } /** * Get the certificate. * - * @param alias the alias. - * @return the certificate, or null if not found. + * @param alias The alias. + * + * @return The certificate, or null if not found. */ public Certificate getCertificate(String alias) { LOGGER.entering("KeyVaultClient", "getCertificate", alias); LOGGER.log(INFO, "Getting certificate for alias: {0}", alias); + X509Certificate certificate = null; CertificateBundle certificateBundle = getCertificateBundle(alias); + if (certificateBundle != null) { String certificateString = certificateBundle.getCer(); + if (certificateString != null) { try { CertificateFactory cf = CertificateFactory.getInstance("X.509"); certificate = (X509Certificate) cf.generateCertificate( - new ByteArrayInputStream(Base64.getDecoder().decode(certificateString)) - ); + new ByteArrayInputStream(Base64.getDecoder().decode(certificateString))); } catch (CertificateException ce) { LOGGER.log(WARNING, "Certificate error", ce); } } } + LOGGER.exiting("KeyVaultClient", "getCertificate", certificate); + return certificate; } /** * Get the key. * - * @param alias the alias. - * @param password the password. - * @return the key. + * @param alias The alias. + * @param password The password. + * + * @return The key. */ public Key getKey(String alias, char[] password) { LOGGER.entering("KeyVaultClient", "getKey", new Object[] { alias, password }); LOGGER.log(INFO, "Getting key for alias: {0}", alias); + CertificateBundle certificateBundle = getCertificateBundle(alias); boolean isExportable = Optional.ofNullable(certificateBundle) - .map(CertificateBundle::getPolicy) - .map(CertificatePolicy::getKeyProperties) - .map(KeyProperties::isExportable) - .orElse(false); + .map(CertificateBundle::getPolicy) + .map(CertificatePolicy::getKeyProperties) + .map(KeyProperties::isExportable) + .orElse(false); String keyType = Optional.ofNullable(certificateBundle) - .map(CertificateBundle::getPolicy) - .map(CertificatePolicy::getKeyProperties) - .map(KeyProperties::getKty) - .orElse(null); + .map(CertificateBundle::getPolicy) + .map(CertificatePolicy::getKeyProperties) + .map(KeyProperties::getKty) + .orElse(null); + if (!isExportable) { - // return KeyVaultPrivateKey if certificate is not exportable because - // if the service needs to obtain the private key for authentication, - // and we can't access private key(which is not exportable), we will use + // Return KeyVaultPrivateKey if certificate is not exportable because if the service needs to obtain the + // private key for authentication, and we can't access private key(which is not exportable), we will use // the Azure Key Vault Secrets API to obtain the private key (keyless). LOGGER.exiting("KeyVaultClient", "getKey", null); String keyType2 = keyType.contains("-HSM") ? keyType.substring(0, keyType.indexOf("-HSM")) : keyType; + return Optional.ofNullable(certificateBundle) - .map(CertificateBundle::getKid) - .map(kid -> new KeyVaultPrivateKey(keyType2, kid, this)) - .orElse(null); + .map(CertificateBundle::getKid) + .map(kid -> new KeyVaultPrivateKey(keyType2, kid, this)) + .orElse(null); } + String certificateSecretUri = certificateBundle.getSid(); Map headers = new HashMap<>(); + headers.put("Authorization", "Bearer " + getAccessToken()); + String body = HttpUtil.get(certificateSecretUri + API_VERSION_POSTFIX, headers); + if (body == null) { // If the private key is not available the certificate cannot be used for server side certificates or mTLS. // Then we do not know the intent of the usage at this stage we skip this key. LOGGER.exiting("KeyVaultClient", "getKey", null); + // We return null because it is really not needed. // The private key is only used for identity authentication. - // If we are unable to obtain the private key, it proves that the client - // does not own the private key (maybe due to lack of authority or other reasons). + // If we are unable to obtain the private key, it proves that the client does not own the private key + // (maybe due to lack of authority or other reasons). return null; } - // The certificate is exportable the private key is available. - // So We'll store the private key for authentication instead of - // obtaining a digital signature through the API(without keyless). + + // If the certificate is exportable the private key is available, so we'll store the private key for + // authentication instead of obtaining a digital signature through the API (without keyless). Key key = null; SecretBundle secretBundle = (SecretBundle) JsonConverterUtil.fromJson(body, SecretBundle.class); String contentType = secretBundle.getContentType(); + if ("application/x-pkcs12".equals(contentType)) { try { KeyStore keyStore = KeyStore.getInstance("PKCS12"); + keyStore.load( new ByteArrayInputStream(Base64.getDecoder().decode(secretBundle.getValue())), "".toCharArray()); + alias = keyStore.aliases().nextElement(); key = keyStore.getKey(alias, "".toCharArray()); - } catch (IOException | KeyStoreException | NoSuchAlgorithmException | UnrecoverableKeyException | CertificateException ex) { - LOGGER.log(WARNING, "Unable to decode key", ex); + } catch (IOException | KeyStoreException | NoSuchAlgorithmException | UnrecoverableKeyException + | CertificateException e) { + + LOGGER.log(WARNING, "Unable to decode key", e); } } else if ("application/x-pem-file".equals(contentType)) { try { @@ -419,76 +387,84 @@ public Key getKey(String alias, char[] password) { } } - // // If the private key is not available the certificate cannot be // used for server side certificates or mTLS. Then we do not know // the intent of the usage at this stage we skip this key. - // LOGGER.exiting("KeyVaultClient", "getKey", key); + return key; } /** - * get signature by key vault - * @param digestName digestName - * @param digestValue digestValue - * @param keyId The key id - * @return signature + * Get signature by Key Vault. + * + * @param digestName Digest name. + * @param digestValue Digest value. + * @param keyId The key ID. + * + * @return Signature. */ public byte[] getSignedWithPrivateKey(String digestName, String digestValue, String keyId) { SignResult result = null; String bodyString = String.format("{\"alg\": \"" + digestName + "\", \"value\": \"%s\"}", digestValue); Map headers = new HashMap<>(); + headers.put("Authorization", "Bearer " + getAccessToken()); - String url = String.format("%s/sign%s", keyId, API_VERSION_POSTFIX); - String response = HttpUtil.post(url, headers, bodyString, "application/json"); + + String uri = keyId + "/sign" + API_VERSION_POSTFIX; + String response = HttpUtil.post(uri, headers, bodyString, "application/json"); + if (response != null) { result = (SignResult) JsonConverterUtil.fromJson(response, SignResult.class); } + if (result != null) { return Base64.getUrlDecoder().decode(result.getValue()); } + return new byte[0]; } /** * Get the private key from the PEM string. * - * @param pemString the PEM file in string format. - * @param keyType the private key type in string format. - * @return the private key + * @param pemString The PEM file in string format. + * @param keyType The private key type in string format. + * + * @return The private key. + * * @throws IOException when an I/O error occurs. * @throws NoSuchAlgorithmException when algorithm is unavailable. * @throws InvalidKeySpecException when the private key cannot be generated. */ private PrivateKey createPrivateKeyFromPem(String pemString, String keyType) throws IOException, NoSuchAlgorithmException, InvalidKeySpecException { + StringBuilder builder = new StringBuilder(); + try (BufferedReader reader = new BufferedReader(new StringReader(pemString))) { String line = reader.readLine(); + if (line == null || !line.contains("BEGIN PRIVATE KEY")) { throw new IllegalArgumentException("No PRIVATE KEY found"); } + line = ""; + while (line != null) { if (line.contains("END PRIVATE KEY")) { break; } + builder.append(line); line = reader.readLine(); } } + byte[] bytes = Base64.getDecoder().decode(builder.toString()); PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(bytes); KeyFactory factory = KeyFactory.getInstance(keyType); - return factory.generatePrivate(spec); - } - String getKeyVaultBaseUri() { - return keyVaultBaseUri; - } - - String getAadAuthenticationUri() { - return aadAuthenticationUri; + return factory.generatePrivate(spec); } } diff --git a/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/certificates/KeyVaultCertificates.java b/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/certificates/KeyVaultCertificates.java index 5555d2b39620..bed0ac548998 100644 --- a/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/certificates/KeyVaultCertificates.java +++ b/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/certificates/KeyVaultCertificates.java @@ -4,6 +4,7 @@ package com.azure.security.keyvault.jca.implementation.certificates; import com.azure.security.keyvault.jca.implementation.KeyVaultClient; + import java.security.Key; import java.security.cert.Certificate; import java.util.ArrayList; @@ -19,7 +20,6 @@ * Store certificates loaded from KeyVault. */ public final class KeyVaultCertificates implements AzureCertificates { - /** * Stores the list of aliases. */ @@ -36,7 +36,7 @@ public final class KeyVaultCertificates implements AzureCertificates { private final Map certificateKeys = new HashMap<>(); /** - * Stores the last time refresh certificates and alias + * Stores the last time refresh certificates and alias. */ private Date lastRefreshTime; @@ -44,15 +44,12 @@ public final class KeyVaultCertificates implements AzureCertificates { private final long refreshInterval; - public KeyVaultCertificates(long refreshInterval, - String keyVaultUri, - String loginUri, - String tenantId, - String clientId, - String clientSecret, - String managedIdentity) { + public KeyVaultCertificates(long refreshInterval, String keyVaultUri, String tenantId, String clientId, + String clientSecret, String managedIdentity, boolean disableChallengeResourceVerification) { + this.refreshInterval = refreshInterval; - updateKeyVaultClient(keyVaultUri, loginUri, tenantId, clientId, clientSecret, managedIdentity); + + updateKeyVaultClient(keyVaultUri, tenantId, clientId, clientSecret, managedIdentity, disableChallengeResourceVerification); } public KeyVaultCertificates(long refreshInterval, KeyVaultClient keyVaultClient) { @@ -61,23 +58,21 @@ public KeyVaultCertificates(long refreshInterval, KeyVaultClient keyVaultClient) } /** - * Update KeyVaultClient + * Update KeyVaultClient. * - * @param keyVaultUri keyVault uri - * @param tenantId tenant id - * @param clientId client id - * @param clientSecret client secret - * @param managedIdentity managed identity + * @param keyVaultUri Key Vault URI. + * @param tenantId Tenant ID. + * @param clientId Client ID. + * @param clientSecret Client secret. + * @param managedIdentity Managed identity. + * @param disableChallengeResourceVerification Indicates if the challenge resource verification should be disabled. */ - public void updateKeyVaultClient(String keyVaultUri, - String loginUri, - String tenantId, - String clientId, - String clientSecret, - String managedIdentity) { + public void updateKeyVaultClient(String keyVaultUri, String tenantId, String clientId, String clientSecret, + String managedIdentity, boolean disableChallengeResourceVerification) { + if (keyVaultUri != null) { - keyVaultClient = - new KeyVaultClient(keyVaultUri, loginUri, tenantId, clientId, clientSecret, managedIdentity); + keyVaultClient = new KeyVaultClient(keyVaultUri, tenantId, clientId, clientSecret, managedIdentity, + disableChallengeResourceVerification); } else { keyVaultClient = null; } @@ -90,39 +85,43 @@ boolean certificatesNeedRefresh() { if (lastRefreshTime == null) { return true; } + return refreshInterval > 0 && lastRefreshTime.getTime() + refreshInterval < new Date().getTime(); } /** * Get certificate aliases. * - * @return certificate aliases + * @return Certificate aliases. */ @Override public List getAliases() { refreshCertificatesIfNeeded(); + return aliases; } /** * Get certificates. * - * @return certificates + * @return Certificates. */ @Override public Map getCertificates() { refreshCertificatesIfNeeded(); + return certificates; } /** * Get certificates. * - * @return certificate keys + * @return Certificate keys. */ @Override public Map getCertificateKeys() { refreshCertificatesIfNeeded(); + return certificateKeys; } @@ -138,57 +137,63 @@ private void refreshCertificatesIfNeeded() { /** * Refresh certificates. Including certificates, aliases, certificate keys. - * */ public synchronized void refreshCertificates() { // When refreshing certificates, the update of the 3 variables should be an atomic operation. aliases = keyVaultClient.getAliases(); certificateKeys.clear(); certificates.clear(); + Optional.ofNullable(aliases) - .orElse(Collections.emptyList()) - .forEach(alias -> { - Key key = keyVaultClient.getKey(alias, null); - if (!Objects.isNull(key)) { - certificateKeys.put(alias, key); - } - Certificate certificate = keyVaultClient.getCertificate(alias); - if (!Objects.isNull(certificate)) { - certificates.put(alias, certificate); - } - }); + .orElse(Collections.emptyList()) + .forEach(alias -> { + Key key = keyVaultClient.getKey(alias, null); + + if (!Objects.isNull(key)) { + certificateKeys.put(alias, key); + } + + Certificate certificate = keyVaultClient.getCertificate(alias); + + if (!Objects.isNull(certificate)) { + certificates.put(alias, certificate); + } + }); + lastRefreshTime = new Date(); } /** - * Get latest alias by certificate which in portal + * Get latest alias by certificate. + * + * @param certificate Certificate to get alias with. * - * @param certificate certificate got - * @return certificate' alias if exist. + * @return Certificate alias if it exists. */ public String refreshAndGetAliasByCertificate(Certificate certificate) { refreshCertificates(); + return getCertificates().entrySet() - .stream() - .filter(entry -> certificate.equals(entry.getValue())) - .findFirst() - .map(Map.Entry::getKey) - .orElse(null); + .stream() + .filter(entry -> certificate.equals(entry.getValue())) + .findFirst() + .map(Map.Entry::getKey) + .orElse(null); } /** - * Delete certificate info by alias if exits + * Delete certificate info by alias if exists. * - * @param alias deleted certificate + * @param alias Deleted certificate. */ @Override public void deleteEntry(String alias) { if (aliases != null) { aliases.remove(alias); } + certificates.remove(alias); certificateKeys.remove(alias); } - } diff --git a/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/utils/AccessTokenUtil.java b/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/utils/AccessTokenUtil.java index ce22adaabebd..78c97b1bb4ca 100644 --- a/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/utils/AccessTokenUtil.java +++ b/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/utils/AccessTokenUtil.java @@ -2,21 +2,27 @@ // Licensed under the MIT License. package com.azure.security.keyvault.jca.implementation.utils; -import static java.util.logging.Level.FINER; -import static java.util.logging.Level.INFO; - import com.azure.security.keyvault.jca.implementation.model.AccessToken; +import org.apache.http.HttpResponse; import java.io.UnsupportedEncodingException; +import java.net.URI; +import java.net.URISyntaxException; import java.net.URLEncoder; +import java.util.Collections; import java.util.HashMap; +import java.util.Locale; +import java.util.Map; import java.util.logging.Logger; +import static com.azure.security.keyvault.jca.implementation.utils.HttpUtil.addTrailingSlashIfRequired; +import static java.util.logging.Level.FINER; +import static java.util.logging.Level.INFO; + /** * The REST client specific to getting an access token for Azure REST APIs. */ public final class AccessTokenUtil { - /** * Stores the Client ID fragment. */ @@ -45,7 +51,7 @@ public final class AccessTokenUtil { /** * Stores the OAuth2 token postfix. */ - private static final String OAUTH2_TOKEN_POSTFIX = "/oauth2/token"; + private static final String OAUTH2_TOKEN_POSTFIX = "oauth2/token"; /** * Stores the OAuth2 managed identity URL. @@ -53,6 +59,16 @@ public final class AccessTokenUtil { private static final String OAUTH2_MANAGED_IDENTITY_TOKEN_URL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01"; + /** + * A prefix to use on the bearer token header. + */ + private static final String BEARER_TOKEN_PREFIX = "Bearer "; + + /** + * The WWW-Authenticate header name. + */ + private static final String WWW_AUTHENTICATE = "WWW-Authenticate"; + /** * Stores our logger. */ @@ -61,128 +77,276 @@ public final class AccessTokenUtil { /** * Get an access token for a managed identity. * - * @param resource the resource. - * @param identity the user-assigned identity (null if system-assigned) - * @return the authorization token. + * @param resource The resource. + * @param identity The user-assigned identity (null if system-assigned). + * + * @return The authorization token. */ public static AccessToken getAccessToken(String resource, String identity) { AccessToken result; - if (System.getenv("WEBSITE_SITE_NAME") != null - && !System.getenv("WEBSITE_SITE_NAME").isEmpty()) { + if (System.getenv("WEBSITE_SITE_NAME") != null && !System.getenv("WEBSITE_SITE_NAME").isEmpty()) { result = getAccessTokenOnAppService(resource, identity); } else { result = getAccessTokenOnOthers(resource, identity); } + return result; } /** * Get an access token. * - * @param resource the resource. - * @param tenantId the tenant ID. - * @param aadAuthenticationUrl the AAD authentication url - * @param clientId the client ID. - * @param clientSecret the client secret. - * @return the authorization token. - */ - public static AccessToken getAccessToken(String resource, String aadAuthenticationUrl, - String tenantId, String clientId, String clientSecret) { - - LOGGER.entering("AccessTokenUtil", "getAccessToken", new Object[]{ - resource, tenantId, clientId, clientSecret}); + * @param resource The resource. + * @param tenantId The tenant ID. + * @param aadAuthenticationUrl The AAD authentication url. + * @param clientId The client ID. + * @param clientSecret The client secret. + * + * @return The authorization token. + */ + public static AccessToken getAccessToken(String resource, String aadAuthenticationUrl, String tenantId, + String clientId, String clientSecret) { + + LOGGER.entering("AccessTokenUtil", "getAccessToken", + new Object[] { resource, tenantId, clientId, clientSecret }); LOGGER.info("Getting access token using client ID / client secret"); + AccessToken result = null; StringBuilder oauth2Url = new StringBuilder(); - oauth2Url.append(aadAuthenticationUrl == null ? OAUTH2_TOKEN_BASE_URL : aadAuthenticationUrl) - .append(tenantId) - .append(OAUTH2_TOKEN_POSTFIX); + + if (aadAuthenticationUrl == null) { + oauth2Url.append(OAUTH2_TOKEN_BASE_URL).append(tenantId).append("/"); + } else { + oauth2Url.append(addTrailingSlashIfRequired(aadAuthenticationUrl)); + } + + oauth2Url.append(OAUTH2_TOKEN_POSTFIX); String encodedClientSecret = ""; + try { encodedClientSecret = URLEncoder.encode(clientSecret, "UTF-8"); } catch (UnsupportedEncodingException e) { - LOGGER.warning("Failed to construct encodedClientSecret"); + LOGGER.warning("Failed to encode client secret for access token request"); } + StringBuilder requestBody = new StringBuilder(); + requestBody.append(GRANT_TYPE_FRAGMENT) - .append(CLIENT_ID_FRAGMENT).append(clientId) - .append(CLIENT_SECRET_FRAGMENT).append(encodedClientSecret) - .append(RESOURCE_FRAGMENT).append(resource); + .append(CLIENT_ID_FRAGMENT).append(clientId) + .append(CLIENT_SECRET_FRAGMENT).append(encodedClientSecret) + .append(RESOURCE_FRAGMENT).append(resource); + + String body = + HttpUtil.post(oauth2Url.toString(), requestBody.toString(), "application/x-www-form-urlencoded"); - String body = HttpUtil - .post(oauth2Url.toString(), requestBody.toString(), "application/x-www-form-urlencoded"); if (body != null) { result = (AccessToken) JsonConverterUtil.fromJson(body, AccessToken.class); } + LOGGER.log(FINER, "Access token: {0}", result); + return result; } /** * Get the access token on Azure App Service. * - * @param resource the resource. - * @param clientId the user-assigned managed identity (null if system-assigned). - * @return the authorization token. + * @param resource The resource. + * @param clientId The user-assigned managed identity (null if system-assigned). + * @return The authorization token. */ private static AccessToken getAccessTokenOnAppService(String resource, String clientId) { LOGGER.entering("AccessTokenUtil", "getAccessTokenOnAppService", resource); LOGGER.info("Getting access token using managed identity based on MSI_SECRET"); + AccessToken result = null; StringBuilder url = new StringBuilder(); + url.append(System.getenv("MSI_ENDPOINT")) - .append("?api-version=2017-09-01") - .append(RESOURCE_FRAGMENT).append(resource); + .append("?api-version=2017-09-01") + .append(RESOURCE_FRAGMENT).append(resource); + if (clientId != null) { url.append("&clientid=").append(clientId); + LOGGER.log(INFO, "Using managed identity with client ID: {0}", clientId); } HashMap headers = new HashMap<>(); + headers.put("Metadata", "true"); headers.put("Secret", System.getenv("MSI_SECRET")); + String body = HttpUtil.get(url.toString(), headers); if (body != null) { result = (AccessToken) JsonConverterUtil.fromJson(body, AccessToken.class); } + LOGGER.exiting("AccessTokenUtil", "getAccessTokenOnAppService", result); + return result; } /** * Get the authorization token on everything else but Azure App Service. * - * @param resource the resource. - * @param identity the user-assigned identity (null if system-assigned). - * @return the authorization token. + * @param resource The resource. + * @param identity The user-assigned identity (null if system-assigned). + * @return The authorization token. */ private static AccessToken getAccessTokenOnOthers(String resource, String identity) { LOGGER.entering("AccessTokenUtil", "getAccessTokenOnOthers", resource); LOGGER.info("Getting access token using managed identity"); + if (identity != null) { LOGGER.log(INFO, "Using managed identity with object ID: {0}", identity); } + AccessToken result = null; StringBuilder url = new StringBuilder(); + url.append(OAUTH2_MANAGED_IDENTITY_TOKEN_URL) .append(RESOURCE_FRAGMENT).append(resource); + if (identity != null) { url.append("&object_id=").append(identity); } HashMap headers = new HashMap<>(); + headers.put("Metadata", "true"); + String body = HttpUtil.get(url.toString(), headers); if (body != null) { result = (AccessToken) JsonConverterUtil.fromJson(body, AccessToken.class); } + LOGGER.exiting("AccessTokenUtil", "getAccessTokenOnOthers", result); + return result; } + + public static String getLoginUri(String resourceUri, boolean disableChallengeResourceVerification) { + LOGGER.entering("AccessTokenUtil", "getLoginUri", resourceUri); + LOGGER.log(INFO, "Getting login URI using: {0}", resourceUri); + + HttpResponse response = HttpUtil.getWithResponse(resourceUri, null); + + if (response == null) { + throw new IllegalStateException("Could not obtain login URI to retrieve access token from."); + } + + Map challengeAttributes = + extractChallengeAttributes(response.getFirstHeader(WWW_AUTHENTICATE).getValue()); + String scope = challengeAttributes.get("resource"); + + if (scope != null) { + scope = scope + "/.default"; + } else { + scope = challengeAttributes.get("scope"); + } + + if (scope == null) { + return null; + } else { + if (!disableChallengeResourceVerification && !isChallengeResourceValid(resourceUri, scope)) { + throw new IllegalStateException("The challenge resource " + scope + " does not match the requested " + + "domain. If you wish to disable this check, set the environment property " + + "'azure.keyvault.disable-challenge-resource-verification' to 'true'. See " + + "https://aka.ms/azsdk/blog/vault-uri for more information."); + } + + String authorization = challengeAttributes.get("authorization"); + + if (authorization == null) { + authorization = challengeAttributes.get("authorization_uri"); + } + + try { + new URI(authorization); + + LOGGER.log(INFO, "Obtained login URI: {0}", authorization); + LOGGER.exiting("AccessTokenUtil", "getLoginUri", authorization); + + return authorization; + } catch (URISyntaxException e) { + throw new IllegalStateException("The challenge authorization URI " + authorization + " is invalid.", e); + } + } + } + + /** + * Extracts attributes off the bearer challenge in the authentication header. + * + * @param authenticateHeader The authentication header containing the challenge. + * + * @return A challenge attributes map. + */ + private static Map extractChallengeAttributes(String authenticateHeader) { + if (!isBearerChallenge(authenticateHeader)) { + return Collections.emptyMap(); + } + + authenticateHeader = + authenticateHeader.toLowerCase(Locale.ROOT).replace(BEARER_TOKEN_PREFIX.toLowerCase(Locale.ROOT), ""); + + String[] attributes = authenticateHeader.split(", "); + Map attributeMap = new HashMap<>(); + + for (String pair : attributes) { + String[] keyValue = pair.split("="); + + attributeMap.put(keyValue[0].replaceAll("\"", ""), keyValue[1].replaceAll("\"", "")); + } + + return attributeMap; + } + + /** + * Verifies whether a challenge is bearer or not. + * + * @param authenticateHeader The authentication header containing all the challenges. + * + * @return A boolean indicating if the challenge is a bearer challenge or not. + */ + private static boolean isBearerChallenge(String authenticateHeader) { + return authenticateHeader != null && !authenticateHeader.isEmpty() + && authenticateHeader.toLowerCase(Locale.ROOT).startsWith(BEARER_TOKEN_PREFIX.toLowerCase(Locale.ROOT)); + } + + /** + * Verifies whether a challenge resource is valid or not. + * + * @param resource The URI to validate the challenge against. + * @param scope The scope of the challenge. + * + * @return A boolean indicating if the resource URI is valid or not. + */ + private static boolean isChallengeResourceValid(String resource, String scope) { + final URI resourceUri; + + try { + resourceUri = new URI(resource); + } catch (URISyntaxException e) { + throw new IllegalStateException("The provided resource " + resource + " is not a valid URI.", e); + } + + final URI scopeUri; + + try { + scopeUri = new URI(scope); + } catch (URISyntaxException e) { + throw new IllegalStateException("The challenge scope " + scope + " is not a valid URI.", e); + } + + // Returns false if the host specified in the scope does not match the requested domain. + return resourceUri.getHost().toLowerCase(Locale.ROOT) + .endsWith("." + scopeUri.getHost().toLowerCase(Locale.ROOT)); + } } diff --git a/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/utils/HttpUtil.java b/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/utils/HttpUtil.java index 95eb5fb142d0..e84793526cda 100644 --- a/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/utils/HttpUtil.java +++ b/sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/utils/HttpUtil.java @@ -25,6 +25,8 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; +import java.net.URI; +import java.net.URISyntaxException; import java.security.KeyManagementException; import java.security.KeyStore; import java.security.KeyStoreException; @@ -40,99 +42,184 @@ * The RestClient that uses the Apache HttpClient class. */ public final class HttpUtil { - - static final String USER_AGENT_KEY = "User-Agent"; - static final String DEFAULT_USER_AGENT_VALUE_PREFIX = "az-se-kv-jca/"; public static final String DEFAULT_VERSION = "unknown"; public static final String VERSION = Optional.of(HttpUtil.class) - .map(Class::getPackage) - .map(Package::getImplementationVersion) - .orElse(DEFAULT_VERSION); + .map(Class::getPackage) + .map(Package::getImplementationVersion) + .orElse(DEFAULT_VERSION); + + public static final String HTTPS_PREFIX = "https://"; + public static final String API_VERSION_POSTFIX = "?api-version=7.1"; public static final String USER_AGENT_VALUE = getUserAgentPrefix() + VERSION; + + static final String USER_AGENT_KEY = "User-Agent"; + static final String DEFAULT_USER_AGENT_VALUE_PREFIX = "az-se-kv-jca/"; + private static final Logger LOGGER = Logger.getLogger(HttpUtil.class.getName()); - public static String get(String url, Map headers) { + public static String get(String uri, Map headers) { String result = null; + try (CloseableHttpClient client = buildClient()) { - HttpGet httpGet = new HttpGet(url); + HttpGet httpGet = new HttpGet(uri); + if (headers != null) { headers.forEach(httpGet::addHeader); } + httpGet.addHeader(USER_AGENT_KEY, USER_AGENT_VALUE); + result = client.execute(httpGet, createResponseHandler()); } catch (IOException ioe) { - LOGGER.log(WARNING, "Unable to finish the http get request.", ioe); + LOGGER.log(WARNING, "Unable to finish the HTTP GET request.", ioe); } + return result; } - public static String post(String url, String body, String contentType) { - return post(url, null, body, contentType); + public static String post(String uri, String body, String contentType) { + return post(uri, null, body, contentType); } public static String getUserAgentPrefix() { return Optional.of(HttpUtil.class) - .map(Class::getClassLoader) - .map(c -> c.getResourceAsStream("azure-security-keyvault-jca-user-agent-value-prefix.txt")) - .map(InputStreamReader::new) - .map(BufferedReader::new) - .map(BufferedReader::lines) - .orElseGet(Stream::empty) - .findFirst() - .orElse(DEFAULT_USER_AGENT_VALUE_PREFIX); + .map(Class::getClassLoader) + .map(c -> c.getResourceAsStream("azure-security-keyvault-jca-user-agent-value-prefix.txt")) + .map(InputStreamReader::new) + .map(BufferedReader::new) + .map(BufferedReader::lines) + .orElseGet(Stream::empty) + .findFirst() + .orElse(DEFAULT_USER_AGENT_VALUE_PREFIX); } - public static String post(String url, Map headers, String body, String contentType) { + public static String post(String uri, Map headers, String body, String contentType) { String result = null; + try (CloseableHttpClient client = buildClient()) { - HttpPost httpPost = new HttpPost(url); + HttpPost httpPost = new HttpPost(uri); + httpPost.addHeader(USER_AGENT_KEY, USER_AGENT_VALUE); + if (headers != null) { headers.forEach(httpPost::addHeader); httpPost.addHeader("Content-Type", contentType); } + httpPost.setEntity(new StringEntity(body, ContentType.create(contentType))); + result = client.execute(httpPost, createResponseHandler()); } catch (IOException ioe) { - LOGGER.log(WARNING, "Unable to finish the http post request.", ioe); + LOGGER.log(WARNING, "Unable to finish the HTTP POST request.", ioe); } + return result; } + public static HttpResponse getWithResponse(String uri, Map headers) { + HttpResponse result = null; + + try (CloseableHttpClient client = buildClient()) { + HttpGet httpGet = new HttpGet(uri); + + if (headers != null) { + headers.forEach(httpGet::addHeader); + } + + httpGet.addHeader(USER_AGENT_KEY, USER_AGENT_VALUE); + + result = client.execute(httpGet, createResponseHandlerForAuthChallenge()); + } catch (IOException ioe) { + LOGGER.log(WARNING, "Unable to finish the HTTP GET request.", ioe); + } + + return result; + } private static ResponseHandler createResponseHandler() { return (HttpResponse response) -> { int status = response.getStatusLine().getStatusCode(); String result = null; + if (status >= 200 && status < 300) { HttpEntity entity = response.getEntity(); result = entity != null ? EntityUtils.toString(entity) : null; } + return result; }; } + private static ResponseHandler createResponseHandlerForAuthChallenge() { + return (HttpResponse response) -> { + int status = response.getStatusLine().getStatusCode(); + + return status == 401 ? response : null; + }; + } + private static CloseableHttpClient buildClient() { KeyStore keyStore = JreKeyStoreFactory.getDefaultKeyStore(); SSLContext sslContext = null; + try { - sslContext = SSLContexts - .custom() + sslContext = SSLContexts.custom() .loadTrustMaterial(keyStore, null) .build(); } catch (NoSuchAlgorithmException | KeyManagementException | KeyStoreException e) { - LOGGER.log(WARNING, "Unable to build the ssl context.", e); + LOGGER.log(WARNING, "Unable to build the SSL context.", e); } - SSLConnectionSocketFactory sslConnectionSocketFactory = new SSLConnectionSocketFactory( - sslContext, (HostnameVerifier) null); + SSLConnectionSocketFactory sslConnectionSocketFactory = + new SSLConnectionSocketFactory(sslContext, (HostnameVerifier) null); - PoolingHttpClientConnectionManager manager = new PoolingHttpClientConnectionManager( - RegistryBuilder.create() + PoolingHttpClientConnectionManager manager = + new PoolingHttpClientConnectionManager(RegistryBuilder.create() .register("http", PlainConnectionSocketFactory.getSocketFactory()) .register("https", sslConnectionSocketFactory) .build()); + return HttpClients.custom().setConnectionManager(manager).build(); } + + public static String validateUri(String uri, String propertyName) { + if (uri == null) { + StringBuilder messageBuilder = new StringBuilder(); + + if (propertyName != null) { + messageBuilder.append(propertyName); + } else { + messageBuilder.append("Provided URI "); + } + + messageBuilder.append("cannot be null."); + + throw new NullPointerException(messageBuilder.toString()); + } + + if (!uri.startsWith(HTTPS_PREFIX)) { + throw new IllegalArgumentException("Provided URI '" + uri + "' must start with 'https://'."); + } + + try { + new URI(uri); + } catch (URISyntaxException e) { + throw new IllegalArgumentException("Provided URI '" + uri + "' is not a valid URI."); + } + + return uri; + } + + public static String addTrailingSlashIfRequired(String uri) { + if (uri == null) { + return null; + } + + if (!uri.endsWith("/")) { + return uri + "/"; + } + + return uri; + } } diff --git a/sdk/keyvault/azure-security-keyvault-jca/src/test/java/com/azure/security/keyvault/jca/AccessTokenUtilTest.java b/sdk/keyvault/azure-security-keyvault-jca/src/test/java/com/azure/security/keyvault/jca/AccessTokenUtilTest.java index 3be4b46e3532..169a069d8037 100644 --- a/sdk/keyvault/azure-security-keyvault-jca/src/test/java/com/azure/security/keyvault/jca/AccessTokenUtilTest.java +++ b/sdk/keyvault/azure-security-keyvault-jca/src/test/java/com/azure/security/keyvault/jca/AccessTokenUtilTest.java @@ -8,6 +8,12 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import java.net.URI; + +import static com.azure.security.keyvault.jca.implementation.utils.AccessTokenUtil.getLoginUri; +import static com.azure.security.keyvault.jca.implementation.utils.HttpUtil.API_VERSION_POSTFIX; +import static com.azure.security.keyvault.jca.implementation.utils.HttpUtil.addTrailingSlashIfRequired; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertNotNull; /** @@ -15,7 +21,6 @@ */ @EnabledIfEnvironmentVariable(named = "AZURE_KEYVAULT_CERTIFICATE_NAME", matches = "myalias") public class AccessTokenUtilTest { - /** * Test getAuthorizationToken method. */ @@ -24,51 +29,21 @@ public void testGetAuthorizationToken() { String tenantId = PropertyConvertorUtils.getPropertyValue("AZURE_KEYVAULT_TENANT_ID"); String clientId = PropertyConvertorUtils.getPropertyValue("AZURE_KEYVAULT_CLIENT_ID"); String clientSecret = PropertyConvertorUtils.getPropertyValue("AZURE_KEYVAULT_CLIENT_SECRET"); - String keyVaultEndPointSuffix = PropertyConvertorUtils.getPropertyValue("KEY_VAULT_ENDPOINT_SUFFIX", ".vault.azure.net"); - CloudType cloudType = getCloudTypeByKeyVaultEndPoint(keyVaultEndPointSuffix); - String resourceUrl = getResourceUrl(cloudType); - String aadAuthenticationUrl = getAadAuthenticationUrl(cloudType); - AccessToken result = AccessTokenUtil.getAccessToken( - resourceUrl, - aadAuthenticationUrl, - tenantId, - clientId, - clientSecret - ); - assertNotNull(result); - } - - private String getResourceUrl(CloudType cloudType) { - if (CloudType.UsGov.equals(cloudType)) { - return "https://management.usgovcloudapi.net/"; - } else if (CloudType.China.equals(cloudType)) { - return "https://management.chinacloudapi.cn/"; - } - return "https://management.azure.com/"; - } + String keyVaultEndpoint = + addTrailingSlashIfRequired(PropertyConvertorUtils.getPropertyValue("AZURE_KEYVAULT_ENDPOINT")); + String aadAuthenticationUri = getLoginUri(keyVaultEndpoint + "certificates" + API_VERSION_POSTFIX, false); + AccessToken result = + AccessTokenUtil.getAccessToken(keyVaultEndpoint, aadAuthenticationUri, tenantId, clientId, clientSecret); - private String getAadAuthenticationUrl(CloudType cloudType) { - if (CloudType.UsGov.equals(cloudType)) { - return "https://login.microsoftonline.us/"; - } else if (CloudType.China.equals(cloudType)) { - return "https://login.partner.microsoftonline.cn/"; - } - return "https://login.microsoftonline.com/"; + assertNotNull(result); } - private CloudType getCloudTypeByKeyVaultEndPoint(String keyVaultEndPointSuffix) { - if (".vault.usgovcloudapi.net".equals(keyVaultEndPointSuffix)) { - return CloudType.UsGov; - } else if (".vault.azure.cn".equals(keyVaultEndPointSuffix)) { - return CloudType.China; - } - return CloudType.Public; - } + @Test + public void testGetLoginUri() { + String keyVaultEndpoint = PropertyConvertorUtils.getPropertyValue("AZURE_KEYVAULT_ENDPOINT"); + String result = getLoginUri(keyVaultEndpoint + "certificates" + API_VERSION_POSTFIX, false); - private enum CloudType { - Public, - UsGov, - China, - UNKNOWN + assertNotNull(result); + assertDoesNotThrow(() -> new URI(result)); } } diff --git a/sdk/keyvault/azure-security-keyvault-jca/src/test/java/com/azure/security/keyvault/jca/implementation/KeyVaultClientTest.java b/sdk/keyvault/azure-security-keyvault-jca/src/test/java/com/azure/security/keyvault/jca/implementation/KeyVaultClientTest.java index eb73b05ee495..cc41e9e16f29 100644 --- a/sdk/keyvault/azure-security-keyvault-jca/src/test/java/com/azure/security/keyvault/jca/implementation/KeyVaultClientTest.java +++ b/sdk/keyvault/azure-security-keyvault-jca/src/test/java/com/azure/security/keyvault/jca/implementation/KeyVaultClientTest.java @@ -9,7 +9,6 @@ import com.azure.security.keyvault.jca.implementation.utils.AccessTokenUtil; import com.azure.security.keyvault.jca.implementation.utils.HttpUtil; import com.azure.security.keyvault.jca.implementation.utils.JsonConverterUtil; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.mockito.MockedStatic; @@ -19,81 +18,28 @@ import java.util.Arrays; import java.util.List; -import static com.azure.security.keyvault.jca.implementation.KeyVaultClient.AAD_LOGIN_URI_CN; -import static com.azure.security.keyvault.jca.implementation.KeyVaultClient.AAD_LOGIN_URI_DE; -import static com.azure.security.keyvault.jca.implementation.KeyVaultClient.AAD_LOGIN_URI_GLOBAL; -import static com.azure.security.keyvault.jca.implementation.KeyVaultClient.AAD_LOGIN_URI_US; -import static com.azure.security.keyvault.jca.implementation.KeyVaultClient.KEY_VAULT_BASE_URI_CN; -import static com.azure.security.keyvault.jca.implementation.KeyVaultClient.KEY_VAULT_BASE_URI_DE; -import static com.azure.security.keyvault.jca.implementation.KeyVaultClient.KEY_VAULT_BASE_URI_GLOBAL; -import static com.azure.security.keyvault.jca.implementation.KeyVaultClient.KEY_VAULT_BASE_URI_US; -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.ArgumentMatchers.*; -import static org.mockito.Mockito.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.notNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; public class KeyVaultClientTest { - private static final String KEY_VAULT_TEST_URI_GLOBAL = "https://fake.vault.azure.net/"; - private static final String KEY_VAULT_TEST_URI_CN = "https://fake.vault.azure.cn/"; - private static final String KEY_VAULT_TEST_URI_US = "https://fake.vault.usgovcloudapi.net/"; - private static final String KEY_VAULT_TEST_URI_DE = "https://fake.vault.microsoftazure.de/"; - private static final String KEY_VAULT_TEST_URI_CUSTOM = "https://fake.vault.contoso.net/"; - private static final String KEY_VAULT_TEST_URI_BASE = "https://vault.contoso.net"; - private static final String LOGIN_TEST_URI = "https://fake.login.com"; - - private KeyVaultClient keyVaultClient; - - /** - * Test initialization of keyVaultBaseUri and aadAuthenticationUrl. - */ - @Test - public void testInitializationOfGlobalURI() { - keyVaultClient = new KeyVaultClient(KEY_VAULT_TEST_URI_GLOBAL, null); - Assertions.assertEquals(keyVaultClient.getKeyVaultBaseUri(), KEY_VAULT_BASE_URI_GLOBAL); - Assertions.assertEquals(keyVaultClient.getAadAuthenticationUri(), AAD_LOGIN_URI_GLOBAL); - } - - @Test - public void testInitializationOfCNURI() { - keyVaultClient = new KeyVaultClient(KEY_VAULT_TEST_URI_CN, null); - Assertions.assertEquals(keyVaultClient.getKeyVaultBaseUri(), KEY_VAULT_BASE_URI_CN); - Assertions.assertEquals(keyVaultClient.getAadAuthenticationUri(), AAD_LOGIN_URI_CN); - } - - @Test - public void testInitializationOfUSURI() { - keyVaultClient = new KeyVaultClient(KEY_VAULT_TEST_URI_US, null); - Assertions.assertEquals(keyVaultClient.getKeyVaultBaseUri(), KEY_VAULT_BASE_URI_US); - Assertions.assertEquals(keyVaultClient.getAadAuthenticationUri(), AAD_LOGIN_URI_US); - } - - @Test - public void testInitializationOfDEURI() { - keyVaultClient = new KeyVaultClient(KEY_VAULT_TEST_URI_DE, null); - Assertions.assertEquals(keyVaultClient.getKeyVaultBaseUri(), KEY_VAULT_BASE_URI_DE); - Assertions.assertEquals(keyVaultClient.getAadAuthenticationUri(), AAD_LOGIN_URI_DE); - } - - @Test - public void testInitializationOfLoginURI() { - keyVaultClient = new KeyVaultClient(KEY_VAULT_TEST_URI_GLOBAL, LOGIN_TEST_URI, null, null, null, null); - Assertions.assertEquals(keyVaultClient.getKeyVaultBaseUri(), KEY_VAULT_BASE_URI_GLOBAL); - Assertions.assertEquals(keyVaultClient.getAadAuthenticationUri(), LOGIN_TEST_URI + "/"); // We add a trailing slash to the login URI if missing. - } - - @Test - public void testInitializationOfLoginURIWithCustomKeyVaultURI() { - keyVaultClient = new KeyVaultClient(KEY_VAULT_TEST_URI_CUSTOM, LOGIN_TEST_URI, null, null, null, null); - Assertions.assertEquals(keyVaultClient.getKeyVaultBaseUri(), KEY_VAULT_TEST_URI_BASE); - Assertions.assertEquals(keyVaultClient.getAadAuthenticationUri(), LOGIN_TEST_URI + "/"); // We add a trailing slash to the login URI if missing. - } @Test public void testGetAliasWithCertificateInfoWith0Page() { try (MockedStatic utilities = Mockito.mockStatic(HttpUtil.class)) { utilities.when(() -> HttpUtil.get(anyString(), anyMap())).thenReturn("fakeValue"); + KeyVaultClient keyVaultClient = mock(KeyVaultClient.class); List result = keyVaultClient.getAliases(); + assertEquals(result.size(), 0); } } @@ -101,15 +47,23 @@ public void testGetAliasWithCertificateInfoWith0Page() { @Test public void testGetAliasWithCertificateInfoWith1Page() { try (MockedStatic utilities = Mockito.mockStatic(HttpUtil.class)) { - // create fake certificates + utilities.when(() -> HttpUtil.validateUri(anyString(), anyString())).thenCallRealMethod(); + utilities.when(() -> HttpUtil.addTrailingSlashIfRequired(anyString())).thenCallRealMethod(); + + // Create fake certificates. CertificateItem fakeCertificateItem1 = new CertificateItem(); fakeCertificateItem1.setId("certificates/fakeCertificateItem1"); + CertificateListResult certificateListResult = new CertificateListResult(); certificateListResult.setValue(Arrays.asList(fakeCertificateItem1)); + String certificateListResultString = JsonConverterUtil.toJson(certificateListResult); + utilities.when(() -> HttpUtil.get(notNull(), anyMap())).thenReturn(certificateListResultString); + KeyVaultClient keyVaultClient = new KeyVaultClient(KEY_VAULT_TEST_URI_GLOBAL, null); List result = keyVaultClient.getAliases(); + assertEquals(result.size(), 1); assertTrue(result.contains("fakeCertificateItem1")); } @@ -118,19 +72,25 @@ public void testGetAliasWithCertificateInfoWith1Page() { @Test public void testGetAliasWithCertificateInfoWith2Pages() { try (MockedStatic utilities = Mockito.mockStatic(HttpUtil.class)) { + utilities.when(() -> HttpUtil.validateUri(anyString(), anyString())).thenCallRealMethod(); + utilities.when(() -> HttpUtil.addTrailingSlashIfRequired(anyString())).thenCallRealMethod(); + // create fake certificates CertificateItem fakeCertificateItem1 = new CertificateItem(); fakeCertificateItem1.setId("certificates/fakeCertificateItem1"); + CertificateItem fakeCertificateItem2 = new CertificateItem(); fakeCertificateItem2.setId("certificates/fakeCertificateItem2"); + CertificateItem fakeCertificateItem3 = new CertificateItem(); fakeCertificateItem3.setId("certificates/fakeCertificateItem3"); - // create first page certificate result + // Create first page certificate result. CertificateListResult certificateListResult = new CertificateListResult(); certificateListResult.setNextLink("fakeNextLint"); certificateListResult.setValue(Arrays.asList(fakeCertificateItem1)); - // create next page certificate result + + // Create next page certificate result. CertificateListResult certificateListResultNext = new CertificateListResult(); certificateListResultNext.setValue(Arrays.asList(fakeCertificateItem2, fakeCertificateItem3)); @@ -142,6 +102,7 @@ public void testGetAliasWithCertificateInfoWith2Pages() { KeyVaultClient keyVaultClient = new KeyVaultClient(KEY_VAULT_TEST_URI_GLOBAL, null); List result = keyVaultClient.getAliases(); + assertEquals(result.size(), 3); assertTrue(result.containsAll(Arrays.asList("fakeCertificateItem1", "fakeCertificateItem2", "fakeCertificateItem3"))); } @@ -172,45 +133,73 @@ private KeyVaultClient getKeyVaultClient() { String tenantId = System.getProperty("azure.keyvault.tenant-id"); String clientId = System.getProperty("azure.keyvault.client-id"); String clientSecret = System.getProperty("azure.keyvault.client-secret"); - return new KeyVaultClient(keyVaultUri, tenantId, clientId, clientSecret); + boolean disableChallengeResourceVerification = + Boolean.parseBoolean(System.getProperty("azure.keyvault.disable-challenge-resource-verification")); + + return new KeyVaultClient(keyVaultUri, tenantId, clientId, clientSecret, null, + disableChallengeResourceVerification); } @Test public void testCacheToken() { - try (MockedStatic tokenUtilMockedStatic = Mockito.mockStatic(AccessTokenUtil.class); MockedStatic httpUtilMockedStatic = Mockito.mockStatic(HttpUtil.class)) { + try (MockedStatic tokenUtilMockedStatic = Mockito.mockStatic(AccessTokenUtil.class); + MockedStatic httpUtilMockedStatic = Mockito.mockStatic(HttpUtil.class)) { + + httpUtilMockedStatic.when(() -> HttpUtil.validateUri(anyString(), anyString())).thenCallRealMethod(); + httpUtilMockedStatic.when(() -> HttpUtil.addTrailingSlashIfRequired(anyString())).thenCallRealMethod(); + AccessToken cacheToken = new AccessToken(); cacheToken.setExpiresIn(300); // 300 seconds. + tokenUtilMockedStatic.when(() -> AccessTokenUtil.getAccessToken(anyString(), anyString())).thenReturn(cacheToken); + CertificateItem fakeCertificateItem = new CertificateItem(); fakeCertificateItem.setId("certificates/fakeCertificateItem"); + CertificateListResult certificateListResult = new CertificateListResult(); certificateListResult.setValue(Arrays.asList(fakeCertificateItem)); + String certificateListResultString = JsonConverterUtil.toJson(certificateListResult); httpUtilMockedStatic.when(() -> HttpUtil.get(anyString(), anyMap())).thenReturn(certificateListResultString); + KeyVaultClient keyVaultClient = new KeyVaultClient(KEY_VAULT_TEST_URI_GLOBAL, ""); keyVaultClient.getAliases(); - keyVaultClient.getAliases(); // get aliases the second time. - tokenUtilMockedStatic.verify(() -> AccessTokenUtil.getAccessToken(anyString(), anyString()), times(1)); + keyVaultClient.getAliases(); // Get aliases the second time. + + tokenUtilMockedStatic.verify(() -> + AccessTokenUtil.getAccessToken(anyString(), anyString()), times(1)); } } @Test public void testCacheTokenExpired() { - try (MockedStatic tokenUtilMockedStatic = Mockito.mockStatic(AccessTokenUtil.class); MockedStatic httpUtilMockedStatic = Mockito.mockStatic(HttpUtil.class)) { + try (MockedStatic tokenUtilMockedStatic = Mockito.mockStatic(AccessTokenUtil.class); + MockedStatic httpUtilMockedStatic = Mockito.mockStatic(HttpUtil.class)) { + + httpUtilMockedStatic.when(() -> HttpUtil.validateUri(anyString(), anyString())).thenCallRealMethod(); + httpUtilMockedStatic.when(() -> HttpUtil.addTrailingSlashIfRequired(anyString())).thenCallRealMethod(); + AccessToken cacheToken = new AccessToken(); cacheToken.setExpiresIn(50); // 50 seconds. + tokenUtilMockedStatic.when(() -> AccessTokenUtil.getAccessToken(anyString(), anyString())).thenReturn(cacheToken); + CertificateItem fakeCertificateItem = new CertificateItem(); fakeCertificateItem.setId("certificates/fakeCertificateItem"); + CertificateListResult certificateListResult = new CertificateListResult(); certificateListResult.setValue(Arrays.asList(fakeCertificateItem)); + String certificateListResultString = JsonConverterUtil.toJson(certificateListResult); httpUtilMockedStatic.when(() -> HttpUtil.get(anyString(), anyMap())).thenReturn(certificateListResultString); + KeyVaultClient keyVaultClient = new KeyVaultClient(KEY_VAULT_TEST_URI_GLOBAL, ""); keyVaultClient.getAliases(); - keyVaultClient.getAliases(); // get aliases the second time. - tokenUtilMockedStatic.verify(() -> AccessTokenUtil.getAccessToken(anyString(), anyString()), times(2)); + keyVaultClient.getAliases(); // Get aliases the second time. + + tokenUtilMockedStatic.verify(() -> + AccessTokenUtil.getAccessToken(anyString(), anyString()), times(2)); } } }