diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/KerbAuthentication.java b/src/main/java/com/microsoft/sqlserver/jdbc/KerbAuthentication.java index 6706d84a61..c0973007bb 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/KerbAuthentication.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/KerbAuthentication.java @@ -8,15 +8,22 @@ package com.microsoft.sqlserver.jdbc; +import java.lang.reflect.Method; import java.net.IDN; +import java.net.InetAddress; +import java.net.UnknownHostException; import java.security.AccessControlContext; import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.HashMap; +import java.util.Locale; import java.util.Map; import java.util.logging.Level; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import javax.naming.NamingException; import javax.security.auth.Subject; import javax.security.auth.login.AppConfigurationEntry; import javax.security.auth.login.Configuration; @@ -30,6 +37,8 @@ import org.ietf.jgss.GSSName; import org.ietf.jgss.Oid; +import com.microsoft.sqlserver.jdbc.dns.DNSKerberosLocator; + /** * KerbAuthentication for int auth. */ @@ -247,6 +256,7 @@ private String makeSpn(String server, // Get user provided SPN string; if not provided then build the generic one String userSuppliedServerSpn = con.activeConnectionProperties.getProperty(SQLServerDriverStringProperty.SERVER_SPN.toString()); + String spn; if (null != userSuppliedServerSpn) { // serverNameAsACE is true, translate the user supplied serverSPN to ASCII if (con.serverNameAsACE()) { @@ -260,6 +270,152 @@ private String makeSpn(String server, else { spn = makeSpn(address, port); } + this.spn = enrichSpnWithRealm(spn, null == userSuppliedServerSpn); + if (!this.spn.equals(spn) && authLogger.isLoggable(Level.FINER)){ + authLogger.finer(toString() + "SPN enriched: " + spn + " := " + this.spn); + } + } + + private static final Pattern SPN_PATTERN = Pattern.compile("MSSQLSvc/(.*):([^:@]+)(@.+)?", Pattern.CASE_INSENSITIVE); + + private String enrichSpnWithRealm(String spn, + boolean allowHostnameCanonicalization) { + if (spn == null) { + return spn; + } + Matcher m = SPN_PATTERN.matcher(spn); + if (!m.matches()) { + return spn; + } + if (m.group(3) != null) { + // Realm is already present, no need to enrich, the job has already been done + return spn; + } + String dnsName = m.group(1); + String portOrInstance = m.group(2); + RealmValidator realmValidator = getRealmValidator(dnsName); + String realm = findRealmFromHostname(realmValidator, dnsName); + if (realm == null && allowHostnameCanonicalization) { + // We failed, try with canonical host name to find a better match + try { + String canonicalHostName = InetAddress.getByName(dnsName).getCanonicalHostName(); + realm = findRealmFromHostname(realmValidator, canonicalHostName); + // Since we have a match, our hostname is the correct one (for instance of server + // name was an IP), so we override dnsName as well + dnsName = canonicalHostName; + } + catch (UnknownHostException cannotCanonicalize) { + // ignored, but we are in a bad shape + } + } + if (realm == null) { + return spn; + } + else { + StringBuilder sb = new StringBuilder("MSSQLSvc/"); + sb.append(dnsName).append(":").append(portOrInstance).append("@").append(realm.toUpperCase(Locale.ENGLISH)); + return sb.toString(); + } + } + + private static RealmValidator validator; + + /** + * Find a suitable way of validating a REALM for given JVM. + * + * @param hostnameToTest + * an example hostname we are gonna use to test our realm validator. + * @return a not null realm Validator. + */ + static RealmValidator getRealmValidator(String hostnameToTest) { + if (validator != null) { + return validator; + } + // JVM Specific, here Sun/Oracle JVM + try { + Class clz = Class.forName("sun.security.krb5.Config"); + Method getInstance = clz.getMethod("getInstance", new Class[0]); + final Method getKDCList = clz.getMethod("getKDCList", new Class[] {String.class}); + final Object instance = getInstance.invoke(null); + RealmValidator oracleRealmValidator = new RealmValidator() { + + @Override + public boolean isRealmValid(String realm) { + try { + Object ret = getKDCList.invoke(instance, realm); + return ret != null; + } + catch (Exception err) { + return false; + } + } + }; + validator = oracleRealmValidator; + // As explained here: https://github.com/Microsoft/mssql-jdbc/pull/40#issuecomment-281509304 + // The default Oracle Resolution mechanism is not bulletproof + // If it resolves a crappy name, drop it. + if (!validator.isRealmValid("this.might.not.exist." + hostnameToTest)) { + // Our realm validator is well working, return it + authLogger.fine("Kerberos Realm Validator: Using Built-in Oracle Realm Validation method."); + return oracleRealmValidator; + } + authLogger.fine("Kerberos Realm Validator: Detected buggy Oracle Realm Validator, using DNSKerberosLocator."); + } + catch (ReflectiveOperationException notTheRightJVMException) { + // Ignored, we simply are not using the right JVM + authLogger.fine("Kerberos Realm Validator: No Oracle Realm Validator Available, using DNSKerberosLocator."); + } + // No implementation found, default one, not any realm is valid + validator = new RealmValidator() { + @Override + public boolean isRealmValid(String realm) { + try { + return DNSKerberosLocator.isRealmValid(realm); + } + catch (NamingException err) { + return false; + } + } + }; + return validator; + } + + /** + * Try to find a REALM in the different parts of a host name. + * + * @param realmValidator + * a function that return true if REALM is valid and exists + * @param hostname + * the name we are looking a REALM for + * @return the realm if found, null otherwise + */ + private String findRealmFromHostname(RealmValidator realmValidator, + String hostname) { + if (hostname == null) { + return null; + } + int index = 0; + while (index != -1 && index < hostname.length() - 2) { + String realm = hostname.substring(index); + if (authLogger.isLoggable(Level.FINEST)) { + authLogger.finest(toString() + " looking up REALM candidate " + realm); + } + if (realmValidator.isRealmValid(realm)) { + return realm.toUpperCase(); + } + index = hostname.indexOf(".", index + 1); + if (index != -1) { + index = index + 1; + } + } + return null; + } + + /** + * JVM Specific implementation to decide whether a realm is valid or not + */ + interface RealmValidator { + boolean isRealmValid(String realm); } byte[] GenerateClientContext(byte[] pin, diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSKerberosLocator.java b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSKerberosLocator.java new file mode 100644 index 0000000000..77bb67b0f1 --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSKerberosLocator.java @@ -0,0 +1,44 @@ +/* + * Microsoft JDBC Driver for SQL Server + * + * Copyright(c) Microsoft Corporation All rights reserved. + * + * This program is made available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ +package com.microsoft.sqlserver.jdbc.dns; + +import java.util.Set; + +import javax.naming.NameNotFoundException; +import javax.naming.NamingException; + +public final class DNSKerberosLocator { + + private DNSKerberosLocator() { + } + + /** + * Tells whether a realm is valid. + * + * @param realmName + * the realm to test + * @return true if realm is valid, false otherwise + * @throws NamingException + * if DNS failed, so realm existence cannot be determined + */ + public static boolean isRealmValid(String realmName) throws NamingException { + if (realmName == null || realmName.length() < 2) { + return false; + } + if (realmName.startsWith(".")) { + realmName = realmName.substring(1); + } + try { + Set records = DNSUtilities.findSrvRecords("_kerberos._udp." + realmName); + return !records.isEmpty(); + } + catch (NameNotFoundException wrongDomainException) { + return false; + } + } +} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSRecordSRV.java b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSRecordSRV.java new file mode 100644 index 0000000000..ba419bb755 --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSRecordSRV.java @@ -0,0 +1,170 @@ +/* + * Microsoft JDBC Driver for SQL Server + * + * Copyright(c) Microsoft Corporation All rights reserved. + * + * This program is made available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ +package com.microsoft.sqlserver.jdbc.dns; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Describe an DNS SRV Record. + */ +public class DNSRecordSRV implements Comparable { + + private static final Pattern PATTERN = Pattern.compile("^([0-9]+) ([0-9]+) ([0-9]+) (.+)$"); + + private final int priority; + + /** + * Parse a DNS SRC Record from a DNS String record. + * + * @param record + * the record to parse + * @return a not null DNS Record + * @throws IllegalArgumentException + * if record is not correct and cannot be parsed + */ + public static DNSRecordSRV parseFromDNSRecord(String record) throws IllegalArgumentException { + Matcher m = PATTERN.matcher(record); + if (!m.matches()) { + throw new IllegalArgumentException("record '" + record + "' cannot be matched as a valid DNS SRV Record"); + } + try { + int priority = Integer.parseInt(m.group(1)); + int weight = Integer.parseInt(m.group(2)); + int port = Integer.parseInt(m.group(3)); + String serverName = m.group(4); + // Avoid issues with Kerberos SPN when fully qualified records ends with '.' + if (serverName.endsWith(".")) { + serverName = serverName.substring(0, serverName.length() - 1); + } + return new DNSRecordSRV(priority, weight, port, serverName); + } + catch (IllegalArgumentException err) { + throw err; + } + catch (Exception err) { + throw new IllegalArgumentException("Failed to parse DNS SRV record '" + record + "'", err); + } + } + + @Override + public String toString() { + return String.format("DNS.SRV[pri=%d w=%d port=%d h='%s']", priority, weight, port, serverName); + } + + /** + * Constructor. + * + * @param priority + * is lowest + * @param weight + * 1 at minimum + * @param port + * the port of service + * @param serverName + * the host + * @throws IllegalArgumentException + * if priority < 0 or weight <= 1 + */ + public DNSRecordSRV(int priority, + int weight, + int port, + String serverName) throws IllegalArgumentException { + if (priority < 0) { + throw new IllegalArgumentException("priority must be >= 0, but was: " + priority); + } + this.priority = priority; + if (weight < 0) { + // Weight == 0 is OK to disable load balancing, but not below + throw new IllegalArgumentException("weight must be >= 0, but was: " + weight); + } + this.weight = weight; + if (port < 0 || port > 65535) { + throw new IllegalArgumentException("port must be between 0 and 65535, but was: " + port); + } + this.port = port; + if (serverName == null || serverName.trim().isEmpty()) { + throw new IllegalArgumentException("hostname is not supposed to be null or empty in a SRV Record"); + } + this.serverName = serverName; + } + + private final int weight; + private final int port; + private final String serverName; + + @Override + public int hashCode() { + return serverName.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + if (!(other instanceof DNSRecordSRV)) { + return false; + } + + DNSRecordSRV r = (DNSRecordSRV) other; + return port == r.port && weight == r.weight && priority == r.priority && serverName.equals(r.serverName); + } + + @Override + public int compareTo(DNSRecordSRV o) { + if (o == null) { + return 1; + } + int p = Integer.compare(priority, o.priority); + if (p != 0) { + return p; + } + p = Integer.compare(weight, o.weight); + if (p != 0) { + return p; + } + p = Integer.compare(port, o.port); + if (p != 0) { + return p; + } + return serverName.compareTo(o.serverName); + } + + /** + * Get the priority of DNS SRV record. + * @return a positive priority, where lowest values have to be considered first. + */ + public int getPriority() { + return priority; + } + + /** + * Get the weight of DNS record from 0 to 65535. + * @return The weight, higher value means higher probability of selecting the given record for a given priority. + */ + public int getWeight() { + return weight; + } + + /** + * IP port of record. + * @return a value from 1 to 65535. + */ + public int getPort() { + return port; + } + + /** + * The DNS server name. + * @return a not null server name. + */ + public String getServerName() { + return serverName; + } +} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSUtilities.java b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSUtilities.java new file mode 100644 index 0000000000..f7eb6de0cd --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSUtilities.java @@ -0,0 +1,68 @@ +/* + * Microsoft JDBC Driver for SQL Server + * + * Copyright(c) Microsoft Corporation All rights reserved. + * + * This program is made available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ +package com.microsoft.sqlserver.jdbc.dns; + +import java.util.Hashtable; +import java.util.Set; +import java.util.TreeSet; +import java.util.logging.Level; +import java.util.logging.Logger; + +import javax.naming.NamingEnumeration; +import javax.naming.NamingException; +import javax.naming.directory.Attribute; +import javax.naming.directory.Attributes; +import javax.naming.directory.DirContext; +import javax.naming.directory.InitialDirContext; + +public class DNSUtilities { + + private final static Logger LOG = Logger.getLogger(DNSUtilities.class.getName()); + + private static final Level DNS_ERR_LOG_LEVEL = Level.FINE; + + /** + * Find all SRV Record using DNS. + * + * You can then use {@link DNSRecordsSRVCollection#getBestRecord()} to find the best candidate (for instance for Round-Robin calls) + * + * @param dnsSrvRecordToFind + * the DNS record, for instance: _ldap._tcp.dc._msdcs.DOMAIN.COM to find all LDAP servers in DOMAIN.COM + * @return the collection of records with facilities to find the best candidate + * @throws NamingException + * if DNS is not available + */ + public static Set findSrvRecords(final String dnsSrvRecordToFind) throws NamingException { + Hashtable env = new Hashtable(); + env.put("java.naming.factory.initial", "com.sun.jndi.dns.DnsContextFactory"); + env.put("java.naming.provider.url", "dns:"); + DirContext ctx = new InitialDirContext(env); + Attributes attrs = ctx.getAttributes(dnsSrvRecordToFind, new String[] {"SRV"}); + NamingEnumeration allServers = attrs.getAll(); + TreeSet records = new TreeSet(); + while (allServers.hasMoreElements()) { + Attribute a = allServers.nextElement(); + NamingEnumeration srvRecord = a.getAll(); + while (srvRecord.hasMore()) { + final String record = String.valueOf(srvRecord.nextElement()); + try { + DNSRecordSRV rec = DNSRecordSRV.parseFromDNSRecord(record); + if (rec != null) { + records.add(rec); + } + } + catch (IllegalArgumentException errorParsingRecord) { + if (LOG.isLoggable(DNS_ERR_LOG_LEVEL)) { + LOG.log(DNS_ERR_LOG_LEVEL, String.format("Failed to parse SRV DNS Record: '%s'", record), errorParsingRecord); + } + } + } + } + return records; + } +} diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/dns/DNSRealmsTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/dns/DNSRealmsTest.java new file mode 100644 index 0000000000..8a16cffc9e --- /dev/null +++ b/src/test/java/com/microsoft/sqlserver/jdbc/dns/DNSRealmsTest.java @@ -0,0 +1,28 @@ +/* + * Microsoft JDBC Driver for SQL Server + * + * Copyright(c) Microsoft Corporation All rights reserved. + * + * This program is made available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ +package com.microsoft.sqlserver.jdbc.dns; + +import javax.naming.NamingException; + +public class DNSRealmsTest { + + public static void main(String... args) { + if (args.length < 1) { + System.err.println("USAGE: list of domains to test for kerberos realms"); + } + for (String realmName : args) { + try { + System.out.print(DNSKerberosLocator.isRealmValid(realmName) ? "[ VALID ] " : "[INVALID] "); + } catch (NamingException err) { + System.err.print("[ FAILED] : " + err.getClass().getName() + ":" + err.getMessage()); + } + System.out.println(realmName); + } + } + +}