Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 45 additions & 52 deletions src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -1402,70 +1402,63 @@ private String parseCommonName(String distinguishedName) {
private boolean validateServerName(String nameInCert) {
// Failed to get the common name from DN or empty CN
if (null == nameInCert) {
if (logger.isLoggable(Level.FINER))
if (logger.isLoggable(Level.FINER)) {
logger.finer(logContext + " Failed to parse the name from the certificate or name is empty.");
}
return false;
}

int wildcardIndex = nameInCert.indexOf("*");

// Respect wildcard. If wildcardIndex is larger than -1, then we have a wildcard.
if (wildcardIndex >= 0) {
// We do not allow wildcards to exist past the first period.
if (wildcardIndex > nameInCert.indexOf(".")) {
return false;
}

// We do not allow wildcards in IDNs.
if (nameInCert.startsWith("xn--")) {
return false;
}

/* We do not allow * plus a top-level domain.
* This if statement counts the number of .s in the nameInCert. If it's 1 or less, then reject it.
* This also catches cases where nameInCert is just *
*/
if ((nameInCert.length() - nameInCert.replace(".", "").length()) <= 1) {
return false;
// We do not allow wildcards in IDNs (xn--).
if (!nameInCert.startsWith("xn--") && nameInCert.contains("*")) {
int hostIndex = 0, certIndex = 0, match = 0, startIndex = -1, periodCount = 0;
while (hostIndex < hostName.length()) {
if ('.' == hostName.charAt(hostIndex)) {
periodCount++;
}
if (certIndex < nameInCert.length() && hostName.charAt(hostIndex) == nameInCert.charAt(certIndex)) {
hostIndex++;
certIndex++;
} else if (certIndex < nameInCert.length() && '*' == nameInCert.charAt(certIndex)) {
startIndex = certIndex;
match = hostIndex;
certIndex++;
} else if (startIndex != -1 && 0 == periodCount) {
certIndex = startIndex + 1;
match++;
hostIndex = match;
} else {
logFailMessage(nameInCert);
return false;
}
}

String certBeforeWildcard = nameInCert.substring(0, wildcardIndex);
int firstPeriodAfterWildcard = nameInCert.indexOf(".", wildcardIndex);
String certAfterWildcard;

if (firstPeriodAfterWildcard < 0) {
/* if we get something like peter.database.c*, then make certAfterWildcard empty so that we accept
* anything after *.
* both startsWith("") and endswith("") will always resolve to "true".
*/
certAfterWildcard = "";
if (nameInCert.length() == certIndex && periodCount > 1) {
logSuccessMessage(nameInCert);
return true;
} else {
certAfterWildcard = nameInCert.substring(firstPeriodAfterWildcard);
}

if (hostName.startsWith(certBeforeWildcard) && hostName.endsWith(certAfterWildcard)) {
// now, find the string that the wildcard covers. If it contains any periods, reject it.
int wildcardCoveredStringIndexStart = hostName.indexOf(certBeforeWildcard) + certBeforeWildcard.length();
int wildcardCoveredStringIndexEnd = hostName.lastIndexOf(certAfterWildcard);
if (!hostName.substring(wildcardCoveredStringIndexStart, wildcardCoveredStringIndexEnd).contains(".")) {
return true;
}
logFailMessage(nameInCert);
return false;
}
}

// Verify that the name in certificate matches exactly with the host name
if (!nameInCert.equals(hostName)) {
if (logger.isLoggable(Level.FINER))
logger.finer(logContext + " The name in certificate " + nameInCert
+ " does not match with the server name " + hostName + ".");
logFailMessage(nameInCert);
return false;
}
logSuccessMessage(nameInCert);
return true;
}

if (logger.isLoggable(Level.FINER))
private void logFailMessage(String nameInCert) {
if (logger.isLoggable(Level.FINER)) {
logger.finer(logContext + " The name in certificate " + nameInCert
+ " does not match with the server name " + hostName + ".");
}
}

private void logSuccessMessage(String nameInCert) {
if (logger.isLoggable(Level.FINER)) {
logger.finer(logContext + " The name in certificate:" + nameInCert + " validated against server name "
+ hostName + ".");

return true;
}
}

public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
Expand Down Expand Up @@ -4636,8 +4629,8 @@ void writeTVPRows(TVP value) throws SQLServerException {
SQLServerError databaseError = new SQLServerError();
databaseError.setFromTDS(tdsReader);

SQLServerException.makeFromDatabaseError(con, null, databaseError.getErrorMessage(), databaseError,
false);
SQLServerException.makeFromDatabaseError(con, null, databaseError.getErrorMessage(),
databaseError, false);
}

command.setInterruptsEnabled(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ public void testValidateServerName() throws Exception {
// Expected result: true
assertTrue((boolean) method.invoke(hsoObject, "msjdbc.database.windows.net"));

// Server Name = msjdbc.database.windows.net
// SAN = msjdbc***.database.windows.net
// Expected result: true
assertTrue((boolean) method.invoke(hsoObject, "msjdbc***.database.windows.net"));

// Server Name = msjdbc.database.windows.net
// SAN = ms*bc.database.windows.net
// Expected result: true
Expand Down