From 0f0cc5beb4beea3c82f205b5aed361e713fdd461 Mon Sep 17 00:00:00 2001 From: dkimitsa Date: Sun, 14 Mar 2021 15:18:01 +0200 Subject: [PATCH 1/4] * fixed: mutex is allocated before tryingto get file descriptor. as in file descriptor op fails (like no more failes available) destructor of AppData will try to release not initialized mutex (will cause a crash) --- .../crypto/src/main/native/org_conscrypt_NativeCrypto.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/compiler/vm/rt/android/libcore/crypto/src/main/native/org_conscrypt_NativeCrypto.cpp b/compiler/vm/rt/android/libcore/crypto/src/main/native/org_conscrypt_NativeCrypto.cpp index 0daf2d9ac..180d058a4 100755 --- a/compiler/vm/rt/android/libcore/crypto/src/main/native/org_conscrypt_NativeCrypto.cpp +++ b/compiler/vm/rt/android/libcore/crypto/src/main/native/org_conscrypt_NativeCrypto.cpp @@ -5596,6 +5596,10 @@ class AppData { public: static AppData* create() { UniquePtr appData(new AppData()); + if (MUTEX_SETUP(appData.get()->mutex) == -1) { + ALOGE("pthread_mutex_init(3) failed: %s", strerror(errno)); + return NULL; + } if (pipe(appData.get()->fdsEmergency) == -1) { ALOGE("AppData::create pipe(2) failed: %s", strerror(errno)); return NULL; @@ -5604,10 +5608,6 @@ class AppData { ALOGE("AppData::create fcntl(2) failed: %s", strerror(errno)); return NULL; } - if (MUTEX_SETUP(appData.get()->mutex) == -1) { - ALOGE("pthread_mutex_init(3) failed: %s", strerror(errno)); - return NULL; - } return appData.release(); } From 149d6dadf8408a2af53794e6b7ddbe9106bc109f Mon Sep 17 00:00:00 2001 From: dkimitsa Date: Sun, 14 Mar 2021 15:19:30 +0200 Subject: [PATCH 2/4] * applied alpn protocol related fix from https://android-review.googlesource.com/c/platform/external/conscrypt/+/89408 --- .../main/java/org/conscrypt/NativeCrypto.java | 4 ++-- .../java/org/conscrypt/OpenSSLSocketImpl.java | 2 +- .../main/native/org_conscrypt_NativeCrypto.cpp | 16 ++++++++-------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/compiler/rt/libcore/crypto/src/main/java/org/conscrypt/NativeCrypto.java b/compiler/rt/libcore/crypto/src/main/java/org/conscrypt/NativeCrypto.java index 8b93a63cb..290107bd8 100755 --- a/compiler/rt/libcore/crypto/src/main/java/org/conscrypt/NativeCrypto.java +++ b/compiler/rt/libcore/crypto/src/main/java/org/conscrypt/NativeCrypto.java @@ -1015,9 +1015,9 @@ public static native void SSL_set_tlsext_host_name(long sslNativePointer, String /** * For clients, sets the list of supported ALPN protocols in wire-format - * (length-prefixed 8-bit strings) on an SSL context. + * (length-prefixed 8-bit strings). */ - public static native int SSL_CTX_set_alpn_protos(long sslCtxPointer, byte[] protos); + public static native int SSL_set_alpn_protos(long sslPointer, byte[] protos); /** * Returns the selected ALPN protocol. If the server did not select a diff --git a/compiler/rt/libcore/crypto/src/main/java/org/conscrypt/OpenSSLSocketImpl.java b/compiler/rt/libcore/crypto/src/main/java/org/conscrypt/OpenSSLSocketImpl.java index ece7893bf..3c082f5a4 100755 --- a/compiler/rt/libcore/crypto/src/main/java/org/conscrypt/OpenSSLSocketImpl.java +++ b/compiler/rt/libcore/crypto/src/main/java/org/conscrypt/OpenSSLSocketImpl.java @@ -282,7 +282,7 @@ private void checkOpen() throws SocketException { } if (client && alpnProtocols != null) { - NativeCrypto.SSL_CTX_set_alpn_protos(sslCtxNativePointer, alpnProtocols); + NativeCrypto.SSL_set_alpn_protos(sslNativePointer, alpnProtocols); } // setup server certificates and private keys. diff --git a/compiler/vm/rt/android/libcore/crypto/src/main/native/org_conscrypt_NativeCrypto.cpp b/compiler/vm/rt/android/libcore/crypto/src/main/native/org_conscrypt_NativeCrypto.cpp index 180d058a4..fee8d3b18 100755 --- a/compiler/vm/rt/android/libcore/crypto/src/main/native/org_conscrypt_NativeCrypto.cpp +++ b/compiler/vm/rt/android/libcore/crypto/src/main/native/org_conscrypt_NativeCrypto.cpp @@ -6912,30 +6912,30 @@ extern "C" jbyteArray Java_com_android_org_conscrypt_NativeCrypto_SSL_1get_1npn_ return result; } -extern "C" int Java_com_android_org_conscrypt_NativeCrypto_SSL_1CTX_1set_1alpn_1protos(JNIEnv* env, jclass, jlong ssl_ctx_address, +extern "C" int Java_com_android_org_conscrypt_NativeCrypto_SSL_1set_1alpn_1protos(JNIEnv* env, jclass, jlong ssl_address, jbyteArray protos) { - SSL_CTX* ssl_ctx = to_SSL_CTX(env, ssl_ctx_address, true); - if (ssl_ctx == NULL) { + SSL* ssl = to_SSL(env, ssl_address, true); + if (ssl == NULL) { return 0; } - JNI_TRACE("ssl_ctx=%p SSL_CTX_set_alpn_protos protos=%p", ssl_ctx, protos); + JNI_TRACE("ssl=%p SSL_set_alpn_protos protos=%p", ssl, protos); if (protos == NULL) { - JNI_TRACE("ssl_ctx=%p SSL_CTX_set_alpn_protos protos=NULL", ssl_ctx); + JNI_TRACE("ssl=%p SSL_set_alpn_protos protos=NULL", ssl); return 1; } ScopedByteArrayRO protosBytes(env, protos); if (protosBytes.get() == NULL) { - JNI_TRACE("ssl_ctx=%p SSL_CTX_set_alpn_protos protos=%p => protosBytes == NULL", ssl_ctx, + JNI_TRACE("ssl=%p SSL_set_alpn_protos protos=%p => protosBytes == NULL", ssl, protos); return 0; } const unsigned char *tmp = reinterpret_cast(protosBytes.get()); - int ret = SSL_CTX_set_alpn_protos(ssl_ctx, tmp, protosBytes.size()); - JNI_TRACE("ssl_ctx=%p SSL_CTX_set_alpn_protos protos=%p => ret=%d", ssl_ctx, protos, ret); + int ret = SSL_set_alpn_protos(ssl, tmp, protosBytes.size()); + JNI_TRACE("ssl=%p SSL_set_alpn_protos protos=%p => ret=%d", ssl, protos, ret); return ret; } From 4ce0a23fdda1598531ec06767dd91fe2377e3b23 Mon Sep 17 00:00:00 2001 From: dkimitsa Date: Mon, 15 Mar 2021 16:43:08 +0200 Subject: [PATCH 3/4] * applied missing part of CVE-2010-3864 (https://www.openssl.org/news/secadv/20101116-2.txt) that were causing the crash due multiple release of s->session->tlsext_ecpointformatlist due race in multithreading --- .../rt/android/external/openssl/ssl/t1_lib.c | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/compiler/vm/rt/android/external/openssl/ssl/t1_lib.c b/compiler/vm/rt/android/external/openssl/ssl/t1_lib.c index f1700561d..c9bbaf265 100755 --- a/compiler/vm/rt/android/external/openssl/ssl/t1_lib.c +++ b/compiler/vm/rt/android/external/openssl/ssl/t1_lib.c @@ -1592,15 +1592,22 @@ int ssl_parse_serverhello_tlsext(SSL *s, unsigned char **p, unsigned char *d, in *al = TLS1_AD_DECODE_ERROR; return 0; } - s->session->tlsext_ecpointformatlist_length = 0; - if (s->session->tlsext_ecpointformatlist != NULL) OPENSSL_free(s->session->tlsext_ecpointformatlist); - if ((s->session->tlsext_ecpointformatlist = OPENSSL_malloc(ecpointformatlist_length)) == NULL) - { - *al = TLS1_AD_INTERNAL_ERROR; - return 0; - } - s->session->tlsext_ecpointformatlist_length = ecpointformatlist_length; - memcpy(s->session->tlsext_ecpointformatlist, sdata, ecpointformatlist_length); + if (!s->hit) + { + if(s->session->tlsext_ecpointformatlist) + { + OPENSSL_free(s->session->tlsext_ecpointformatlist); + s->session->tlsext_ecpointformatlist = NULL; + } + s->session->tlsext_ecpointformatlist_length = 0; + if ((s->session->tlsext_ecpointformatlist = OPENSSL_malloc(ecpointformatlist_length)) == NULL) + { + *al = TLS1_AD_INTERNAL_ERROR; + return 0; + } + s->session->tlsext_ecpointformatlist_length = ecpointformatlist_length; + memcpy(s->session->tlsext_ecpointformatlist, sdata, ecpointformatlist_length); + } #if 0 fprintf(stderr,"ssl_parse_serverhello_tlsext s->session->tlsext_ecpointformatlist "); sdata = s->session->tlsext_ecpointformatlist; From 3c6e25346890489ee9240481e554bd46a7f8e2af Mon Sep 17 00:00:00 2001 From: dkimitsa Date: Mon, 15 Mar 2021 16:47:57 +0200 Subject: [PATCH 4/4] * applied 68360-Tidy up locking in OpenSSLSocketImpl that improves multithreading (https://android-review.googlesource.com/c/platform/external/conscrypt/+/68360) We guard all state with a single lock "stateLock", which replaces usages of "this" and "handshakeLock". We do not perform any blocking operations while holding this lock. In particular, startHandshake is no longer synchronized. We use a single integer to keep track of handshake state instead of a pair of booleans. Also fix a bug in getSession, the previous implementation wouldn't work in cut-through mode. --- .../java/org/conscrypt/OpenSSLSocketImpl.java | 651 +++++++++++++----- 1 file changed, 470 insertions(+), 181 deletions(-) diff --git a/compiler/rt/libcore/crypto/src/main/java/org/conscrypt/OpenSSLSocketImpl.java b/compiler/rt/libcore/crypto/src/main/java/org/conscrypt/OpenSSLSocketImpl.java index 3c082f5a4..23f7120d7 100755 --- a/compiler/rt/libcore/crypto/src/main/java/org/conscrypt/OpenSSLSocketImpl.java +++ b/compiler/rt/libcore/crypto/src/main/java/org/conscrypt/OpenSSLSocketImpl.java @@ -39,17 +39,18 @@ import javax.net.ssl.HandshakeCompletedListener; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; -import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLProtocolException; import javax.net.ssl.SSLSession; import javax.net.ssl.X509TrustManager; import javax.security.auth.x500.X500Principal; -import static libcore.io.OsConstants.*; import libcore.io.ErrnoException; import libcore.io.Libcore; import libcore.io.Streams; import libcore.io.StructTimeval; +import static libcore.io.OsConstants.SOL_SOCKET; +import static libcore.io.OsConstants.SO_SNDTIMEO; + /** * Implementation of the class OpenSSLSocketImpl based on OpenSSL. *

@@ -64,36 +65,100 @@ public class OpenSSLSocketImpl extends javax.net.ssl.SSLSocket implements NativeCrypto.SSLHandshakeCallbacks { + private static final boolean DBG_STATE = false; + + /** + * Protects handshakeStarted and handshakeCompleted. + */ + private final Object stateLock = new Object(); + + /** + * The {@link OpenSSLSocketImpl} object is constructed, but {@link #startHandshake()} + * has not yet been called. + */ + private static final int STATE_NEW = 0; + + /** + * {@link #startHandshake()} has been called at least once. + */ + private static final int STATE_HANDSHAKE_STARTED = 1; + + /** + * {@link #handshakeCompleted()} has been called, but {@link #startHandshake()} hasn't + * returned yet. + */ + private static final int STATE_HANDSHAKE_COMPLETED = 2; + + /** + * {@link #startHandshake()} has completed but {@link #handshakeCompleted()} hasn't + * been called. This is expected behaviour in cut-through mode, where SSL_do_handshake + * returns before the handshake is complete. We can now start writing data to the socket. + */ + private static final int STATE_READY_HANDSHAKE_CUT_THROUGH = 3; + + /** + * {@link #startHandshake()} has completed and {@link #handshakeCompleted()} has been + * called. + */ + private static final int STATE_READY = 4; + + /** + * {@link #close()} has been called at least once. + */ + private static final int STATE_CLOSED = 5; + + // @GuardedBy("stateLock"); + private int state = STATE_NEW; + + /** + * Protected by synchronizing on stateLock. Starts as 0, set by + * startHandshake, reset to 0 on close. + */ + // @GuardedBy("stateLock"); private long sslNativePointer; - private InputStream is; - private OutputStream os; - private final Object handshakeLock = new Object(); - private final Object readLock = new Object(); - private final Object writeLock = new Object(); - private SSLParametersImpl sslParameters; - private byte[] npnProtocols; - private byte[] alpnProtocols; + + /** + * Protected by synchronizing on stateLock. Starts as null, set by + * getInputStream. + */ + // @GuardedBy("stateLock"); + private SSLInputStream is; + + /** + * Protected by synchronizing on stateLock. Starts as null, set by + * getInputStream. + */ + // @GuardedBy("stateLock"); + private SSLOutputStream os; + + private final Socket socket; + private final boolean autoClose; + private final String wrappedHost; + private final int wrappedPort; + private final SSLParametersImpl sslParameters; + private final CloseGuard guard = CloseGuard.get(); + private String[] enabledProtocols; private String[] enabledCipherSuites; + private byte[] npnProtocols; + private byte[] alpnProtocols; private boolean useSessionTickets; private String hostname; - /** Whether the TLS Channel ID extension is enabled. This field is server-side only. */ + + /** + * Whether the TLS Channel ID extension is enabled. This field is + * server-side only. + */ private boolean channelIdEnabled; - /** Private key for the TLS Channel ID extension. This field is client-side only. */ - private OpenSSLKey channelIdPrivateKey; - private OpenSSLSessionImpl sslSession; - private final Socket socket; - private boolean autoClose; - private boolean handshakeStarted = false; - private final CloseGuard guard = CloseGuard.get(); /** - * Not set to true until the update from native that tells us the - * full handshake is complete, since SSL_do_handshake can return - * before the handshake is completely done due to - * handshake_cutthrough support. + * Private key for the TLS Channel ID extension. This field is + * client-side only. Set during startHandshake. */ - private boolean handshakeCompleted = false; + private OpenSSLKey channelIdPrivateKey; + + /** Set during startHandshake. */ + private OpenSSLSessionImpl sslSession; private ArrayList listeners; @@ -107,33 +172,51 @@ public class OpenSSLSocketImpl private int writeTimeoutMilliseconds = 0; private int handshakeTimeoutMilliseconds = -1; // -1 = same as timeout; 0 = infinite - private String wrappedHost; - private int wrappedPort; protected OpenSSLSocketImpl(SSLParametersImpl sslParameters) throws IOException { this.socket = this; - init(sslParameters); + this.wrappedHost = null; + this.wrappedPort = -1; + this.autoClose = false; + this.sslParameters = sslParameters; + this.enabledProtocols = NativeCrypto.getDefaultProtocols(); + this.enabledCipherSuites = NativeCrypto.getDefaultCipherSuites(); } protected OpenSSLSocketImpl(SSLParametersImpl sslParameters, String[] enabledProtocols, String[] enabledCipherSuites) throws IOException { this.socket = this; - init(sslParameters, enabledProtocols, enabledCipherSuites); + this.wrappedHost = null; + this.wrappedPort = -1; + this.autoClose = false; + this.sslParameters = sslParameters; + this.enabledProtocols = enabledProtocols; + this.enabledCipherSuites = enabledCipherSuites; } protected OpenSSLSocketImpl(String host, int port, SSLParametersImpl sslParameters) throws IOException { super(host, port); this.socket = this; - init(sslParameters); + this.wrappedHost = null; + this.wrappedPort = -1; + this.autoClose = false; + this.sslParameters = sslParameters; + this.enabledProtocols = NativeCrypto.getDefaultProtocols(); + this.enabledCipherSuites = NativeCrypto.getDefaultCipherSuites(); } protected OpenSSLSocketImpl(InetAddress address, int port, SSLParametersImpl sslParameters) throws IOException { super(address, port); this.socket = this; - init(sslParameters); + this.wrappedHost = null; + this.wrappedPort = -1; + this.autoClose = false; + this.sslParameters = sslParameters; + this.enabledProtocols = NativeCrypto.getDefaultProtocols(); + this.enabledCipherSuites = NativeCrypto.getDefaultCipherSuites(); } @@ -142,7 +225,12 @@ protected OpenSSLSocketImpl(String host, int port, SSLParametersImpl sslParameters) throws IOException { super(host, port, clientAddress, clientPort); this.socket = this; - init(sslParameters); + this.wrappedHost = null; + this.wrappedPort = -1; + this.autoClose = false; + this.sslParameters = sslParameters; + this.enabledProtocols = NativeCrypto.getDefaultProtocols(); + this.enabledCipherSuites = NativeCrypto.getDefaultCipherSuites(); } protected OpenSSLSocketImpl(InetAddress address, int port, @@ -150,7 +238,12 @@ protected OpenSSLSocketImpl(InetAddress address, int port, SSLParametersImpl sslParameters) throws IOException { super(address, port, clientAddress, clientPort); this.socket = this; - init(sslParameters); + this.wrappedHost = null; + this.wrappedPort = -1; + this.autoClose = false; + this.sslParameters = sslParameters; + this.enabledProtocols = NativeCrypto.getDefaultProtocols(); + this.enabledCipherSuites = NativeCrypto.getDefaultCipherSuites(); } /** @@ -158,40 +251,20 @@ protected OpenSSLSocketImpl(InetAddress address, int port, * OpenSSLSocketImplWrapper constructor. */ protected OpenSSLSocketImpl(Socket socket, String host, int port, - boolean autoClose, SSLParametersImpl sslParameters) throws IOException { + boolean autoClose, SSLParametersImpl sslParameters) throws IOException { this.socket = socket; this.wrappedHost = host; this.wrappedPort = port; this.autoClose = autoClose; - init(sslParameters); + this.sslParameters = sslParameters; + this.enabledProtocols = NativeCrypto.getDefaultProtocols(); + this.enabledCipherSuites = NativeCrypto.getDefaultCipherSuites(); // this.timeout is not set intentionally. // OpenSSLSocketImplWrapper.getSoTimeout will delegate timeout // to wrapped socket } - /** - * Initialize the SSL socket and set the certificates for the - * future handshaking. - */ - private void init(SSLParametersImpl sslParameters) throws IOException { - init(sslParameters, - NativeCrypto.getDefaultProtocols(), - NativeCrypto.getDefaultCipherSuites()); - } - - /** - * Initialize the SSL socket and set the certificates for the - * future handshaking. - */ - private void init(SSLParametersImpl sslParameters, - String[] enabledProtocols, - String[] enabledCipherSuites) throws IOException { - this.sslParameters = sslParameters; - this.enabledProtocols = enabledProtocols; - this.enabledCipherSuites = enabledCipherSuites; - } - /** * Gets the suitable session reference from the session cache container. */ @@ -246,12 +319,14 @@ private void checkOpen() throws SocketException { * verified if the correspondent property in java.Security is set. All * listeners are notified at the end of the TLS/SSL handshake. */ - @Override public synchronized void startHandshake() throws IOException { - synchronized (handshakeLock) { - checkOpen(); - if (!handshakeStarted) { - handshakeStarted = true; + @Override public void startHandshake() throws IOException { + checkOpen(); + synchronized (stateLock) { + if (state == STATE_NEW) { + state = STATE_HANDSHAKE_STARTED; } else { + // We've either started the handshake already or have been closed. + // Do nothing in both cases. return; } } @@ -268,11 +343,11 @@ private void checkOpen() throws SocketException { final boolean client = sslParameters.getUseClientMode(); final long sslCtxNativePointer = (client) ? - sslParameters.getClientSessionContext().sslCtxNativePointer : - sslParameters.getServerSessionContext().sslCtxNativePointer; + sslParameters.getClientSessionContext().sslCtxNativePointer : + sslParameters.getServerSessionContext().sslCtxNativePointer; - this.sslNativePointer = 0; - boolean exception = true; + sslNativePointer = 0; + boolean releaseResources = true; try { sslNativePointer = NativeCrypto.SSL_new(sslCtxNativePointer); guard.open("close"); @@ -301,8 +376,8 @@ private void checkOpen() throws SocketException { for (String keyType : keyTypes) { try { setCertificate(sslParameters.getKeyManager().chooseServerAlias(keyType, - null, - this)); + null, + this)); } catch (CertificateEncodingException e) { throw new IOException(e); } @@ -321,7 +396,7 @@ private void checkOpen() throws SocketException { boolean enableSessionCreation = sslParameters.getEnableSessionCreation(); if (!enableSessionCreation) { NativeCrypto.SSL_set_session_creation_enabled(sslNativePointer, - enableSessionCreation); + enableSessionCreation); } AbstractSessionContext sessionContext; @@ -333,7 +408,7 @@ private void checkOpen() throws SocketException { sessionToReuse = getCachedClientSession(clientSessionContext); if (sessionToReuse != null) { NativeCrypto.SSL_set_session(sslNativePointer, - sessionToReuse.sslSessionNativePointer); + sessionToReuse.sslSessionNativePointer); } } else { sessionContext = sslParameters.getServerSessionContext(); @@ -349,15 +424,15 @@ private void checkOpen() throws SocketException { boolean certRequested; if (sslParameters.getNeedClientAuth()) { NativeCrypto.SSL_set_verify(sslNativePointer, - NativeCrypto.SSL_VERIFY_PEER - | NativeCrypto.SSL_VERIFY_FAIL_IF_NO_PEER_CERT); + NativeCrypto.SSL_VERIFY_PEER + | NativeCrypto.SSL_VERIFY_FAIL_IF_NO_PEER_CERT); certRequested = true; - // ... over just wanting it... + // ... over just wanting it... } else if (sslParameters.getWantClientAuth()) { NativeCrypto.SSL_set_verify(sslNativePointer, - NativeCrypto.SSL_VERIFY_PEER); + NativeCrypto.SSL_VERIFY_PEER); certRequested = true; - // ... and it defaults properly so don't call SSL_set_verify in the common case. + // ... and it defaults properly so don't call SSL_set_verify in the common case. } else { certRequested = false; } @@ -400,6 +475,12 @@ private void checkOpen() throws SocketException { } } + synchronized (stateLock) { + if (state == STATE_CLOSED) { + return; + } + } + long sslSessionNativePointer; try { sslSessionNativePointer = NativeCrypto.SSL_do_handshake(sslNativePointer, @@ -409,7 +490,32 @@ private void checkOpen() throws SocketException { SSLHandshakeException wrapper = new SSLHandshakeException(e.getMessage()); wrapper.initCause(e); throw wrapper; + } catch (SSLException e) { + // Swallow this exception if it's thrown as the result of an interruption. + // + // TODO: SSL_read and SSL_write return -1 when interrupted, but SSL_do_handshake + // will throw the last sslError that it saw before sslSelect, usually SSL_WANT_READ + // (or WANT_WRITE). Catching that exception here doesn't seem much worse than + // changing the native code to return a "special" native pointer value when that + // happens. + synchronized (stateLock) { + if (state == STATE_CLOSED) { + return; + } + } + + throw e; } + + boolean handshakeCompleted = false; + synchronized (stateLock) { + if (state == STATE_HANDSHAKE_COMPLETED) { + handshakeCompleted = true; + } else if (state == STATE_CLOSED) { + return; + } + } + byte[] sessionId = NativeCrypto.SSL_SESSION_session_id(sslSessionNativePointer); if (sessionToReuse != null && Arrays.equals(sessionToReuse.getId(), sessionId)) { this.sslSession = sessionToReuse; @@ -443,13 +549,41 @@ private void checkOpen() throws SocketException { notifyHandshakeCompletedListeners(); } - exception = false; + synchronized (stateLock) { + releaseResources = (state == STATE_CLOSED); + + if (state == STATE_HANDSHAKE_STARTED) { + state = STATE_READY_HANDSHAKE_CUT_THROUGH; + } else if (state == STATE_HANDSHAKE_COMPLETED) { + state = STATE_READY; + } + + if (!releaseResources) { + // Unblock threads that are waiting for our state to transition + // into STATE_READY or STATE_READY_HANDSHAKE_CUT_THROUGH. + stateLock.notifyAll(); + } + } } catch (SSLProtocolException e) { throw new SSLHandshakeException(e); } finally { // on exceptional exit, treat the socket as closed - if (exception) { - close(); + if (releaseResources) { + synchronized (stateLock) { + // Mark the socket as closed since we might have reached this as + // a result on an exception thrown by the handshake process. + // + // The state will already be set to closed if we reach this as a result of + // an early return or an interruption due to a concurrent call to close(). + state = STATE_CLOSED; + stateLock.notifyAll(); + } + + try { + shutdownAndFreeSslNative(); + } catch (IOException ignored) { + + } } } } @@ -546,36 +680,52 @@ public void clientCertificateRequested(byte[] keyTypeBytes, byte[][] asn1DerEnco @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks / info_callback public void handshakeCompleted() { - handshakeCompleted = true; - - // If sslSession is null, the handshake was completed during - // the call to NativeCrypto.SSL_do_handshake and not during a - // later read operation. That means we do not need to fix up - // the SSLSession and session cache or notify - // HandshakeCompletedListeners, it will be done in - // startHandshake. - if (sslSession == null) { - return; + synchronized (stateLock) { + if (state == STATE_HANDSHAKE_STARTED) { + // If sslSession is null, the handshake was completed during + // the call to NativeCrypto.SSL_do_handshake and not during a + // later read operation. That means we do not need to fix up + // the SSLSession and session cache or notify + // HandshakeCompletedListeners, it will be done in + // startHandshake. + + state = STATE_HANDSHAKE_COMPLETED; + return; + } else if (state == STATE_READY_HANDSHAKE_CUT_THROUGH) { + // We've returned from startHandshake, which means we've set a sslSession etc. + // we need to fix them up, which we'll do outside this lock. + } else if (state == STATE_CLOSED) { + // Someone called "close" but the handshake hasn't been interrupted yet. + return; + } } // reset session id from the native pointer and update the // appropriate cache. sslSession.resetId(); AbstractSessionContext sessionContext = - (sslParameters.getUseClientMode()) - ? sslParameters.getClientSessionContext() - : sslParameters.getServerSessionContext(); + (sslParameters.getUseClientMode()) + ? sslParameters.getClientSessionContext() + : sslParameters.getServerSessionContext(); sessionContext.putSession(sslSession); // let listeners know we are finally done notifyHandshakeCompletedListeners(); + + synchronized (stateLock) { + // Now that we've fixed up our state, we can tell waiting threads that + // we're ready. + state = STATE_READY; + // Notify all threads waiting for the handshake to complete. + stateLock.notifyAll(); + } } private void notifyHandshakeCompletedListeners() { if (listeners != null && !listeners.isEmpty()) { // notify the listeners HandshakeCompletedEvent event = - new HandshakeCompletedEvent(this, sslSession); + new HandshakeCompletedEvent(this, sslSession); for (HandshakeCompletedListener listener : listeners) { try { listener.handshakeCompleted(event); @@ -615,7 +765,7 @@ private void notifyHandshakeCompletedListeners() { } else { String authType = peerCertificateChain[0].getPublicKey().getAlgorithm(); sslParameters.getTrustManager().checkClientTrusted(peerCertificateChain, - authType); + authType); } } catch (CertificateException e) { @@ -627,23 +777,80 @@ private void notifyHandshakeCompletedListeners() { @Override public InputStream getInputStream() throws IOException { checkOpen(); - synchronized (this) { + + InputStream returnVal; + synchronized (stateLock) { + if (state == STATE_CLOSED) { + throw new SocketException("Socket is closed."); + } + if (is == null) { is = new SSLInputStream(); } - return is; + returnVal = is; } + + // Block waiting for a handshake without a lock held. It's possible that the socket + // is closed at this point. If that happens, we'll still return the input stream but + // all reads on it will throw. + waitForHandshake(); + return returnVal; } @Override public OutputStream getOutputStream() throws IOException { checkOpen(); - synchronized (this) { + + OutputStream returnVal; + synchronized (stateLock) { + if (state == STATE_CLOSED) { + throw new SocketException("Socket is closed."); + } + if (os == null) { os = new SSLOutputStream(); } - return os; + returnVal = os; + } + + // Block waiting for a handshake without a lock held. It's possible that the socket + // is closed at this point. If that happens, we'll still return the output stream but + // all writes on it will throw. + waitForHandshake(); + return returnVal; + } + + private void assertReadableOrWriteableState() { + if (state == STATE_READY || state == STATE_READY_HANDSHAKE_CUT_THROUGH) { + return; + } + + throw new AssertionError("Invalid state: " + state); + } + + + private void waitForHandshake() throws IOException { + startHandshake(); + + synchronized (stateLock) { + while (state != STATE_READY && + state != STATE_READY_HANDSHAKE_CUT_THROUGH && + state != STATE_CLOSED) { + try { + stateLock.wait(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + IOException ioe = new IOException("Interrupted waiting for handshake"); + ioe.initCause(e); + + throw ioe; + } + } + + if (state == STATE_CLOSED) { + throw new SocketException("Socket is closed"); + } } } @@ -653,12 +860,14 @@ private void notifyHandshakeCompletedListeners() { * read data received via SSL protocol. */ private class SSLInputStream extends InputStream { - SSLInputStream() throws IOException { - /* - * Note: When startHandshake() throws an exception, no - * SSLInputStream object will be created. - */ - OpenSSLSocketImpl.this.startHandshake(); + /** + * OpenSSL only lets one thread read at a time, so this is used to + * make sure we serialize callers of SSL_read. Thread is already + * expected to have completed handshaking. + */ + private final Object readLock = new Object(); + + SSLInputStream() { } /** @@ -666,7 +875,7 @@ private class SSLInputStream extends InputStream { * this operation can block until the data will be * available. * @return read value. - * @throws IOException + * @throws IOException */ @Override public int read() throws IOException { @@ -680,16 +889,36 @@ public int read() throws IOException { @Override public int read(byte[] buf, int offset, int byteCount) throws IOException { BlockGuard.getThreadPolicy().onNetwork(); + + checkOpen(); + Arrays.checkOffsetAndCount(buf.length, offset, byteCount); + if (byteCount == 0) { + return 0; + } + synchronized (readLock) { - checkOpen(); - Arrays.checkOffsetAndCount(buf.length, offset, byteCount); - if (byteCount == 0) { - return 0; + synchronized (stateLock) { + if (state == STATE_CLOSED) { + throw new SocketException("socket is closed"); + } + + if (DBG_STATE) assertReadableOrWriteableState(); } + return NativeCrypto.SSL_read(sslNativePointer, socket.getFileDescriptor$(), OpenSSLSocketImpl.this, buf, offset, byteCount, getSoTimeout()); } } + + public void awaitPendingOps() { + if (DBG_STATE) { + synchronized (stateLock) { + if (state != STATE_CLOSED) throw new AssertionError("State is: " + state); + } + } + + synchronized (readLock) { } + } } /** @@ -698,12 +927,15 @@ public int read(byte[] buf, int offset, int byteCount) throws IOException { * write data according to the encryption parameters given in SSL context. */ private class SSLOutputStream extends OutputStream { - SSLOutputStream() throws IOException { - /* - * Note: When startHandshake() throws an exception, no - * SSLOutputStream object will be created. - */ - OpenSSLSocketImpl.this.startHandshake(); + + /** + * OpenSSL only lets one thread write at a time, so this is used + * to make sure we serialize callers of SSL_write. Thread is + * already expected to have completed handshaking. + */ + private final Object writeLock = new Object(); + + SSLOutputStream() { } /** @@ -722,23 +954,43 @@ public void write(int oneByte) throws IOException { @Override public void write(byte[] buf, int offset, int byteCount) throws IOException { BlockGuard.getThreadPolicy().onNetwork(); + checkOpen(); + Arrays.checkOffsetAndCount(buf.length, offset, byteCount); + if (byteCount == 0) { + return; + } + synchronized (writeLock) { - checkOpen(); - Arrays.checkOffsetAndCount(buf.length, offset, byteCount); - if (byteCount == 0) { - return; + synchronized (stateLock) { + if (state == STATE_CLOSED) { + throw new SocketException("socket is closed"); + } + + if (DBG_STATE) assertReadableOrWriteableState(); } + NativeCrypto.SSL_write(sslNativePointer, socket.getFileDescriptor$(), OpenSSLSocketImpl.this, buf, offset, byteCount, writeTimeoutMilliseconds); } } + + + public void awaitPendingOps() { + if (DBG_STATE) { + synchronized (stateLock) { + if (state != STATE_CLOSED) throw new AssertionError("State is: " + state); + } + } + + synchronized (writeLock) { } + } } @Override public SSLSession getSession() { if (sslSession == null) { try { - startHandshake(); + waitForHandshake(); } catch (IOException e) { // return an invalid session with // invalid cipher suite of "SSL_NULL_WITH_NULL_NULL" @@ -837,10 +1089,13 @@ public void setChannelIdEnabled(boolean enabled) { if (getUseClientMode()) { throw new IllegalStateException("Client mode"); } - if (handshakeStarted) { - throw new IllegalStateException( - "Could not enable/disable Channel ID after the initial handshake has" - + " begun."); + + synchronized (stateLock) { + if (state != STATE_NEW) { + throw new IllegalStateException( + "Could not enable/disable Channel ID after the initial handshake has" + + " begun."); + } } this.channelIdEnabled = enabled; } @@ -859,9 +1114,12 @@ public byte[] getChannelId() throws SSLException { if (getUseClientMode()) { throw new IllegalStateException("Client mode"); } - if (!handshakeCompleted) { - throw new IllegalStateException( - "Channel ID is only available after handshake completes"); + + synchronized (stateLock) { + if (state != STATE_READY) { + throw new IllegalStateException( + "Channel ID is only available after handshake completes"); + } } return NativeCrypto.SSL_get_tls_channel_id(sslNativePointer); } @@ -882,11 +1140,15 @@ public void setChannelIdPrivateKey(PrivateKey privateKey) { if (!getUseClientMode()) { throw new IllegalStateException("Server mode"); } - if (handshakeStarted) { - throw new IllegalStateException( - "Could not change Channel ID private key after the initial handshake has" - + " begun."); + + synchronized (stateLock) { + if (state != STATE_NEW) { + throw new IllegalStateException( + "Could not change Channel ID private key after the initial handshake has" + + " begun."); + } } + if (privateKey == null) { this.channelIdEnabled = false; this.channelIdPrivateKey = null; @@ -905,9 +1167,11 @@ public void setChannelIdPrivateKey(PrivateKey privateKey) { } @Override public void setUseClientMode(boolean mode) { - if (handshakeStarted) { - throw new IllegalArgumentException( - "Could not change the mode after the initial handshake has begun."); + synchronized (stateLock) { + if (state != STATE_NEW) { + throw new IllegalArgumentException( + "Could not change the mode after the initial handshake has begun."); + } } sslParameters.setUseClientMode(mode); } @@ -977,65 +1241,90 @@ public void setHandshakeTimeout(int handshakeTimeoutMilliseconds) throws SocketE @Override public void close() throws IOException { // TODO: Close SSL sockets using a background thread so they close gracefully. - synchronized (handshakeLock) { - if (!handshakeStarted) { - // prevent further attempts to start handshake - handshakeStarted = true; + SSLInputStream sslInputStream = null; + SSLOutputStream sslOutputStream = null; - synchronized (this) { - free(); + synchronized (stateLock) { + if (state == STATE_CLOSED) { + // close() has already been called, so do nothing and return. + return; + } - if (socket != this) { - if (autoClose && !socket.isClosed()) socket.close(); - } else { - if (!super.isClosed()) super.close(); - } - } + int oldState = state; + state = STATE_CLOSED; + + if (oldState == STATE_NEW) { + // The handshake hasn't been started yet, so there's no OpenSSL related + // state to clean up. We still need to close the underlying socket if + // we're wrapping it and were asked to autoClose. + closeUnderlyingSocket(); + stateLock.notifyAll(); return; } - } - synchronized (this) { + if (oldState != STATE_READY && oldState != STATE_READY_HANDSHAKE_CUT_THROUGH) { + // If we're in these states, we still haven't returned from startHandshake. + // We call SSL_interrupt so that we can interrupt SSL_do_handshake and then + // set the state to STATE_CLOSED. startHandshake will handle all cleanup + // after SSL_do_handshake returns, so we don't have anything to do here. + NativeCrypto.SSL_interrupt(sslNativePointer); + + stateLock.notifyAll(); + return; + } - // Interrupt any outstanding reads or writes before taking the writeLock and readLock + stateLock.notifyAll(); + // We've already returned from startHandshake, so we potentially have + // input and output streams to clean up. + sslInputStream = is; + sslOutputStream = os; + } + + // Don't bother interrupting unless we have something to interrupt. + if (sslInputStream != null || sslOutputStream != null) { NativeCrypto.SSL_interrupt(sslNativePointer); + } - synchronized (writeLock) { - synchronized (readLock) { - // Shut down the SSL connection, per se. - try { - if (handshakeStarted) { - BlockGuard.getThreadPolicy().onNetwork(); - NativeCrypto.SSL_shutdown(sslNativePointer, socket.getFileDescriptor$(), - this); - } - } catch (IOException ignored) { - /* - * Note that although close() can throw - * IOException, the RI does not throw if there - * is problem sending a "close notify" which - * can happen if the underlying socket is closed. - */ - } finally { - /* - * Even if the above call failed, it is still safe to free - * the native structs, and we need to do so lest we leak - * memory. - */ - free(); - - if (socket != this) { - if (autoClose && !socket.isClosed()) { - socket.close(); - } - } else { - if (!super.isClosed()) { - super.close(); - } - } - } - } + // Wait for the input and output streams to finish any reads they have in + // progress. If there are no reads in progress at this point, future reads will + // throw because state == STATE_CLOSED + if (sslInputStream != null) { + sslInputStream.awaitPendingOps(); + } + if (sslOutputStream != null) { + sslOutputStream.awaitPendingOps(); + } + + shutdownAndFreeSslNative(); + } + + private void shutdownAndFreeSslNative() throws IOException { + try { + BlockGuard.getThreadPolicy().onNetwork(); + NativeCrypto.SSL_shutdown(sslNativePointer, socket.getFileDescriptor$(), + this); + } catch (IOException ignored) { + /* + * Note that although close() can throw + * IOException, the RI does not throw if there + * is problem sending a "close notify" which + * can happen if the underlying socket is closed. + */ + } finally { + free(); + closeUnderlyingSocket(); + } + } + + private void closeUnderlyingSocket() throws IOException { + if (socket != this) { + if (autoClose && !socket.isClosed()) { + socket.close(); + } + } else { + if (!super.isClosed()) { + super.close(); } } }