diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java b/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java index 5d65f2a109..8229d4cd1b 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java @@ -1402,70 +1402,63 @@ private String parseCommonName(String distinguishedName) { private boolean validateServerName(String nameInCert) { // Failed to get the common name from DN or empty CN if (null == nameInCert) { - if (logger.isLoggable(Level.FINER)) + if (logger.isLoggable(Level.FINER)) { logger.finer(logContext + " Failed to parse the name from the certificate or name is empty."); + } return false; } - - int wildcardIndex = nameInCert.indexOf("*"); - - // Respect wildcard. If wildcardIndex is larger than -1, then we have a wildcard. - if (wildcardIndex >= 0) { - // We do not allow wildcards to exist past the first period. - if (wildcardIndex > nameInCert.indexOf(".")) { - return false; - } - - // We do not allow wildcards in IDNs. - if (nameInCert.startsWith("xn--")) { - return false; - } - - /* We do not allow * plus a top-level domain. - * This if statement counts the number of .s in the nameInCert. If it's 1 or less, then reject it. - * This also catches cases where nameInCert is just * - */ - if ((nameInCert.length() - nameInCert.replace(".", "").length()) <= 1) { - return false; + // We do not allow wildcards in IDNs (xn--). + if (!nameInCert.startsWith("xn--") && nameInCert.contains("*")) { + int hostIndex = 0, certIndex = 0, match = 0, startIndex = -1, periodCount = 0; + while (hostIndex < hostName.length()) { + if ('.' == hostName.charAt(hostIndex)) { + periodCount++; + } + if (certIndex < nameInCert.length() && hostName.charAt(hostIndex) == nameInCert.charAt(certIndex)) { + hostIndex++; + certIndex++; + } else if (certIndex < nameInCert.length() && '*' == nameInCert.charAt(certIndex)) { + startIndex = certIndex; + match = hostIndex; + certIndex++; + } else if (startIndex != -1 && 0 == periodCount) { + certIndex = startIndex + 1; + match++; + hostIndex = match; + } else { + logFailMessage(nameInCert); + return false; + } } - - String certBeforeWildcard = nameInCert.substring(0, wildcardIndex); - int firstPeriodAfterWildcard = nameInCert.indexOf(".", wildcardIndex); - String certAfterWildcard; - - if (firstPeriodAfterWildcard < 0) { - /* if we get something like peter.database.c*, then make certAfterWildcard empty so that we accept - * anything after *. - * both startsWith("") and endswith("") will always resolve to "true". - */ - certAfterWildcard = ""; + if (nameInCert.length() == certIndex && periodCount > 1) { + logSuccessMessage(nameInCert); + return true; } else { - certAfterWildcard = nameInCert.substring(firstPeriodAfterWildcard); - } - - if (hostName.startsWith(certBeforeWildcard) && hostName.endsWith(certAfterWildcard)) { - // now, find the string that the wildcard covers. If it contains any periods, reject it. - int wildcardCoveredStringIndexStart = hostName.indexOf(certBeforeWildcard) + certBeforeWildcard.length(); - int wildcardCoveredStringIndexEnd = hostName.lastIndexOf(certAfterWildcard); - if (!hostName.substring(wildcardCoveredStringIndexStart, wildcardCoveredStringIndexEnd).contains(".")) { - return true; - } + logFailMessage(nameInCert); + return false; } } - // Verify that the name in certificate matches exactly with the host name if (!nameInCert.equals(hostName)) { - if (logger.isLoggable(Level.FINER)) - logger.finer(logContext + " The name in certificate " + nameInCert - + " does not match with the server name " + hostName + "."); + logFailMessage(nameInCert); return false; } + logSuccessMessage(nameInCert); + return true; + } - if (logger.isLoggable(Level.FINER)) + private void logFailMessage(String nameInCert) { + if (logger.isLoggable(Level.FINER)) { + logger.finer(logContext + " The name in certificate " + nameInCert + + " does not match with the server name " + hostName + "."); + } + } + + private void logSuccessMessage(String nameInCert) { + if (logger.isLoggable(Level.FINER)) { logger.finer(logContext + " The name in certificate:" + nameInCert + " validated against server name " + hostName + "."); - - return true; + } } public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { @@ -4636,8 +4629,8 @@ void writeTVPRows(TVP value) throws SQLServerException { SQLServerError databaseError = new SQLServerError(); databaseError.setFromTDS(tdsReader); - SQLServerException.makeFromDatabaseError(con, null, databaseError.getErrorMessage(), databaseError, - false); + SQLServerException.makeFromDatabaseError(con, null, databaseError.getErrorMessage(), + databaseError, false); } command.setInterruptsEnabled(true); diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/SSLCertificateValidation.java b/src/test/java/com/microsoft/sqlserver/jdbc/SSLCertificateValidation.java index 2008bb5d3f..9dd1e96441 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/SSLCertificateValidation.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/SSLCertificateValidation.java @@ -47,6 +47,11 @@ public void testValidateServerName() throws Exception { // Expected result: true assertTrue((boolean) method.invoke(hsoObject, "msjdbc.database.windows.net")); + // Server Name = msjdbc.database.windows.net + // SAN = msjdbc***.database.windows.net + // Expected result: true + assertTrue((boolean) method.invoke(hsoObject, "msjdbc***.database.windows.net")); + // Server Name = msjdbc.database.windows.net // SAN = ms*bc.database.windows.net // Expected result: true