Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSessionContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
Expand Down Expand Up @@ -542,17 +543,6 @@ SSLContext sslContext() {
return context;
}

/**
* Invalidates the sessions in the provided {@link SSLSessionContext}
*/
private void invalidateSessions(SSLSessionContext sslSessionContext) {
Enumeration<byte[]> sessionIds = sslSessionContext.getIds();
while (sessionIds.hasMoreElements()) {
byte[] sessionId = sessionIds.nextElement();
sslSessionContext.getSession(sessionId).invalidate();
}
}

synchronized void reload() {
invalidateSessions(context.getClientSessionContext());
invalidateSessions(context.getServerSessionContext());
Expand Down Expand Up @@ -592,6 +582,24 @@ X509ExtendedTrustManager getEmptyTrustManager() throws GeneralSecurityException,
}
}

/**
* Invalidates the sessions in the provided {@link SSLSessionContext}
*/
static void invalidateSessions(SSLSessionContext sslSessionContext) {
Enumeration<byte[]> sessionIds = sslSessionContext.getIds();
while (sessionIds.hasMoreElements()) {
byte[] sessionId = sessionIds.nextElement();
SSLSession session = sslSessionContext.getSession(sessionId);
// a SSLSession could be null as there is no lock while iterating, the session cache
// could have evicted a value, the session could be timed out, or the session could
// have already been invalidated, which removes the value from the session cache in the
// sun implementation
if (session != null) {
session.invalidate();
}
}
}

/**
* @return A map of Settings prefix to Settings object
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.CheckedRunnable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.settings.MockSecureSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.env.Environment;
Expand All @@ -35,19 +36,29 @@
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSessionContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.X509ExtendedTrustManager;
import javax.security.cert.X509Certificate;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.Principal;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.security.cert.Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
import static org.hamcrest.Matchers.contains;
Expand Down Expand Up @@ -654,6 +665,57 @@ public void testReadCertificateInformation() throws Exception {
assertFalse(iterator.hasNext());
}

public void testSSLSessionInvalidationHandlesNullSessions() {
final int numEntries = randomIntBetween(1, 32);
final AtomicInteger invalidationCounter = new AtomicInteger();
int numNull = 0;
final Map<byte[], SSLSession> sessionMap = new HashMap<>();
for (int i = 0; i < numEntries; i++) {
final byte[] id = randomByteArrayOfLength(2);
final SSLSession sslSession;
if (rarely()) {
sslSession = null;
numNull++;
} else {
sslSession = new MockSSLSession(id, invalidationCounter::incrementAndGet);
}
sessionMap.put(id, sslSession);
}

SSLSessionContext sslSessionContext = new SSLSessionContext() {
@Override
public SSLSession getSession(byte[] sessionId) {
return sessionMap.get(sessionId);
}

@Override
public Enumeration<byte[]> getIds() {
return Collections.enumeration(sessionMap.keySet());
}

@Override
public void setSessionTimeout(int seconds) throws IllegalArgumentException {
}

@Override
public int getSessionTimeout() {
return 0;
}

@Override
public void setSessionCacheSize(int size) throws IllegalArgumentException {
}

@Override
public int getSessionCacheSize() {
return 0;
}
};

SSLService.invalidateSessions(sslSessionContext);
assertEquals(numEntries - numNull, invalidationCounter.get());
}

@Network
public void testThatSSLContextWithoutSettingsWorks() throws Exception {
SSLService sslService = new SSLService(Settings.EMPTY, env);
Expand Down Expand Up @@ -761,4 +823,120 @@ private static void privilegedConnect(CheckedRunnable<Exception> runnable) throw
}
}

private static final class MockSSLSession implements SSLSession {

private final byte[] id;
private final Runnable invalidation;

private MockSSLSession(byte[] id, Runnable invalidation) {
this.id = id;
this.invalidation = invalidation;
}

@Override
public byte[] getId() {
return id;
}

@Override
public SSLSessionContext getSessionContext() {
return null;
}

@Override
public long getCreationTime() {
return 0;
}

@Override
public long getLastAccessedTime() {
return 0;
}

@Override
public void invalidate() {
invalidation.run();
}

@Override
public boolean isValid() {
return false;
}

@Override
public void putValue(String name, Object value) {

}

@Override
public Object getValue(String name) {
return null;
}

@Override
public void removeValue(String name) {

}

@Override
public String[] getValueNames() {
return new String[0];
}

@Override
public Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException {
return new Certificate[0];
}

@Override
public Certificate[] getLocalCertificates() {
return new Certificate[0];
}

@SuppressForbidden(reason = "need to reference deprecated class to implement JDK interface")
@Override
public X509Certificate[] getPeerCertificateChain() throws SSLPeerUnverifiedException {
return new X509Certificate[0];
}

@Override
public Principal getPeerPrincipal() throws SSLPeerUnverifiedException {
return null;
}

@Override
public Principal getLocalPrincipal() {
return null;
}

@Override
public String getCipherSuite() {
return null;
}

@Override
public String getProtocol() {
return null;
}

@Override
public String getPeerHost() {
return null;
}

@Override
public int getPeerPort() {
return 0;
}

@Override
public int getPacketBufferSize() {
return 0;
}

@Override
public int getApplicationBufferSize() {
return 0;
}
}
}