diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/AuthenticationJNI.java b/src/main/java/com/microsoft/sqlserver/jdbc/AuthenticationJNI.java index 0364a3a78d..bd4a04df98 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/AuthenticationJNI.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/AuthenticationJNI.java @@ -59,7 +59,6 @@ static boolean isDllLoaded() { enabled = true; } catch (UnsatisfiedLinkError e) { temp = e; - authLogger.warning("Failed to load the sqljdbc_auth.dll cause : " + e.getMessage()); // This is not re-thrown on purpose - the constructor will terminate the properly with the appropriate error // string } finally { diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java b/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java index 056cf721ff..79dd16b279 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java @@ -103,6 +103,7 @@ final class TDS { static final int TDS_FEDAUTH_LIBRARY_RESERVED = 0x7F; static final byte ADALWORKFLOW_ACTIVEDIRECTORYPASSWORD = 0x01; static final byte ADALWORKFLOW_ACTIVEDIRECTORYINTEGRATED = 0x02; + static final byte ADALWORKFLOW_ACTIVEDIRECTORYMSI = 0x03; static final byte FEDAUTH_INFO_ID_STSURL = 0x01; // FedAuthInfoData is token endpoint URL from which to acquire fed // auth token static final byte FEDAUTH_INFO_ID_SPN = 0x02; // FedAuthInfoData is the SPN to use for acquiring fed auth token diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java b/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java index a0aedb4a56..af603ef7a6 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java @@ -805,4 +805,19 @@ public interface ISQLServerDataSource extends javax.sql.CommonDataSource { * indicates whether Bulk Copy API should be used for Batch Insert operations. */ public void setUseBulkCopyForBatchInsert(boolean useBulkCopyForBatchInsert); + + /** + * Sets the client id to be used to retrieve access token from MSI EndPoint. + * + * @param msiClientId + * Client ID of User Assigned Managed Identity + */ + public void setMSIClientId(String msiClientId); + + /** + * Returns the value for the connection property 'msiClientId'. + * + * @return msiClientId property value + */ + public String getMSIClientId(); } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerADAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerADAL4JUtils.java index f0418a60c9..a94ca3cc5e 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerADAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerADAL4JUtils.java @@ -37,25 +37,26 @@ static SqlFedAuthToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String use ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID, user, password, null); AuthenticationResult authenticationResult = future.get(); - SqlFedAuthToken fedAuthToken = new SqlFedAuthToken(authenticationResult.getAccessToken(), - authenticationResult.getExpiresOnDate()); - return fedAuthToken; + return new SqlFedAuthToken(authenticationResult.getAccessToken(), authenticationResult.getExpiresOnDate()); } catch (MalformedURLException | InterruptedException e) { throw new SQLServerException(e.getMessage(), e); } catch (ExecutionException e) { MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_ADALExecution")); Object[] msgArgs = {user, authenticationString}; - // the cause error message uses \\n\\r which does not give correct format - // change it to \r\n to provide correct format + /* + * the cause error message uses \\n\\r which does not give correct format change it to \r\n to provide + * correct format + */ String correctedErrorMessage = e.getCause().getMessage().replaceAll("\\\\r\\\\n", "\r\n"); AuthenticationException correctedAuthenticationException = new AuthenticationException( correctedErrorMessage); - // SQLServerException is caused by ExecutionException, which is caused by - // AuthenticationException - // to match the exception tree before error message correction + /* + * SQLServerException is caused by ExecutionException, which is caused by AuthenticationException to match + * the exception tree before error message correction + */ ExecutionException correctedExecutionException = new ExecutionException(correctedAuthenticationException); throw new SQLServerException(form.format(msgArgs), null, 0, correctedExecutionException); @@ -69,8 +70,10 @@ static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo, ExecutorService executorService = Executors.newFixedThreadPool(1); try { - // principal name does not matter, what matters is the realm name - // it gets the username in principal_name@realm_name format + /* + * principal name does not matter, what matters is the realm name it gets the username in + * principal_name@realm_name format + */ KerberosPrincipal kerberosPrincipal = new KerberosPrincipal("username"); String username = kerberosPrincipal.getName(); @@ -83,10 +86,8 @@ static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo, ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID, username, null, null); AuthenticationResult authenticationResult = future.get(); - SqlFedAuthToken fedAuthToken = new SqlFedAuthToken(authenticationResult.getAccessToken(), - authenticationResult.getExpiresOnDate()); - return fedAuthToken; + return new SqlFedAuthToken(authenticationResult.getAccessToken(), authenticationResult.getExpiresOnDate()); } catch (InterruptedException | IOException e) { throw new SQLServerException(e.getMessage(), e); } catch (ExecutionException e) { @@ -97,15 +98,18 @@ static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo, // the case when Future's outcome has no AuthenticationResult but exception throw new SQLServerException(form.format(msgArgs), null); } else { - // the cause error message uses \\n\\r which does not give correct format - // change it to \r\n to provide correct format + /* + * the cause error message uses \\n\\r which does not give correct format change it to \r\n to provide + * correct format + */ String correctedErrorMessage = e.getCause().getMessage().replaceAll("\\\\r\\\\n", "\r\n"); AuthenticationException correctedAuthenticationException = new AuthenticationException( correctedErrorMessage); - // SQLServerException is caused by ExecutionException, which is caused by - // AuthenticationException - // to match the exception tree before error message correction + /* + * SQLServerException is caused by ExecutionException, which is caused by AuthenticationException to + * match the exception tree before error message correction + */ ExecutionException correctedExecutionException = new ExecutionException( correctedAuthenticationException); diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java index 92355d90fe..77beaf6df1 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java @@ -6,12 +6,18 @@ package com.microsoft.sqlserver.jdbc; import static java.nio.charset.StandardCharsets.UTF_16LE; +import static java.nio.charset.StandardCharsets.UTF_8; +import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; import java.net.DatagramPacket; import java.net.DatagramSocket; +import java.net.HttpURLConnection; import java.net.InetAddress; import java.net.SocketException; +import java.net.URL; import java.net.UnknownHostException; import java.sql.CallableStatement; import java.sql.Connection; @@ -25,8 +31,13 @@ import java.sql.SQLXML; import java.sql.Savepoint; import java.sql.Statement; +import java.text.DateFormat; import java.text.MessageFormat; +import java.text.SimpleDateFormat; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Calendar; +import java.util.Date; import java.util.Enumeration; import java.util.HashMap; import java.util.LinkedList; @@ -37,6 +48,7 @@ import java.util.UUID; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Level; @@ -324,10 +336,6 @@ private static int[] locateParams(String sql) { return parameterPositions.stream().mapToInt(Integer::valueOf).toArray(); } - SqlFedAuthToken getAuthenticationResult() { - return fedAuthToken; - } - /** * Encapsulates the data to be sent to the server as part of Federated Authentication Feature Extension. */ @@ -342,13 +350,16 @@ class FederatedAuthenticationFeatureExtensionData { this.libraryType = libraryType; this.fedAuthRequiredPreLoginResponse = fedAuthRequiredPreLoginResponse; - switch (authenticationString.toUpperCase(Locale.ENGLISH).trim()) { + switch (authenticationString.toUpperCase(Locale.ENGLISH)) { case "ACTIVEDIRECTORYPASSWORD": this.authentication = SqlAuthentication.ActiveDirectoryPassword; break; case "ACTIVEDIRECTORYINTEGRATED": this.authentication = SqlAuthentication.ActiveDirectoryIntegrated; break; + case "ACTIVEDIRECTORYMSI": + this.authentication = SqlAuthentication.ActiveDirectoryMSI; + break; default: assert (false); MessageFormat form = new MessageFormat( @@ -378,7 +389,12 @@ public String toString() { class ActiveDirectoryAuthentication { static final String JDBC_FEDAUTH_CLIENT_ID = "7f98cb04-cd1e-40df-9140-3bf7e2cea4db"; + static final String AZURE_REST_MSI_URL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01"; static final String ADAL_GET_ACCESS_TOKEN_FUNCTION_NAME = "ADALGetAccessToken"; + static final String ACCESS_TOKEN_IDENTIFIER = "\"access_token\":\""; + static final String ACCESS_TOKEN_EXPIRES_IN_IDENTIFIER = "\"expires_in\":\""; + static final String ACCESS_TOKEN_EXPIRES_ON_IDENTIFIER = "\"expires_on\":\""; + static final String ACCESS_TOKEN_EXPIRES_ON_DATE_FORMAT = "M/d/yyyy h:mm:ss a X"; static final int GET_ACCESS_TOKEN_SUCCESS = 0; static final int GET_ACCESS_TOKEN_INVALID_GRANT = 1; static final int GET_ACCESS_TOKEN_TANSISENT_ERROR = 2; @@ -1051,12 +1067,17 @@ void checkClosed() throws SQLServerException { SQLServerException.makeFromDriverError(null, null, SQLServerException.getErrString("R_connectionIsClosed"), null, false); } + } + protected boolean needsReconnect() throws SQLServerException { + // Check if federated Authentication is in use if (null != fedAuthToken) { - if (Util.checkIfNeedNewAccessToken(this)) { - connect(this.activeConnectionProperties, null); + // Check if access token is about to expire soon + if (Util.checkIfNeedNewAccessToken(this, fedAuthToken.expiresOn)) { + return true; } } + return false; } /** @@ -1527,7 +1548,7 @@ Connection connectInternal(Properties propsIn, if (sPropValue == null) { sPropValue = SQLServerDriverStringProperty.AUTHENTICATION.getDefaultValue(); } - authenticationString = SqlAuthentication.valueOfString(sPropValue).toString(); + authenticationString = SqlAuthentication.valueOfString(sPropValue).toString().trim(); if (integratedSecurity && !authenticationString.equalsIgnoreCase(SqlAuthentication.NotSpecified.toString())) { @@ -1565,6 +1586,19 @@ Connection connectInternal(Properties propsIn, null); } + if (authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryMSI.toString()) + && ((!activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()) + .isEmpty()) + || (!activeConnectionProperties + .getProperty(SQLServerDriverStringProperty.PASSWORD.toString()).isEmpty()))) { + if (connectionlogger.isLoggable(Level.SEVERE)) { + connectionlogger.severe( + toString() + " " + SQLServerException.getErrString("R_MSIAuthenticationWithUserPassword")); + } + throw new SQLServerException(SQLServerException.getErrString("R_MSIAuthenticationWithUserPassword"), + null); + } + if (authenticationString.equalsIgnoreCase(SqlAuthentication.SqlPassword.toString()) && ((activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()) .isEmpty()) @@ -1840,6 +1874,12 @@ else if (0 == requestedPacketSize) activeConnectionProperties.setProperty(sPropKey, SSLProtocol.valueOfString(sPropValue).toString()); } + sPropKey = SQLServerDriverStringProperty.MSI_CLIENT_ID.toString(); + sPropValue = activeConnectionProperties.getProperty(sPropKey); + if (null != sPropValue) { + activeConnectionProperties.setProperty(sPropKey, sPropValue); + } + FailoverInfo fo = null; String databaseNameProperty = SQLServerDriverStringProperty.DATABASE_NAME.toString(); String serverNameProperty = SQLServerDriverStringProperty.SERVER_NAME.toString(); @@ -2400,7 +2440,7 @@ private void connectHelper(ServerPortPlaceHolder serverInfo, int timeOutsliceInM */ void Prelogin(String serverName, int portNumber) throws SQLServerException { // Build a TDS Pre-Login packet to send to the server. - if ((!authenticationString.trim().equalsIgnoreCase(SqlAuthentication.NotSpecified.toString())) + if ((!authenticationString.equalsIgnoreCase(SqlAuthentication.NotSpecified.toString())) || (null != accessTokenInByte)) { fedAuthRequiredByUser = true; } @@ -3490,6 +3530,9 @@ int writeFedAuthFeatureRequest(boolean write, TDSWriter tdsWriter, case ActiveDirectoryIntegrated: workflow = TDS.ADALWORKFLOW_ACTIVEDIRECTORYINTEGRATED; break; + case ActiveDirectoryMSI: + workflow = TDS.ADALWORKFLOW_ACTIVEDIRECTORYMSI; + break; default: assert (false); // Unrecognized Authentication type for fedauth ADAL request break; @@ -3562,8 +3605,9 @@ private void logon(LogonCommand command) throws SQLServerException { // for FEDAUTHREQUIRED option indicates Federated Authentication is required, we have to insert FedAuth Feature // Extension // in Login7, indicating the intent to use Active Directory Authentication Library for SQL Server. - if (authenticationString.trim().equalsIgnoreCase(SqlAuthentication.ActiveDirectoryPassword.toString()) - || (authenticationString.trim().equalsIgnoreCase(SqlAuthentication.ActiveDirectoryIntegrated.toString()) + if (authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryPassword.toString()) + || ((authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryIntegrated.toString()) + || authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryMSI.toString())) && fedAuthRequiredPreLoginResponse)) { federatedAuthenticationInfoRequested = true; fedAuthFeatureExtensionData = new FederatedAuthenticationFeatureExtensionData(TDS.TDS_FEDAUTH_LIBRARY_ADAL, @@ -3982,16 +4026,16 @@ final void processFedAuthInfo(TDSReader tdsReader, TDSTokenHandler tdsTokenHandl final class FedAuthTokenCommand extends UninterruptableTDSCommand { TDSTokenHandler tdsTokenHandler = null; - SqlFedAuthToken fedAuthToken = null; + SqlFedAuthToken sqlFedAuthToken = null; - FedAuthTokenCommand(SqlFedAuthToken fedAuthToken, TDSTokenHandler tdsTokenHandler) { + FedAuthTokenCommand(SqlFedAuthToken sqlFedAuthToken, TDSTokenHandler tdsTokenHandler) { super("FedAuth"); this.tdsTokenHandler = tdsTokenHandler; - this.fedAuthToken = fedAuthToken; + this.sqlFedAuthToken = sqlFedAuthToken; } final boolean doExecute() throws SQLServerException { - sendFedAuthToken(this, fedAuthToken, tdsTokenHandler); + sendFedAuthToken(this, sqlFedAuthToken, tdsTokenHandler); return true; } } @@ -4003,8 +4047,10 @@ final boolean doExecute() throws SQLServerException { void onFedAuthInfo(SqlFedAuthInfo fedAuthInfo, TDSTokenHandler tdsTokenHandler) throws SQLServerException { assert (null != activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()) && null != activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString())) - || ((authenticationString.trim().equalsIgnoreCase( - SqlAuthentication.ActiveDirectoryIntegrated.toString()) && fedAuthRequiredPreLoginResponse)); + || (authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryIntegrated.toString()) + || authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryMSI.toString()) + && fedAuthRequiredPreLoginResponse); + assert null != fedAuthInfo; attemptRefreshTokenLocked = true; @@ -4031,14 +4077,20 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe int sleepInterval = 100; while (true) { - if (authenticationString.trim().equalsIgnoreCase(SqlAuthentication.ActiveDirectoryPassword.toString())) { + if (authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryPassword.toString())) { + validateAdalLibrary("R_ADALMissing"); fedAuthToken = SQLServerADAL4JUtils.getSqlFedAuthToken(fedAuthInfo, user, password, authenticationString); // Break out of the retry loop in successful case. break; - } else if (authenticationString.trim() - .equalsIgnoreCase(SqlAuthentication.ActiveDirectoryIntegrated.toString())) { + } else if (authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryMSI.toString())) { + fedAuthToken = getMSIAuthToken(fedAuthInfo.spn, + activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString())); + + // Break out of the retry loop in successful case. + break; + } else if (authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryIntegrated.toString())) { // If operating system is windows and sqljdbc_auth is loaded then choose the DLL authentication. if (System.getProperty("os.name").toLowerCase(Locale.ENGLISH).startsWith("windows") @@ -4051,11 +4103,9 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe // AccessToken should not be null. assert null != dllInfo.accessTokenBytes; - byte[] accessTokenFromDLL = dllInfo.accessTokenBytes; String accessToken = new String(accessTokenFromDLL, UTF_16LE); - fedAuthToken = new SqlFedAuthToken(accessToken, dllInfo.expiresIn); // Break out of the retry loop in successful case. @@ -4115,6 +4165,8 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe // so we don't need to check the // OS version here. else { + // Check if ADAL4J library is available + validateAdalLibrary("R_DLLandADALMissing"); fedAuthToken = SQLServerADAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString); } // Break out of the retry loop in successful case. @@ -4125,6 +4177,166 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe return fedAuthToken; } + private void validateAdalLibrary(String errorMessage) throws SQLServerException { + try { + Class.forName("com.microsoft.aad.adal4j.AuthenticationContext"); + } catch (ClassNotFoundException e) { + // throw Exception for missing libraries + MessageFormat form = new MessageFormat(SQLServerException.getErrString(errorMessage)); + throw new SQLServerException(form.format(new Object[] {authenticationString}), null, 0, null); + } + } + + private SqlFedAuthToken getMSIAuthToken(String resource, String msiClientId) throws SQLServerException { + // IMDS upgrade time can take up to 70s + final int imdsUpgradeTimeInMs = 70 * 1000; + final List retrySlots = new ArrayList<>(); + final String msiEndpoint = System.getenv("MSI_ENDPOINT"); + final String msiSecret = System.getenv("MSI_SECRET"); + + StringBuilder urlString = new StringBuilder(); + int retry = 1, maxRetry = 1; + + /* + * isAzureFunction is used for identifying if the current client application is running in a Virtual Machine + * (without MSI environment variables) or App Service/Function (with MSI environment variables) as the APIs to + * be called for acquiring MSI Token are different for both cases. + */ + boolean isAzureFunction = null != msiEndpoint && !msiEndpoint.isEmpty() && null != msiSecret + && !msiSecret.isEmpty(); + + if (isAzureFunction) { + urlString.append(msiEndpoint).append("?api-version=2017-09-01&resource=").append(resource); + } else { + urlString.append(ActiveDirectoryAuthentication.AZURE_REST_MSI_URL).append("&resource=").append(resource); + // Retry acquiring access token upto 20 times due to possible IMDS upgrade (Applies to VM only) + maxRetry = 20; + // Simplified variant of Exponential BackOff + for (int x = 0; x < maxRetry; x++) { + retrySlots.add(500 * ((2 << 1) - 1) / 1000); + } + } + + // Append Client Id if available + if (null != msiClientId && !msiClientId.isEmpty()) { + if (isAzureFunction) { + urlString.append("&clientid=").append(msiClientId); + } else { + urlString.append("&client_id=").append(msiClientId); + } + } + + // Loop while maxRetry reaches its limit + while (retry <= maxRetry) { + HttpURLConnection connection = null; + + try { + connection = (HttpURLConnection) new URL(urlString.toString()).openConnection(); + connection.setRequestMethod("GET"); + + if (isAzureFunction) { + connection.setRequestProperty("Secret", msiSecret); + if (connectionlogger.isLoggable(Level.FINER)) { + connectionlogger.finer(toString() + " Using Azure Function/App Service MSI auth: " + urlString); + } + } else { + connection.setRequestProperty("Metadata", "true"); + if (connectionlogger.isLoggable(Level.FINER)) { + connectionlogger.finer(toString() + " Using Azure MSI auth: " + urlString); + } + } + + connection.connect(); + + try (InputStream stream = connection.getInputStream()) { + + BufferedReader reader = new BufferedReader(new InputStreamReader(stream, UTF_8), 100); + String result = reader.readLine(); + + int startIndex_AT = result.indexOf(ActiveDirectoryAuthentication.ACCESS_TOKEN_IDENTIFIER) + + ActiveDirectoryAuthentication.ACCESS_TOKEN_IDENTIFIER.length(); + + String accessToken = result.substring(startIndex_AT, result.indexOf("\"", startIndex_AT + 1)); + + Calendar cal = new Calendar.Builder().setInstant(new Date()).build(); + + if (isAzureFunction) { + // Fetch expires_on + int startIndex_ATX = result + .indexOf(ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_ON_IDENTIFIER) + + ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_ON_IDENTIFIER.length(); + String accessTokenExpiry = result.substring(startIndex_ATX, + result.indexOf("\"", startIndex_ATX + 1)); + if (connectionlogger.isLoggable(Level.FINER)) { + connectionlogger.finer(toString() + " MSI auth token expires on: " + accessTokenExpiry); + } + + DateFormat df = new SimpleDateFormat( + ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_ON_DATE_FORMAT); + cal = new Calendar.Builder().setInstant(df.parse(accessTokenExpiry)).build(); + } else { + // Fetch expires_in + int startIndex_ATX = result + .indexOf(ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_IN_IDENTIFIER) + + ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_IN_IDENTIFIER.length(); + String accessTokenExpiry = result.substring(startIndex_ATX, + result.indexOf("\"", startIndex_ATX + 1)); + cal.add(Calendar.SECOND, Integer.parseInt(accessTokenExpiry)); + } + + return new SqlFedAuthToken(accessToken, cal.getTime()); + } + } catch (Exception e) { + retry++; + // Below code applicable only when !isAzureFunctcion (VM) + if (retry > maxRetry) { + // Do not retry if maxRetry limit has been reached. + break; + } else { + try { + int responseCode = connection.getResponseCode(); + // Check Error Response Code from Connection + if (410 == responseCode || 429 == responseCode || 404 == responseCode + || (500 <= responseCode && 599 >= responseCode)) { + try { + int retryTimeoutInMs = retrySlots.get(ThreadLocalRandom.current().nextInt(retry - 1)); + // Error code 410 indicates IMDS upgrade is in progress, which can take up to 70s + retryTimeoutInMs = (responseCode == 410 + && retryTimeoutInMs < imdsUpgradeTimeInMs) ? imdsUpgradeTimeInMs + : retryTimeoutInMs; + Thread.sleep(retryTimeoutInMs); + } catch (InterruptedException ex) { + // Throw runtime exception as driver must not be interrupted here + throw new RuntimeException(ex); + } + } else { + if (null != msiClientId && !msiClientId.isEmpty()) { + SQLServerException.makeFromDriverError(this, null, + SQLServerException.getErrString("R_MSITokenFailureClientId"), null, true); + } else { + SQLServerException.makeFromDriverError(this, null, + SQLServerException.getErrString("R_MSITokenFailureImds"), null, true); + } + } + } catch (IOException io) { + // Throw error as unexpected if response code not available + SQLServerException.makeFromDriverError(this, null, + SQLServerException.getErrString("R_MSITokenFailureUnexpected"), null, true); + } + } + } finally { + if (connection != null) { + connection.disconnect(); + } + } + } + if (retry > maxRetry) { + SQLServerException.makeFromDriverError(this, null, SQLServerException + .getErrString(isAzureFunction ? "R_MSITokenFailureEndpoint" : "R_MSITokenFailureImds"), null, true); + } + return null; + } + /** * Send the access token to the server. */ diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java index e20b0a8e87..b234e51156 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java @@ -866,6 +866,17 @@ public String getJASSConfigurationName() { SQLServerDriverStringProperty.JAAS_CONFIG_NAME.getDefaultValue()); } + @Override + public void setMSIClientId(String msiClientId) { + setStringProperty(connectionProps, SQLServerDriverStringProperty.MSI_CLIENT_ID.toString(), msiClientId); + } + + @Override + public String getMSIClientId() { + return getStringProperty(connectionProps, SQLServerDriverStringProperty.MSI_CLIENT_ID.toString(), + SQLServerDriverStringProperty.MSI_CLIENT_ID.getDefaultValue()); + } + /** * Sets a property string value. * diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java index 30a0c4d761..222a897dc5 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java @@ -65,7 +65,8 @@ enum SqlAuthentication { NotSpecified, SqlPassword, ActiveDirectoryPassword, - ActiveDirectoryIntegrated; + ActiveDirectoryIntegrated, + ActiveDirectoryMSI; static SqlAuthentication valueOfString(String value) throws SQLServerException { SqlAuthentication method = null; @@ -80,6 +81,8 @@ static SqlAuthentication valueOfString(String value) throws SQLServerException { } else if (value.toLowerCase(Locale.US) .equalsIgnoreCase(SqlAuthentication.ActiveDirectoryIntegrated.toString())) { method = SqlAuthentication.ActiveDirectoryIntegrated; + } else if (value.toLowerCase(Locale.US).equalsIgnoreCase(SqlAuthentication.ActiveDirectoryMSI.toString())) { + method = SqlAuthentication.ActiveDirectoryMSI; } else { MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_InvalidConnectionSetting")); Object[] msgArgs = {"authentication", value}; @@ -280,7 +283,8 @@ enum SQLServerDriverStringProperty { KEY_STORE_AUTHENTICATION("keyStoreAuthentication", ""), KEY_STORE_SECRET("keyStoreSecret", ""), KEY_STORE_LOCATION("keyStoreLocation", ""), - SSL_PROTOCOL("sslProtocol", SSLProtocol.TLS.toString()),; + SSL_PROTOCOL("sslProtocol", SSLProtocol.TLS.toString()), + MSI_CLIENT_ID("msiClientId", ""),; private final String name; private final String defaultValue; @@ -500,6 +504,8 @@ public final class SQLServerDriver implements java.sql.Driver { SQLServerDriverStringProperty.SSL_PROTOCOL.getDefaultValue(), false, new String[] {SSLProtocol.TLS.toString(), SSLProtocol.TLS_V10.toString(), SSLProtocol.TLS_V11.toString(), SSLProtocol.TLS_V12.toString()}), + new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString(), + SQLServerDriverStringProperty.MSI_CLIENT_ID.getDefaultValue(), false, null), new SQLServerDriverPropertyInfo(SQLServerDriverIntProperty.CANCEL_QUERY_TIMEOUT.toString(), Integer.toString(SQLServerDriverIntProperty.CANCEL_QUERY_TIMEOUT.getDefaultValue()), false, null), new SQLServerDriverPropertyInfo(SQLServerDriverBooleanProperty.USE_BULK_COPY_FOR_BATCH_INSERT.toString(), diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPooledConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPooledConnection.java index 91b8a51442..31ca83aa8e 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPooledConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPooledConnection.java @@ -29,11 +29,11 @@ public class SQLServerPooledConnection implements PooledConnection { private SQLServerConnectionPoolProxy lastProxyConnection; private String factoryUser, factoryPassword; private java.util.logging.Logger pcLogger; - static private final AtomicInteger basePooledConnectionID = new AtomicInteger(0); // Unique id generator for each - // PooledConnection instance - // (used for logging). private final String traceID; + // Unique id generator for each PooledConnection instance (used for logging). + static private final AtomicInteger basePooledConnectionID = new AtomicInteger(0); + SQLServerPooledConnection(SQLServerDataSource ds, String user, String password) throws SQLException { listeners = new Vector<>(); // Piggyback SQLServerDataSource logger for now. @@ -65,7 +65,12 @@ public String toString() { return traceID; } - // Helper function to create a new connection for the pool. + /** + * Helper function to create a new connection for the pool. + * + * @return SQLServerConnection instance + * @throws SQLException + */ private SQLServerConnection createNewConnection() throws SQLException { return factoryDataSource.getConnectionInternal(factoryUser, factoryPassword, this); } @@ -88,28 +93,38 @@ public Connection getConnection() throws SQLException { SQLServerException.getErrString("R_physicalConnectionIsClosed"), "", true); } - // Check with security manager to insure caller has rights to connect. - // This will throw a SecurityException if the caller does not have proper rights. + /* + * Check with security manager to insure caller has rights to connect. This will throw a SecurityException + * if the caller does not have proper rights. + */ physicalConnection.doSecurityCheck(); if (pcLogger.isLoggable(Level.FINE)) pcLogger.fine(toString() + " Physical connection, " + safeCID()); - if (null != physicalConnection.getAuthenticationResult()) { - if (Util.checkIfNeedNewAccessToken(physicalConnection)) { - physicalConnection = createNewConnection(); - } + if (physicalConnection.needsReconnect()) { + physicalConnection.close(); + physicalConnection = createNewConnection(); } - // The last proxy connection handle returned will be invalidated (moved to closed state) - // when getConnection is called. + /* + * The last proxy connection handle returned will be invalidated (moved to closed state) when getConnection + * is called. + */ if (null != lastProxyConnection) { // if there was a last proxy connection send reset physicalConnection.resetPooledConnection(); - if (pcLogger.isLoggable(Level.FINE) && !lastProxyConnection.isClosed()) - pcLogger.fine(toString() + "proxy " + lastProxyConnection.toString() - + " is not closed before getting the connection."); - // use internal close so there wont be an event due to us closing the connection, if not closed already. - lastProxyConnection.internalClose(); + + if (!lastProxyConnection.isClosed()) { + if (pcLogger.isLoggable(Level.FINE)) { + pcLogger.fine(toString() + "proxy " + lastProxyConnection.toString() + + " is not closed before getting the connection."); + } + /* + * use internal close so there wont be an event due to us closing the connection, if not closed + * already. + */ + lastProxyConnection.internalClose(); + } } lastProxyConnection = new SQLServerConnectionPoolProxy(physicalConnection); @@ -222,8 +237,10 @@ private static int nextPooledConnectionID() { return basePooledConnectionID.incrementAndGet(); } - // Helper function to return connectionID of the physicalConnection in a safe manner for logging. - // Returns (null) if physicalConnection is null, otherwise returns connectionID. + /** + * Helper function to return connectionID of the physicalConnection in a safe manner for logging. Returns (null) if + * physicalConnection is null, otherwise returns connectionID. + **/ private String safeCID() { if (null == physicalConnection) return " ConnectionID:(null)"; diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java index 7f30ca791e..d3a6e7303c 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java @@ -245,6 +245,8 @@ protected Object[][] getContents() { {"R_statementPoolingCacheSizePropertyDescription", "This setting specifies the size of the prepared statement cache for a connection. A value less than 1 means no cache."}, {"R_gsscredentialPropertyDescription", "Impersonated GSS Credential to access SQL Server."}, + {"R_msiClientIdPropertyDescription", + "Client Id of User Assigned Managed Identity to be used for generating access token for Azure AD MSI Authentication"}, {"R_noParserSupport", "An error occurred while instantiating the required parser. Error: \"{0}\""}, {"R_writeOnlyXML", "Cannot read from this SQLXML instance. This instance is for writing data only."}, {"R_dataHasBeenReadXML", "Cannot read from this SQLXML instance. The data has already been read."}, @@ -386,6 +388,8 @@ protected Object[][] getContents() { "Cannot set the AccessToken property if the \"IntegratedSecurity\" connection string keyword has been set to \"true\"."}, {"R_IntegratedAuthenticationWithUserPassword", "Cannot use \"Authentication=ActiveDirectoryIntegrated\" with \"User\", \"UserName\" or \"Password\" connection string keywords."}, + {"R_MSIAuthenticationWithUserPassword", + "Cannot use \"Authentication=ActiveDirectoryMSI\" with \"User\", \"UserName\" or \"Password\" connection string keywords."}, {"R_AccessTokenWithUserPassword", "Cannot set the AccessToken property if \"User\", \"UserName\" or \"Password\" has been specified in the connection string."}, {"R_AccessTokenCannotBeEmpty", "AccesToken cannot be empty."}, @@ -534,5 +538,14 @@ protected Object[][] getContents() { {"R_unknownUTF8SupportValue", "Unknown value for UTF8 support."}, {"R_illegalWKT", "Illegal Well-Known text. Please make sure Well-Known text is valid."}, {"R_illegalTypeForGeometry", "{0} is not supported for Geometry."}, - {"R_illegalWKTposition", "Illegal character in Well-Known text at position {0}."},}; + {"R_illegalWKTposition", "Illegal character in Well-Known text at position {0}."}, + {"R_ADALMissing", "Failed to load ADAL4J Java library for performing {0} authentication."}, + {"R_DLLandADALMissing", + "Failed to load both sqljdbc_auth.dll and ADAL4J Java library for performing {0} authentication. Please install one of them to proceed."}, + {"R_MSITokenFailureImds", "MSI Token failure: Failed to acquire access token from IMDS"}, + {"R_MSITokenFailureImdsClientId", + "MSI Token failure: Failed to acquire access token from IMDS, verify your clientId."}, + {"R_MSITokenFailureUnexpected", + "MSI Token failure: Failed to acquire access token from IMDS, unexpected error occurred."}, + {"R_MSITokenFailureEndpoint", "MSI Token failure: Failed to acquire token from MSI Endpoint"}}; } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SqlFedAuthToken.java b/src/main/java/com/microsoft/sqlserver/jdbc/SqlFedAuthToken.java index 01adb70a10..21aedbb8ec 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SqlFedAuthToken.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SqlFedAuthToken.java @@ -1,31 +1,27 @@ -/* - * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made - * available under the terms of the MIT License. See the LICENSE file in the project root for more information. - */ - -package com.microsoft.sqlserver.jdbc; - -import java.util.Date; - - -class SqlFedAuthToken { - final Date expiresOn; - final String accessToken; - - SqlFedAuthToken(final String accessToken, final long expiresIn) { - this.accessToken = accessToken; - - Date now = new Date(); - now.setTime(now.getTime() + (expiresIn * 1000)); - this.expiresOn = now; - } - - SqlFedAuthToken(final String accessToken, final Date expiresOn) { - this.accessToken = accessToken; - this.expiresOn = expiresOn; - } - - Date getExpiresOnDate() { - return expiresOn; - } -} +/* + * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made + * available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ + +package com.microsoft.sqlserver.jdbc; + +import java.util.Date; + + +class SqlFedAuthToken { + final Date expiresOn; + final String accessToken; + + SqlFedAuthToken(String accessToken, long expiresIn) { + this.accessToken = accessToken; + + Date now = new Date(); + now.setTime(now.getTime() + (expiresIn * 1000)); + this.expiresOn = now; + } + + SqlFedAuthToken(String accessToken, Date expiresOn) { + this.accessToken = accessToken; + this.expiresOn = expiresOn; + } +} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/Util.java b/src/main/java/com/microsoft/sqlserver/jdbc/Util.java index 76617d31a0..0f08dbce90 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/Util.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/Util.java @@ -936,8 +936,7 @@ else if (("" + value).contains("E")) { // If the token is expiring within the next 45 mins, try to fetch a new token if there is no thread already doing // it. // If a thread is already doing the refresh, just use the existing token and proceed. - static synchronized boolean checkIfNeedNewAccessToken(SQLServerConnection connection) { - Date accessTokenExpireDate = connection.getAuthenticationResult().getExpiresOnDate(); + static synchronized boolean checkIfNeedNewAccessToken(SQLServerConnection connection, Date accessTokenExpireDate) { Date now = new Date(); // if the token's expiration is within the next 45 mins @@ -957,7 +956,6 @@ static synchronized boolean checkIfNeedNewAccessToken(SQLServerConnection connec } } } - return false; }