diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ssl/SSLService.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ssl/SSLService.java index 07d438b243b5b..08513ce7412a4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ssl/SSLService.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ssl/SSLService.java @@ -24,6 +24,7 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSessionContext; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; @@ -542,17 +543,6 @@ SSLContext sslContext() { return context; } - /** - * Invalidates the sessions in the provided {@link SSLSessionContext} - */ - private void invalidateSessions(SSLSessionContext sslSessionContext) { - Enumeration sessionIds = sslSessionContext.getIds(); - while (sessionIds.hasMoreElements()) { - byte[] sessionId = sessionIds.nextElement(); - sslSessionContext.getSession(sessionId).invalidate(); - } - } - synchronized void reload() { invalidateSessions(context.getClientSessionContext()); invalidateSessions(context.getServerSessionContext()); @@ -592,6 +582,24 @@ X509ExtendedTrustManager getEmptyTrustManager() throws GeneralSecurityException, } } + /** + * Invalidates the sessions in the provided {@link SSLSessionContext} + */ + static void invalidateSessions(SSLSessionContext sslSessionContext) { + Enumeration sessionIds = sslSessionContext.getIds(); + while (sessionIds.hasMoreElements()) { + byte[] sessionId = sessionIds.nextElement(); + SSLSession session = sslSessionContext.getSession(sessionId); + // a SSLSession could be null as there is no lock while iterating, the session cache + // could have evicted a value, the session could be timed out, or the session could + // have already been invalidated, which removes the value from the session cache in the + // sun implementation + if (session != null) { + session.invalidate(); + } + } + } + /** * @return A map of Settings prefix to Settings object */ diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ssl/SSLServiceTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ssl/SSLServiceTests.java index 048ad2e8e3692..e0fee670d8dc1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ssl/SSLServiceTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ssl/SSLServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.CheckedRunnable; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.SuppressForbidden; import org.elasticsearch.common.settings.MockSecureSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.env.Environment; @@ -35,19 +36,29 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSessionContext; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.X509ExtendedTrustManager; +import javax.security.cert.X509Certificate; import java.nio.file.Path; import java.security.AccessController; +import java.security.Principal; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.security.cert.Certificate; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; +import java.util.Enumeration; +import java.util.HashMap; import java.util.Iterator; import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.Matchers.arrayContainingInAnyOrder; import static org.hamcrest.Matchers.contains; @@ -654,6 +665,57 @@ public void testReadCertificateInformation() throws Exception { assertFalse(iterator.hasNext()); } + public void testSSLSessionInvalidationHandlesNullSessions() { + final int numEntries = randomIntBetween(1, 32); + final AtomicInteger invalidationCounter = new AtomicInteger(); + int numNull = 0; + final Map sessionMap = new HashMap<>(); + for (int i = 0; i < numEntries; i++) { + final byte[] id = randomByteArrayOfLength(2); + final SSLSession sslSession; + if (rarely()) { + sslSession = null; + numNull++; + } else { + sslSession = new MockSSLSession(id, invalidationCounter::incrementAndGet); + } + sessionMap.put(id, sslSession); + } + + SSLSessionContext sslSessionContext = new SSLSessionContext() { + @Override + public SSLSession getSession(byte[] sessionId) { + return sessionMap.get(sessionId); + } + + @Override + public Enumeration getIds() { + return Collections.enumeration(sessionMap.keySet()); + } + + @Override + public void setSessionTimeout(int seconds) throws IllegalArgumentException { + } + + @Override + public int getSessionTimeout() { + return 0; + } + + @Override + public void setSessionCacheSize(int size) throws IllegalArgumentException { + } + + @Override + public int getSessionCacheSize() { + return 0; + } + }; + + SSLService.invalidateSessions(sslSessionContext); + assertEquals(numEntries - numNull, invalidationCounter.get()); + } + @Network public void testThatSSLContextWithoutSettingsWorks() throws Exception { SSLService sslService = new SSLService(Settings.EMPTY, env); @@ -761,4 +823,120 @@ private static void privilegedConnect(CheckedRunnable runnable) throw } } + private static final class MockSSLSession implements SSLSession { + + private final byte[] id; + private final Runnable invalidation; + + private MockSSLSession(byte[] id, Runnable invalidation) { + this.id = id; + this.invalidation = invalidation; + } + + @Override + public byte[] getId() { + return id; + } + + @Override + public SSLSessionContext getSessionContext() { + return null; + } + + @Override + public long getCreationTime() { + return 0; + } + + @Override + public long getLastAccessedTime() { + return 0; + } + + @Override + public void invalidate() { + invalidation.run(); + } + + @Override + public boolean isValid() { + return false; + } + + @Override + public void putValue(String name, Object value) { + + } + + @Override + public Object getValue(String name) { + return null; + } + + @Override + public void removeValue(String name) { + + } + + @Override + public String[] getValueNames() { + return new String[0]; + } + + @Override + public Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException { + return new Certificate[0]; + } + + @Override + public Certificate[] getLocalCertificates() { + return new Certificate[0]; + } + + @SuppressForbidden(reason = "need to reference deprecated class to implement JDK interface") + @Override + public X509Certificate[] getPeerCertificateChain() throws SSLPeerUnverifiedException { + return new X509Certificate[0]; + } + + @Override + public Principal getPeerPrincipal() throws SSLPeerUnverifiedException { + return null; + } + + @Override + public Principal getLocalPrincipal() { + return null; + } + + @Override + public String getCipherSuite() { + return null; + } + + @Override + public String getProtocol() { + return null; + } + + @Override + public String getPeerHost() { + return null; + } + + @Override + public int getPeerPort() { + return 0; + } + + @Override + public int getPacketBufferSize() { + return 0; + } + + @Override + public int getApplicationBufferSize() { + return 0; + } + } }