From d92d208e30300d6befcab93c56b4d8ad679e4704 Mon Sep 17 00:00:00 2001 From: irunika Date: Wed, 13 Jun 2018 16:53:11 +0530 Subject: [PATCH] Add SSL support to WebSocket client --- .../http/netty/common/Constants.java | 4 +- .../transport/http/netty/common/Util.java | 71 ++++++- .../config/OutboundSslConfiguration.java | 186 ++++++++++++++++++ .../netty/config/SenderConfiguration.java | 161 +-------------- .../netty/contract/ServerConnectorFuture.java | 2 +- .../WebSocketClientConnectorConfig.java | 15 +- .../websocket/WebSocketControlMessage.java | 2 +- .../DefaultHttpClientConnector.java | 2 +- .../DefaultClientHandshakeFuture.java | 3 +- .../DefaultWebSocketClientConnector.java | 17 +- .../WebSocketInboundFrameHandler.java | 6 +- .../message/DefaultWebSocketInitMessage.java | 4 +- .../sender/HttpClientChannelInitializer.java | 61 +----- .../netty/sender/OCSPStaplingHandler.java | 2 +- .../sender/websocket/WebSocketClient.java | 163 +++++++++------ .../WebSocketClientHandshakeHandler.java | 62 +++--- .../transport/http/netty/util/TestUtil.java | 2 + .../WebSocketClientFunctionalityTestCase.java | 7 +- ...tClientHandshakeFunctionalityTestCase.java | 50 ++++- .../WebSocketTestClientConnectorListener.java | 8 +- .../WebSocketSSLHandshakeFailureTestCase.java | 128 ++++++++++++ ...bSocketSSLHandshakeSuccessfulTestCase.java | 138 +++++++++++++ .../src/test/resources/testng.xml | 2 + 23 files changed, 738 insertions(+), 358 deletions(-) create mode 100644 components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/config/OutboundSslConfiguration.java create mode 100644 components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/ssl/WebSocketSSLHandshakeFailureTestCase.java create mode 100644 components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/ssl/WebSocketSSLHandshakeSuccessfulTestCase.java diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Constants.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Constants.java index 879ca04be..0530f0e0f 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Constants.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Constants.java @@ -172,8 +172,8 @@ public final class Constants { public static final String LOCALHOST = "localhost"; public static final String HTTP_OBJECT_AGGREGATOR = "HTTP_OBJECT_AGGREGATOR"; - public static final String WEBSOCKET_PROTOCOL = "ws"; - public static final String WEBSOCKET_PROTOCOL_SECURED = "wss"; + public static final String WS_SCHEME = "ws"; + public static final String WSS_SCHEME = "wss"; public static final String WEBSOCKET_UPGRADE = "websocket"; public static final String WEBSOCKET_FRAME_HANDLER = "WEBSOCKET_FRAME_HANDLER"; public static final String WEBSOCKET_FRAME_BLOCKING_HANDLER = "WEBSOCKET_FRAME_BLOCKING_HANDLER"; diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Util.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Util.java index 9f5afbd90..f5e39604e 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Util.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Util.java @@ -20,6 +20,7 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; +import io.netty.channel.socket.SocketChannel; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.DefaultHttpRequest; import io.netty.handler.codec.http.DefaultHttpResponse; @@ -35,15 +36,22 @@ import io.netty.handler.codec.http2.Http2Exception; import io.netty.handler.codec.http2.Http2Headers; import io.netty.handler.codec.http2.HttpConversionUtil; +import io.netty.handler.ssl.ReferenceCountedOpenSslContext; +import io.netty.handler.ssl.ReferenceCountedOpenSslEngine; +import io.netty.handler.ssl.SslHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.wso2.transport.http.netty.common.ssl.SSLConfig; +import org.wso2.transport.http.netty.common.ssl.SSLHandlerFactory; import org.wso2.transport.http.netty.config.ChunkConfig; +import org.wso2.transport.http.netty.config.OutboundSslConfiguration; import org.wso2.transport.http.netty.config.Parameter; import org.wso2.transport.http.netty.contract.HttpResponseFuture; import org.wso2.transport.http.netty.message.DefaultListener; import org.wso2.transport.http.netty.message.HTTPCarbonMessage; import org.wso2.transport.http.netty.message.Listener; +import org.wso2.transport.http.netty.sender.CertificateValidationHandler; +import org.wso2.transport.http.netty.sender.OCSPStaplingHandler; import java.io.File; import java.io.IOException; @@ -54,6 +62,8 @@ import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; import static org.wso2.transport.http.netty.common.Constants.COLON; import static org.wso2.transport.http.netty.common.Constants.HTTP_HOST; @@ -341,7 +351,7 @@ public static SSLConfig getSSLConfigForSender(String certPass, String keyStorePa certPass = keyStorePass; } if (trustStoreFilePath == null || trustStorePass == null) { - throw new IllegalArgumentException("TrusStoreFile or trustStorePassword not defined for HTTPS scheme"); + throw new IllegalArgumentException("TrustStoreFile or trustStorePassword not defined for HTTPS/WSS scheme"); } SSLConfig sslConfig = new SSLConfig(null, null).setCertPass(null); @@ -374,6 +384,65 @@ public static SSLConfig getSSLConfigForSender(String certPass, String keyStorePa return sslConfig; } + /** + * Configure outbound HTTP pipeline for SSL configuration. + * + * @param socketChannel Socket channel of outbound connection + * @param sslConfiguration {@link OutboundSslConfiguration} + * @param host host of the connection + * @param port port of the connection + * @throws SSLException if any error occurs in the SSL connection + */ + public static void configureHttpPipelineForSSL(SocketChannel socketChannel, String host, int port, + OutboundSslConfiguration sslConfiguration) throws SSLException { + log.debug("adding ssl handler"); + SSLConfig sslConfig = sslConfiguration.generateSSLConfig(); + ChannelPipeline pipeline = socketChannel.pipeline(); + if (sslConfiguration.isOcspStaplingEnabled()) { + SSLHandlerFactory sslHandlerFactory = new SSLHandlerFactory(sslConfig); + ReferenceCountedOpenSslContext referenceCountedOpenSslContext = sslHandlerFactory + .buildClientReferenceCountedOpenSslContext(); + + if (referenceCountedOpenSslContext != null) { + SslHandler sslHandler = referenceCountedOpenSslContext.newHandler(socketChannel.alloc()); + ReferenceCountedOpenSslEngine engine = (ReferenceCountedOpenSslEngine) sslHandler.engine(); + socketChannel.pipeline().addLast(sslHandler); + socketChannel.pipeline().addLast(new OCSPStaplingHandler(engine)); + } + } else { + SSLEngine sslEngine = instantiateAndConfigSSL(sslConfig, host, port, + sslConfiguration.hostNameVerificationEnabled()); + pipeline.addLast(Constants.SSL_HANDLER, new SslHandler(sslEngine)); + if (sslConfiguration.validateCertEnabled()) { + pipeline.addLast(Constants.HTTP_CERT_VALIDATION_HANDLER, + new CertificateValidationHandler(sslEngine, sslConfiguration.getCacheValidityPeriod(), + sslConfiguration.getCacheSize())); + } + } + } + + /** + * Set configurations to create ssl engine. + * + * @param sslConfig ssl related configurations + * @return ssl engine + */ + private static SSLEngine instantiateAndConfigSSL(SSLConfig sslConfig, String host, int port, + boolean hostNameVerificationEnabled) { + // set the pipeline factory, which creates the pipeline for each newly created channels + SSLEngine sslEngine = null; + if (sslConfig != null) { + SSLHandlerFactory sslHandlerFactory = new SSLHandlerFactory(sslConfig); + sslEngine = sslHandlerFactory.buildClientSSLEngine(host, port); + sslEngine.setUseClientMode(true); + sslHandlerFactory.setSNIServerNames(sslEngine, host); + if (hostNameVerificationEnabled) { + sslHandlerFactory.setHostNameVerfication(sslEngine); + } + } + return sslEngine; + } + /** * Get integer type property value from a property map. *

diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/config/OutboundSslConfiguration.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/config/OutboundSslConfiguration.java new file mode 100644 index 000000000..3ac697e14 --- /dev/null +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/config/OutboundSslConfiguration.java @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2018, WSO2 Inc. (http://www.wso2.org) All Rights Reserved. + * + * WSO2 Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.transport.http.netty.config; + +import org.wso2.transport.http.netty.common.Util; +import org.wso2.transport.http.netty.common.ssl.SSLConfig; + +import java.util.ArrayList; +import java.util.List; +import javax.xml.bind.annotation.XmlAttribute; +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlElementWrapper; + +/** + * SSL configuration for Outbound HTTP connection. + */ +public class OutboundSslConfiguration { + + @XmlAttribute + private String scheme = "http"; + + @XmlAttribute + private String keyStoreFile; + + @XmlAttribute + private String keyStorePassword; + + @XmlAttribute + private String trustStoreFile; + + @XmlAttribute + private String trustStorePass; + + @XmlAttribute + private String certPass; + + @XmlAttribute + private String sslProtocol; + + @XmlElementWrapper(name = "parameters") + @XmlElement(name = "parameter") + private List parameters = new ArrayList<>(); + + private String tlsStoreType; + private boolean hostNameVerificationEnabled = true; + private boolean validateCertEnabled; + private int cacheValidityPeriod = 15; + private int cacheSize = 50; + private boolean ocspStaplingEnabled = false; + + public String getCertPass() { + return certPass; + } + + public void setCertPass(String certPass) { + this.certPass = certPass; + } + + public String getKeyStoreFile() { + return keyStoreFile; + } + + public void setKeyStoreFile(String keyStoreFile) { + this.keyStoreFile = keyStoreFile; + } + + public String getKeyStorePassword() { + return keyStorePassword; + } + + public void setKeyStorePassword(String keyStorePassword) { + this.keyStorePassword = keyStorePassword; + } + + public String getScheme() { + return scheme; + } + + public void setScheme(String scheme) { + this.scheme = scheme; + } + + public String getTrustStoreFile() { + return trustStoreFile; + } + + public void setTrustStoreFile(String trustStoreFile) { + this.trustStoreFile = trustStoreFile; + } + + public String getTrustStorePass() { + return trustStorePass; + } + + public void setTrustStorePass(String trustStorePass) { + this.trustStorePass = trustStorePass; + } + + public void setSSLProtocol(String sslProtocol) { + this.sslProtocol = sslProtocol; + } + + public String getSSLProtocol() { + return sslProtocol; + } + + public List getParameters() { + return parameters; + } + + public void setParameters(List parameters) { + this.parameters = parameters; + } + + public String getTLSStoreType() { + return tlsStoreType; + } + + public void setTLSStoreType(String storeType) { + this.tlsStoreType = storeType; + } + + public void setValidateCertEnabled(boolean validateCertEnabled) { + this.validateCertEnabled = validateCertEnabled; + } + + public boolean validateCertEnabled() { + return validateCertEnabled; + } + + public void setHostNameVerificationEnabled(boolean hostNameVerificationEnabled) { + this.hostNameVerificationEnabled = hostNameVerificationEnabled; + } + + public boolean hostNameVerificationEnabled() { + return hostNameVerificationEnabled; + } + + public void setCacheValidityPeriod(int cacheValidityPeriod) { + this.cacheValidityPeriod = cacheValidityPeriod; + } + + public int getCacheValidityPeriod() { + return cacheValidityPeriod; + } + + public void setCacheSize(int cacheSize) { + this.cacheSize = cacheSize; + } + + public int getCacheSize() { + return cacheSize; + } + + public void setOcspStaplingEnabled(boolean ocspStaplingEnabled) { + this.ocspStaplingEnabled = ocspStaplingEnabled; + } + + public boolean isOcspStaplingEnabled() { + return ocspStaplingEnabled; + } + + public SSLConfig generateSSLConfig() { + if (scheme == null || !scheme.equalsIgnoreCase("https")) { + return null; + } + return Util.getSSLConfigForSender(certPass, keyStorePassword, keyStoreFile, trustStoreFile, trustStorePass, + parameters, sslProtocol, tlsStoreType); + } +} diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/config/SenderConfiguration.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/config/SenderConfiguration.java index a4766a75d..5026254b7 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/config/SenderConfiguration.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/config/SenderConfiguration.java @@ -19,25 +19,18 @@ package org.wso2.transport.http.netty.config; import org.wso2.transport.http.netty.common.ProxyServerConfiguration; -import org.wso2.transport.http.netty.common.Util; -import org.wso2.transport.http.netty.common.ssl.SSLConfig; import org.wso2.transport.http.netty.sender.channel.pool.PoolConfiguration; -import java.util.ArrayList; -import java.util.List; import javax.xml.bind.annotation.XmlAccessType; import javax.xml.bind.annotation.XmlAccessorType; import javax.xml.bind.annotation.XmlAttribute; -import javax.xml.bind.annotation.XmlElement; -import javax.xml.bind.annotation.XmlElementWrapper; - /** * JAXB representation of the Netty transport sender configuration. */ @SuppressWarnings("unused") @XmlAccessorType(XmlAccessType.FIELD) -public class SenderConfiguration { +public class SenderConfiguration extends OutboundSslConfiguration { private static final String DEFAULT_KEY = "netty"; @@ -51,24 +44,6 @@ public static SenderConfiguration getDefault() { @XmlAttribute(required = true) private String id = DEFAULT_KEY; - @XmlAttribute - private String scheme = "http"; - - @XmlAttribute - private String keyStoreFile; - - @XmlAttribute - private String keyStorePassword; - - @XmlAttribute - private String trustStoreFile; - - @XmlAttribute - private String trustStorePass; - - @XmlAttribute - private String certPass; - @XmlAttribute private int socketIdleTimeout = 60000; @@ -77,28 +52,16 @@ public static SenderConfiguration getDefault() { private ChunkConfig chunkingConfig = ChunkConfig.AUTO; - @XmlAttribute - private String sslProtocol; - - @XmlElementWrapper(name = "parameters") - @XmlElement(name = "parameter") - private List parameters = new ArrayList<>(); - private KeepAliveConfig keepAliveConfig = KeepAliveConfig.AUTO; @XmlAttribute private boolean forceHttp2 = false; - private String tlsStoreType; private String httpVersion = "1.1"; private ProxyServerConfiguration proxyServerConfiguration; private PoolConfiguration poolConfiguration; - private boolean validateCertEnabled; - private int cacheSize = 50; - private int cacheValidityPeriod = 15; - private boolean hostNameVerificationEnabled = true; + private ForwardedExtensionConfig forwardedExtensionConfig; - private boolean ocspStaplingEnabled = false; public SenderConfiguration() { this.poolConfiguration = new PoolConfiguration(); @@ -109,30 +72,6 @@ public SenderConfiguration(String id) { this.poolConfiguration = new PoolConfiguration(); } - public void setSSLProtocol(String sslProtocol) { - this.sslProtocol = sslProtocol; - } - - public String getSSLProtocol() { - return sslProtocol; - } - - public String getCertPass() { - return certPass; - } - - public String getTLSStoreType() { - return tlsStoreType; - } - - public void setTLSStoreType(String storeType) { - this.tlsStoreType = storeType; - } - - public void setCertPass(String certPass) { - this.certPass = certPass; - } - public String getId() { return id; } @@ -141,62 +80,6 @@ public void setId(String id) { this.id = id; } - public String getKeyStoreFile() { - return keyStoreFile; - } - - public void setKeyStoreFile(String keyStoreFile) { - this.keyStoreFile = keyStoreFile; - } - - public String getKeyStorePassword() { - return keyStorePassword; - } - - public void setKeyStorePassword(String keyStorePassword) { - this.keyStorePassword = keyStorePassword; - } - - public String getScheme() { - return scheme; - } - - public void setScheme(String scheme) { - this.scheme = scheme; - } - - public List getParameters() { - return parameters; - } - - public void setParameters(List parameters) { - this.parameters = parameters; - } - - public String getTrustStoreFile() { - return trustStoreFile; - } - - public void setTrustStoreFile(String trustStoreFile) { - this.trustStoreFile = trustStoreFile; - } - - public String getTrustStorePass() { - return trustStorePass; - } - - public void setTrustStorePass(String trustStorePass) { - this.trustStorePass = trustStorePass; - } - - public SSLConfig getSSLConfig() { - if (scheme == null || !scheme.equalsIgnoreCase("https")) { - return null; - } - return Util.getSSLConfigForSender(certPass, keyStorePassword, keyStoreFile, trustStoreFile, trustStorePass, - parameters, sslProtocol, tlsStoreType); - } - public int getSocketIdleTimeout(int defaultValue) { if (socketIdleTimeout == 0) { return defaultValue; @@ -258,46 +141,6 @@ public void setForceHttp2(boolean forceHttp2) { this.forceHttp2 = forceHttp2; } - public void setValidateCertEnabled(boolean validateCertEnabled) { - this.validateCertEnabled = validateCertEnabled; - } - - public void setCacheSize(int cacheSize) { - this.cacheSize = cacheSize; - } - - public void setCacheValidityPeriod(int cacheValidityPeriod) { - this.cacheValidityPeriod = cacheValidityPeriod; - } - - public boolean validateCertEnabled() { - return validateCertEnabled; - } - - public int getCacheSize() { - return cacheSize; - } - - public void setHostNameVerificationEnabled(boolean hostNameVerificationEnabled) { - this.hostNameVerificationEnabled = hostNameVerificationEnabled; - } - - public boolean hostNameVerificationEnabled() { - return hostNameVerificationEnabled; - } - - public int getCacheValidityPeriod() { - return cacheValidityPeriod; - } - - public void setOcspStaplingEnabled(boolean ocspStaplingEnabled) { - this.ocspStaplingEnabled = ocspStaplingEnabled; - } - - public boolean isOcspStaplingEnabled() { - return ocspStaplingEnabled; - } - public PoolConfiguration getPoolConfiguration() { return poolConfiguration; } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/ServerConnectorFuture.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/ServerConnectorFuture.java index 72197925b..f80d943f2 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/ServerConnectorFuture.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/ServerConnectorFuture.java @@ -27,7 +27,7 @@ public interface ServerConnectorFuture extends HttpConnectorFuture, WebSocketConnectorFuture { /** - * Set life cycle event listener for the HTTP/WS connector + * Set life cycle event listener for the HTTP/WS_SCHEME connector * * @param portBindingEventListener The PortBindingEventListener implementation */ diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketClientConnectorConfig.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketClientConnectorConfig.java index 6d4406ab8..ee61cd188 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketClientConnectorConfig.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketClientConnectorConfig.java @@ -21,7 +21,10 @@ import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.HttpHeaders; +import org.wso2.transport.http.netty.common.Constants; +import org.wso2.transport.http.netty.config.OutboundSslConfiguration; +import java.net.URI; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -29,7 +32,7 @@ /** * Configuration for WebSocket client connector. */ -public class WebSocketClientConnectorConfig { +public class WebSocketClientConnectorConfig extends OutboundSslConfiguration { private final String remoteAddress; private List subProtocols; @@ -38,16 +41,10 @@ public class WebSocketClientConnectorConfig { private final HttpHeaders headers; public WebSocketClientConnectorConfig(String remoteAddress) { - this(remoteAddress, null, -1, true); - } - - public WebSocketClientConnectorConfig(String remoteAddress, List subProtocols, - int idleTimeoutInSeconds, boolean autoRead) { this.remoteAddress = remoteAddress; - this.subProtocols = subProtocols; - this.idleTimeoutInSeconds = idleTimeoutInSeconds; - this.autoRead = autoRead; this.headers = new DefaultHttpHeaders(); + this.setScheme(Constants.WSS_SCHEME.equals(URI.create(remoteAddress).getScheme()) + ? Constants.HTTPS_SCHEME : Constants.HTTP_SCHEME); } /** diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketControlMessage.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketControlMessage.java index 3863f89bb..cc6099980 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketControlMessage.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketControlMessage.java @@ -20,7 +20,7 @@ package org.wso2.transport.http.netty.contract.websocket; /** - * This message contains the details of WebSocket bong message. + * This message contains the details of WebSocket control message. */ public interface WebSocketControlMessage extends WebSocketBinaryMessage { diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/DefaultHttpClientConnector.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/DefaultHttpClientConnector.java index 9f3bc1502..96c0e4199 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/DefaultHttpClientConnector.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/DefaultHttpClientConnector.java @@ -292,7 +292,7 @@ private void initTargetChannelProperties(SenderConfiguration senderConfiguration this.httpVersion = senderConfiguration.getHttpVersion(); this.chunkConfig = senderConfiguration.getChunkingConfig(); this.socketIdleTimeout = senderConfiguration.getSocketIdleTimeout(Constants.ENDPOINT_TIMEOUT); - this.sslConfig = senderConfiguration.getSSLConfig(); + this.sslConfig = senderConfiguration.generateSSLConfig(); this.keepAliveConfig = senderConfiguration.getKeepAliveConfig(); this.forwardedExtensionConfig = senderConfiguration.getForwardedExtensionConfig(); } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultClientHandshakeFuture.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultClientHandshakeFuture.java index cb38c2a81..5886395b2 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultClientHandshakeFuture.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultClientHandshakeFuture.java @@ -42,8 +42,7 @@ public void setClientHandshakeListener(ClientHandshakeListener clientHandshakeLi this.clientHandshakeListener = clientHandshakeListener; if (throwable != null) { clientHandshakeListener.onError(throwable, response); - } - if (webSocketConnection != null && response != null) { + } else if (webSocketConnection != null && response != null) { clientHandshakeListener.onSuccess(webSocketConnection, response); } } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketClientConnector.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketClientConnector.java index 7304c093f..594d947ee 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketClientConnector.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketClientConnector.java @@ -20,7 +20,6 @@ package org.wso2.transport.http.netty.contractimpl.websocket; import io.netty.channel.EventLoopGroup; -import io.netty.handler.codec.http.HttpHeaders; import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeFuture; import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnector; import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; @@ -31,27 +30,15 @@ */ public class DefaultWebSocketClientConnector implements WebSocketClientConnector { - private final String remoteUrl; - private final String subProtocols; - private final int idleTimeout; - private final HttpHeaders customHeaders; - private final EventLoopGroup wsClientEventLoopGroup; - private final boolean autoRead; + private final WebSocketClient webSocketClient; public DefaultWebSocketClientConnector(WebSocketClientConnectorConfig clientConnectorConfig, EventLoopGroup wsClientEventLoopGroup) { - this.remoteUrl = clientConnectorConfig.getRemoteAddress(); - this.subProtocols = clientConnectorConfig.getSubProtocolsAsCSV(); - this.customHeaders = clientConnectorConfig.getHeaders(); - this.idleTimeout = clientConnectorConfig.getIdleTimeoutInMillis(); - this.wsClientEventLoopGroup = wsClientEventLoopGroup; - this.autoRead = clientConnectorConfig.isAutoRead(); + this.webSocketClient = new WebSocketClient(wsClientEventLoopGroup, clientConnectorConfig); } @Override public ClientHandshakeFuture connect() { - WebSocketClient webSocketClient = new WebSocketClient(remoteUrl, subProtocols, idleTimeout, - wsClientEventLoopGroup, customHeaders, autoRead); return webSocketClient.handshake(); } } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/WebSocketInboundFrameHandler.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/WebSocketInboundFrameHandler.java index d7f22f8d2..2a564f523 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/WebSocketInboundFrameHandler.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/WebSocketInboundFrameHandler.java @@ -131,7 +131,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Web @Override public void channelInactive(ChannelHandlerContext ctx) throws WebSocketConnectorException { - if (!caughtException && webSocketConnection != null && !this.isCloseFrameReceived() && closePromise == null) { + if (!caughtException && webSocketConnection != null && !closeFrameReceived && closePromise == null) { // Notify abnormal closure. DefaultWebSocketMessage webSocketCloseMessage = new DefaultWebSocketCloseMessage(Constants.WEBSOCKET_STATUS_CODE_ABNORMAL_CLOSURE); @@ -140,7 +140,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws WebSocketConnector return; } - if (closePromise != null && !closePromise.isDone()) { + if (closePromise != null && !closeFrameReceived) { String errMsg = "Connection is closed by remote endpoint without echoing a close frame"; ctx.close().addListener(closeFuture -> closePromise.setFailure(new IllegalStateException(errMsg))); } @@ -168,6 +168,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception } notifyBinaryMessage(binaryFrame, binaryFrame.content(), binaryFrame.isFinalFragment()); } else if (msg instanceof CloseWebSocketFrame) { + closeFrameReceived = true; notifyCloseMessage((CloseWebSocketFrame) msg); } else if (msg instanceof PingWebSocketFrame) { notifyPingMessage((PingWebSocketFrame) msg); @@ -221,7 +222,6 @@ private void notifyCloseMessage(CloseWebSocketFrame closeWebSocketFrame) throws if (closePromise == null) { DefaultWebSocketMessage webSocketCloseMessage = new DefaultWebSocketCloseMessage(statusCode, reasonText); setupCommonProperties(webSocketCloseMessage); - closeFrameReceived = true; connectorFuture.notifyWebSocketListener((WebSocketCloseMessage) webSocketCloseMessage); } else { if (webSocketConnection.getCloseInitiatedStatusCode() != closeWebSocketFrame.statusCode()) { diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketInitMessage.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketInitMessage.java index 6eba145bf..eeaccf851 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketInitMessage.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketInitMessage.java @@ -198,9 +198,9 @@ private void configureFrameHandlingPipeline(int idleTimeout, WebSocketFramesBloc /* Get the URL of the given connection */ private String getWebSocketURL(HttpRequest req) { - String protocol = Constants.WEBSOCKET_PROTOCOL; + String protocol = Constants.WS_SCHEME; if (secureConnection) { - protocol = Constants.WEBSOCKET_PROTOCOL_SECURED; + protocol = Constants.WSS_SCHEME; } return protocol + "://" + req.headers().get("Host") + req.uri(); } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/HttpClientChannelInitializer.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/HttpClientChannelInitializer.java index 41dca75d6..049ff5941 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/HttpClientChannelInitializer.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/HttpClientChannelInitializer.java @@ -44,6 +44,7 @@ import org.wso2.transport.http.netty.common.FrameLogger; import org.wso2.transport.http.netty.common.HttpRoute; import org.wso2.transport.http.netty.common.ProxyServerConfiguration; +import org.wso2.transport.http.netty.common.Util; import org.wso2.transport.http.netty.common.ssl.SSLConfig; import org.wso2.transport.http.netty.common.ssl.SSLHandlerFactory; import org.wso2.transport.http.netty.config.KeepAliveConfig; @@ -55,7 +56,6 @@ import org.wso2.transport.http.netty.sender.http2.Http2ConnectionManager; import org.wso2.transport.http.netty.sender.http2.Http2TargetHandler; -import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLException; import static io.netty.handler.logging.LogLevel.TRACE; @@ -99,7 +99,7 @@ public HttpClientChannelInitializer(SenderConfiguration senderConfiguration, Htt this.cacheSize = senderConfiguration.getCacheSize(); this.senderConfiguration = senderConfiguration; this.httpRoute = httpRoute; - this.sslConfig = senderConfiguration.getSSLConfig(); + this.sslConfig = senderConfiguration.generateSSLConfig(); this.connectionAvailabilityFuture = connectionAvailabilityFuture; String httpVersion = senderConfiguration.getHttpVersion(); @@ -129,7 +129,7 @@ protected void initChannel(SocketChannel socketChannel) throws Exception { targetHandler.setHttp2TargetHandler(http2TargetHandler); targetHandler.setKeepAliveConfig(getKeepAliveConfig()); if (http2) { - SSLConfig sslConfig = senderConfiguration.getSSLConfig(); + SSLConfig sslConfig = senderConfiguration.generateSSLConfig(); if (sslConfig != null) { configureSslForHttp2(socketChannel, clientPipeline, sslConfig); } else if (senderConfiguration.isForceHttp2()) { @@ -139,7 +139,11 @@ protected void initChannel(SocketChannel socketChannel) throws Exception { } } else { if (sslConfig != null) { - configureSslForHttp(clientPipeline, targetHandler, socketChannel); + connectionAvailabilityFuture.setSSLEnabled(true); + Util.configureHttpPipelineForSSL(socketChannel, httpRoute.getHost(), httpRoute.getPort(), + senderConfiguration); + clientPipeline.addLast(Constants.SSL_COMPLETION_HANDLER, + new SslHandshakeCompletionHandlerForClient(connectionAvailabilityFuture, this, targetHandler)); } else { configureHttpPipeline(clientPipeline, targetHandler); } @@ -162,34 +166,6 @@ private void configureProxyServer(ChannelPipeline clientPipeline) { } } - private void configureSslForHttp(ChannelPipeline clientPipeline, TargetHandler targetHandler, - SocketChannel socketChannel) - throws SSLException { - log.debug("adding ssl handler"); - connectionAvailabilityFuture.setSSLEnabled(true); - if (senderConfiguration.isOcspStaplingEnabled()) { - SSLHandlerFactory sslHandlerFactory = new SSLHandlerFactory(sslConfig); - ReferenceCountedOpenSslContext referenceCountedOpenSslContext = sslHandlerFactory - .buildClientReferenceCountedOpenSslContext(); - - if (referenceCountedOpenSslContext != null) { - SslHandler sslHandler = referenceCountedOpenSslContext.newHandler(socketChannel.alloc()); - ReferenceCountedOpenSslEngine engine = (ReferenceCountedOpenSslEngine) sslHandler.engine(); - socketChannel.pipeline().addLast(sslHandler); - socketChannel.pipeline().addLast(new OCSPStaplingHandler(engine)); - } - } else { - SSLEngine sslEngine = instantiateAndConfigSSL(sslConfig); - clientPipeline.addLast(Constants.SSL_HANDLER, new SslHandler(sslEngine)); - if (validateCertEnabled) { - clientPipeline.addLast(Constants.HTTP_CERT_VALIDATION_HANDLER, - new CertificateValidationHandler(sslEngine, this.cacheDelay, this.cacheSize)); - } - } - clientPipeline.addLast(Constants.SSL_COMPLETION_HANDLER, - new SslHandshakeCompletionHandlerForClient(connectionAvailabilityFuture, this, targetHandler)); - } - private void configureSslForHttp2(SocketChannel ch, ChannelPipeline clientPipeline, SSLConfig sslConfig) throws SSLException { connectionAvailabilityFuture.setSSLEnabled(true); @@ -277,27 +253,6 @@ private void addCommonHandlers(ChannelPipeline pipeline) { } } - /** - * Set configurations to create ssl engine. - * - * @param sslConfig ssl related configurations - * @return ssl engine - */ - private SSLEngine instantiateAndConfigSSL(SSLConfig sslConfig) { - // set the pipeline factory, which creates the pipeline for each newly created channels - SSLEngine sslEngine = null; - if (sslConfig != null) { - SSLHandlerFactory sslHandlerFactory = new SSLHandlerFactory(sslConfig); - sslEngine = sslHandlerFactory.buildClientSSLEngine(httpRoute.getHost(), httpRoute.getPort()); - sslEngine.setUseClientMode(true); - sslHandlerFactory.setSNIServerNames(sslEngine, httpRoute.getHost()); - if (senderConfiguration.hostNameVerificationEnabled()) { - sslHandlerFactory.setHostNameVerfication(sslEngine); - } - } - return sslEngine; - } - /** * Gets the associated {@link Http2Connection}. * diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/OCSPStaplingHandler.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/OCSPStaplingHandler.java index fcb13fa47..18a21cb60 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/OCSPStaplingHandler.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/OCSPStaplingHandler.java @@ -45,7 +45,7 @@ public class OCSPStaplingHandler extends OcspClientHandler { private static final Logger log = LoggerFactory.getLogger(OCSPStaplingHandler.class); - protected OCSPStaplingHandler(ReferenceCountedOpenSslEngine engine) { + public OCSPStaplingHandler(ReferenceCountedOpenSslEngine engine) { super(engine); } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClient.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClient.java index 4147bc4db..60b51b527 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClient.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClient.java @@ -20,6 +20,8 @@ package org.wso2.transport.http.netty.sender.websocket; import io.netty.bootstrap.Bootstrap; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; @@ -32,21 +34,23 @@ import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory; import io.netty.handler.codec.http.websocketx.WebSocketVersion; import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler; -import io.netty.handler.ssl.SslContext; -import io.netty.handler.ssl.SslContextBuilder; -import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.SslHandshakeCompletionEvent; import io.netty.handler.timeout.IdleStateHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.wso2.transport.http.netty.common.Constants; +import org.wso2.transport.http.netty.common.Util; +import org.wso2.transport.http.netty.common.ssl.SSLConfig; +import org.wso2.transport.http.netty.common.ssl.SSLHandlerFactory; import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeFuture; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; import org.wso2.transport.http.netty.contractimpl.websocket.DefaultClientHandshakeFuture; -import org.wso2.transport.http.netty.contractimpl.websocket.DefaultWebSocketConnection; import org.wso2.transport.http.netty.listener.WebSocketFramesBlockingHandler; import java.net.URI; import java.net.URISyntaxException; import java.util.concurrent.TimeUnit; +import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLException; /** @@ -64,23 +68,20 @@ public class WebSocketClient { private final HttpHeaders headers; private final EventLoopGroup wsClientEventLoopGroup; private final boolean autoRead; + private final WebSocketClientConnectorConfig connectorConfig; /** - * @param url url of the remote endpoint - * @param subProtocols subProtocols the negotiable sub-protocol if server is asking for it - * @param idleTimeout Idle timeout of the connection * @param wsClientEventLoopGroup of the client connector - * @param headers any specific headers which need to send to the server - * @param autoRead sets the read interest + * @param connectorConfig Connector configuration for WebSocket client. */ - public WebSocketClient(String url, String subProtocols, int idleTimeout, EventLoopGroup wsClientEventLoopGroup, - HttpHeaders headers, boolean autoRead) { - this.url = url; - this.subProtocols = subProtocols; - this.idleTimeout = idleTimeout; - this.headers = headers; + public WebSocketClient(EventLoopGroup wsClientEventLoopGroup, WebSocketClientConnectorConfig connectorConfig) { + this.url = connectorConfig.getRemoteAddress(); + this.subProtocols = connectorConfig.getSubProtocolsAsCSV(); + this.idleTimeout = connectorConfig.getIdleTimeoutInMillis(); + this.headers = connectorConfig.getHeaders(); this.wsClientEventLoopGroup = wsClientEventLoopGroup; - this.autoRead = autoRead; + this.autoRead = connectorConfig.isAutoRead(); + this.connectorConfig = connectorConfig; } /** @@ -89,71 +90,94 @@ public WebSocketClient(String url, String subProtocols, int idleTimeout, EventLo * @return handshake future for connection. */ public ClientHandshakeFuture handshake() { - ClientHandshakeFuture handshakeFuture = new DefaultClientHandshakeFuture(); + DefaultClientHandshakeFuture handshakeFuture = new DefaultClientHandshakeFuture(); try { URI uri = new URI(url); - String scheme = uri.getScheme() == null ? "ws" : uri.getScheme(); - final String host = uri.getHost() == null ? "127.0.0.1" : uri.getHost(); - final int port = getPort(uri); + String scheme = uri.getScheme(); if (!"ws".equalsIgnoreCase(scheme) && !"wss".equalsIgnoreCase(scheme)) { - log.error("Only WS(S) is supported."); - throw new URISyntaxException(url, "WebSocket client supports only WS(S) scheme"); + log.error("Only WS_SCHEME(S) is supported."); + throw new URISyntaxException(url, "WebSocket client supports only WS_SCHEME(S) scheme"); } + + final String host = uri.getHost() == null ? "127.0.0.1" : uri.getHost(); + final int port = getPort(uri); final boolean ssl = "wss".equalsIgnoreCase(scheme); WebSocketClientHandshaker webSocketHandshaker = WebSocketClientHandshakerFactory.newHandshaker( uri, WebSocketVersion.V13, subProtocols, true, headers); WebSocketFramesBlockingHandler blockingHandler = new WebSocketFramesBlockingHandler(); - clientHandshakeHandler = new WebSocketClientHandshakeHandler(webSocketHandshaker, blockingHandler, ssl, - autoRead, url, handshakeFuture); - Bootstrap clientBootstrap = initClientBootstrap(host, port, getSslContext(ssl)); - clientBootstrap.connect(uri.getHost(), port).sync().channel(); - clientHandshakeHandler.handshakeFuture().addListener(clientHandshakeFuture -> { - Throwable cause = clientHandshakeFuture.cause(); - if (clientHandshakeFuture.isSuccess() && cause == null) { - DefaultWebSocketConnection webSocketConnection = - clientHandshakeHandler.getInboundFrameHandler().getWebSocketConnection(); - String actualSubProtocol = webSocketHandshaker.actualSubprotocol(); - webSocketConnection.getDefaultWebSocketSession().setNegotiatedSubProtocol(actualSubProtocol); - handshakeFuture.notifySuccess(webSocketConnection, clientHandshakeHandler.getHttpCarbonResponse()); - } else { - handshakeFuture.notifyError(cause, clientHandshakeHandler.getHttpCarbonResponse()); - } - }); + clientHandshakeHandler = new WebSocketClientHandshakeHandler(webSocketHandshaker, handshakeFuture, + blockingHandler, ssl, autoRead, url, handshakeFuture); + SSLEngine sslEngine = instantiateAndConfigSSL(connectorConfig.generateSSLConfig(), host, port); + Bootstrap clientBootstrap = initClientBootstrap(sslEngine, host, port, handshakeFuture); + clientBootstrap.connect(uri.getHost(), port); + return handshakeFuture; } catch (Throwable throwable) { - if (clientHandshakeHandler != null) { - handshakeFuture.notifyError(throwable, clientHandshakeHandler.getHttpCarbonResponse()); - } else { - handshakeFuture.notifyError(throwable, null); + handleHandshakeError(handshakeFuture, throwable); + return handshakeFuture; + } + } + + private void handleHandshakeError(DefaultClientHandshakeFuture handshakeFuture, Throwable throwable) { + if (clientHandshakeHandler != null) { + handshakeFuture.notifyError(throwable, clientHandshakeHandler.getHttpCarbonResponse()); + } else { + handshakeFuture.notifyError(throwable, null); + } + } + + /** + * Set configurations to create ssl engine. + * + * @param sslConfig ssl related configurations + * @return ssl engine + */ + private SSLEngine instantiateAndConfigSSL(SSLConfig sslConfig, String host, int port) { + // set the pipeline factory, which creates the pipeline for each newly created channels + SSLEngine sslEngine = null; + if (sslConfig != null) { + SSLHandlerFactory sslHandlerFactory = new SSLHandlerFactory(sslConfig); + sslEngine = sslHandlerFactory.buildClientSSLEngine(host, port); + sslEngine.setUseClientMode(true); + sslHandlerFactory.setSNIServerNames(sslEngine, host); + if (connectorConfig.hostNameVerificationEnabled()) { + sslHandlerFactory.setHostNameVerfication(sslEngine); } } - return handshakeFuture; + return sslEngine; } - private Bootstrap initClientBootstrap(String host, int port, SslContext sslCtx) { + private Bootstrap initClientBootstrap(SSLEngine sslEngine, String host, int port, + DefaultClientHandshakeFuture handshakeFuture) { Bootstrap clientBootstrap = new Bootstrap(); clientBootstrap.group(wsClientEventLoopGroup).channel(NioSocketChannel.class).handler( new ChannelInitializer() { @Override - protected void initChannel(SocketChannel ch) { - ChannelPipeline pipeline = ch.pipeline(); - if (sslCtx != null) { - pipeline.addLast(sslCtx.newHandler(ch.alloc(), host, port)); + protected void initChannel(SocketChannel socketChannel) throws SSLException { + if (sslEngine != null) { + Util.configureHttpPipelineForSSL(socketChannel, host, port, connectorConfig); + socketChannel.pipeline().addLast(Constants.SSL_COMPLETION_HANDLER, + new WebSocketClientSSLHandshakeCompletionHandler(handshakeFuture)); + } else { + configureHandshakePipeline(socketChannel.pipeline()); } - pipeline.addLast(new HttpClientCodec()); - // Assuming that WebSocket Handshake messages will not be large than 8KB - pipeline.addLast(new HttpObjectAggregator(8192)); - pipeline.addLast(WebSocketClientCompressionHandler.INSTANCE); - if (idleTimeout > 0) { - pipeline.addLast(new IdleStateHandler(idleTimeout, idleTimeout, - idleTimeout, TimeUnit.MILLISECONDS)); - } - pipeline.addLast(Constants.WEBSOCKET_CLIENT_HANDSHAKE_HANDLER, clientHandshakeHandler); } }); return clientBootstrap; } + private void configureHandshakePipeline(ChannelPipeline pipeline) { + pipeline.addLast(new HttpClientCodec()); + // Assuming that WebSocket Handshake messages will not be large than 8KB + pipeline.addLast(new HttpObjectAggregator(8192)); + pipeline.addLast(WebSocketClientCompressionHandler.INSTANCE); + if (idleTimeout > 0) { + pipeline.addLast(new IdleStateHandler(idleTimeout, idleTimeout, + idleTimeout, TimeUnit.MILLISECONDS)); + } + pipeline.addLast(Constants.WEBSOCKET_CLIENT_HANDSHAKE_HANDLER, clientHandshakeHandler); + } + private int getPort(URI uri) { String scheme = uri.getScheme(); if (uri.getPort() == -1) { @@ -170,10 +194,27 @@ private int getPort(URI uri) { } } - private SslContext getSslContext(boolean ssl) throws SSLException { - if (ssl) { - return SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE).build(); + /** + * handler to identify the SSL handshake completion. + */ + private class WebSocketClientSSLHandshakeCompletionHandler extends ChannelInboundHandlerAdapter { + private final DefaultClientHandshakeFuture clientHandshakeFuture; + + private WebSocketClientSSLHandshakeCompletionHandler(DefaultClientHandshakeFuture clientHandshakeFuture) { + this.clientHandshakeFuture = clientHandshakeFuture; + } + + @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt; + if (event.isSuccess() && event.cause() == null) { + configureHandshakePipeline(ctx.channel().pipeline()); + ctx.pipeline().remove(Constants.SSL_COMPLETION_HANDLER); + ctx.fireChannelActive(); + } else { + clientHandshakeFuture.notifyError(event.cause(), null); + } + } } - return null; } } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClientHandshakeHandler.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClientHandshakeHandler.java index 0eaa8e17c..6e95314be 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClientHandshakeHandler.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClientHandshakeHandler.java @@ -18,10 +18,8 @@ package org.wso2.transport.http.netty.sender.websocket; -import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker; @@ -29,6 +27,8 @@ import org.slf4j.LoggerFactory; import org.wso2.transport.http.netty.common.Constants; import org.wso2.transport.http.netty.contract.websocket.WebSocketConnectorFuture; +import org.wso2.transport.http.netty.contractimpl.websocket.DefaultClientHandshakeFuture; +import org.wso2.transport.http.netty.contractimpl.websocket.DefaultWebSocketConnection; import org.wso2.transport.http.netty.contractimpl.websocket.WebSocketInboundFrameHandler; import org.wso2.transport.http.netty.listener.WebSocketFramesBlockingHandler; import org.wso2.transport.http.netty.message.DefaultListener; @@ -46,14 +46,13 @@ public class WebSocketClientHandshakeHandler extends ChannelInboundHandlerAdapte private final boolean isSecure; private final boolean autoRead; private final String requestedUri; - private ChannelPromise handshakeFuture; + private DefaultClientHandshakeFuture handshakeFuture; private HttpCarbonResponse httpCarbonResponse; private final WebSocketConnectorFuture connectorFuture; - private WebSocketInboundFrameHandler inboundFrameHandler; public WebSocketClientHandshakeHandler(WebSocketClientHandshaker handshaker, - WebSocketFramesBlockingHandler framesBlockingHandler, boolean isSecure, - boolean autoRead, String requestedUri, WebSocketConnectorFuture connectorFuture) { + DefaultClientHandshakeFuture handshakeFuture, WebSocketFramesBlockingHandler framesBlockingHandler, + boolean isSecure, boolean autoRead, String requestedUri, WebSocketConnectorFuture connectorFuture) { this.handshaker = handshaker; this.blockingHandler = framesBlockingHandler; this.isSecure = isSecure; @@ -61,57 +60,52 @@ public WebSocketClientHandshakeHandler(WebSocketClientHandshaker handshaker, this.requestedUri = requestedUri; this.handshakeFuture = null; this.connectorFuture = connectorFuture; - } - - public ChannelFuture handshakeFuture() { - return handshakeFuture; + this.handshakeFuture = handshakeFuture; } public HttpCarbonResponse getHttpCarbonResponse() { return httpCarbonResponse; } - public WebSocketInboundFrameHandler getInboundFrameHandler() { - return this.inboundFrameHandler; - } - - @Override - public void handlerAdded(ChannelHandlerContext ctx) { - handshakeFuture = ctx.newPromise(); - } - @Override public void channelActive(ChannelHandlerContext ctx) { handshaker.handshake(ctx.channel()); } @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + public void channelRead(ChannelHandlerContext ctx, Object msg) { if (!(msg instanceof FullHttpResponse)) { throw new IllegalArgumentException("HTTP response is expected!"); } FullHttpResponse fullHttpResponse = (FullHttpResponse) msg; - httpCarbonResponse = setUpCarbonMessage(ctx, fullHttpResponse); - log.debug("WebSocket Client connected!"); - ctx.channel().config().setAutoRead(autoRead); - if (!autoRead) { - ctx.channel().pipeline().addLast(Constants.WEBSOCKET_FRAME_BLOCKING_HANDLER, blockingHandler); + try { + httpCarbonResponse = setUpCarbonMessage(ctx, fullHttpResponse); + log.debug("WebSocket Client connected!"); + ctx.channel().config().setAutoRead(false); + if (!autoRead) { + ctx.channel().pipeline().addLast(Constants.WEBSOCKET_FRAME_BLOCKING_HANDLER, blockingHandler); + } + WebSocketInboundFrameHandler inboundFrameHandler = new WebSocketInboundFrameHandler(connectorFuture, + blockingHandler, false, isSecure, requestedUri, null); + ctx.channel().pipeline().addLast(Constants.WEBSOCKET_FRAME_HANDLER, inboundFrameHandler); + handshaker.finishHandshake(ctx.channel(), fullHttpResponse); + ctx.channel().pipeline().remove(Constants.WEBSOCKET_CLIENT_HANDSHAKE_HANDLER); + ctx.fireChannelActive(); + DefaultWebSocketConnection webSocketConnection = inboundFrameHandler.getWebSocketConnection(); + String actualSubProtocol = handshaker.actualSubprotocol(); + webSocketConnection.getDefaultWebSocketSession().setNegotiatedSubProtocol(actualSubProtocol); + handshakeFuture.notifySuccess(webSocketConnection, httpCarbonResponse); + ctx.channel().config().setAutoRead(autoRead); + } finally { + fullHttpResponse.release(); } - inboundFrameHandler = - new WebSocketInboundFrameHandler(connectorFuture, blockingHandler, false, isSecure, requestedUri, null); - ctx.channel().pipeline().addLast(Constants.WEBSOCKET_FRAME_HANDLER, inboundFrameHandler); - handshaker.finishHandshake(ctx.channel(), fullHttpResponse); - ctx.channel().pipeline().remove(Constants.WEBSOCKET_CLIENT_HANDSHAKE_HANDLER); - ctx.fireChannelActive(); - handshakeFuture.setSuccess(); - fullHttpResponse.release(); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { log.error("Caught exception", cause); - handshakeFuture.setFailure(cause); + handshakeFuture.notifyError(cause, httpCarbonResponse); } private HttpCarbonResponse setUpCarbonMessage(ChannelHandlerContext ctx, HttpResponse msg) { diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/util/TestUtil.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/util/TestUtil.java index 093470ebf..8f7583470 100644 --- a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/util/TestUtil.java +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/util/TestUtil.java @@ -95,6 +95,8 @@ public class TestUtil { public static final TimeUnit HTTP2_RESPONSE_TIME_UNIT = TimeUnit.SECONDS; public static final String WEBSOCKET_REMOTE_SERVER_URL = String.format("ws://%s:%d/%s", TEST_HOST, WEBSOCKET_REMOTE_SERVER_PORT, "websocket"); + public static final String WEBSOCKET_SECURE_REMOTE_SERVER_URL = + String.format("wss://%s:%d/%s", TEST_HOST, WEBSOCKET_REMOTE_SERVER_PORT, "websocket"); private static final DefaultHttpWsConnectorFactory httpConnectorFactory = new DefaultHttpWsConnectorFactory(); public static HttpServer startHTTPServer(int port, ChannelInitializer channelInitializer) { diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientFunctionalityTestCase.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientFunctionalityTestCase.java index f579cbd01..676a86aac 100644 --- a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientFunctionalityTestCase.java +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientFunctionalityTestCase.java @@ -29,6 +29,7 @@ import org.wso2.transport.http.netty.contract.ServerConnectorException; import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeFuture; import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeListener; +import org.wso2.transport.http.netty.contract.websocket.WebSocketBinaryMessage; import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnector; import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; import org.wso2.transport.http.netty.contract.websocket.WebSocketCloseMessage; @@ -63,6 +64,7 @@ public void setup() throws InterruptedException { remoteServer = new WebSocketRemoteServer(WEBSOCKET_REMOTE_SERVER_PORT, "xml, json"); remoteServer.run(); WebSocketClientConnectorConfig configuration = new WebSocketClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); + configuration.setAutoRead(true); clientConnector = httpConnectorFactory.createWsClientConnector(configuration); } @@ -129,9 +131,10 @@ public void testBinarySendAndReceive() throws Throwable { byte[] bytes = {1, 2, 3, 4, 5}; ByteBuffer bufferSent = ByteBuffer.wrap(bytes); WebSocketTestClientConnectorListener connectorListener = handshakeAndSendBinaryMessage(bufferSent); - ByteBuffer bufferReceived = connectorListener.getReceivedByteBufferToClient(); + WebSocketBinaryMessage receivedBinaryMessage = connectorListener.getReceivedBinaryMessageToClient(); - Assert.assertEquals(bufferReceived, bufferSent); + Assert.assertEquals(receivedBinaryMessage.getByteBuffer(), bufferSent); + Assert.assertEquals(receivedBinaryMessage.getByteArray(), bytes); } private WebSocketTestClientConnectorListener handshakeAndSendBinaryMessage(ByteBuffer bufferSent) diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientHandshakeFunctionalityTestCase.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientHandshakeFunctionalityTestCase.java index 4228dc65b..5e8b04bdf 100644 --- a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientHandshakeFunctionalityTestCase.java +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientHandshakeFunctionalityTestCase.java @@ -35,6 +35,7 @@ import org.wso2.transport.http.netty.message.HttpCarbonResponse; import org.wso2.transport.http.netty.util.server.websocket.WebSocketRemoteServer; +import java.net.URISyntaxException; import java.util.NoSuchElementException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -51,16 +52,14 @@ public class WebSocketClientHandshakeFunctionalityTestCase { private static final Logger log = LoggerFactory.getLogger(WebSocketClientHandshakeFunctionalityTestCase.class); - private DefaultHttpWsConnectorFactory httpConnectorFactory = new DefaultHttpWsConnectorFactory(); - private WebSocketClientConnector clientConnector; + private DefaultHttpWsConnectorFactory httpConnectorFactory; private WebSocketRemoteServer remoteServer; @BeforeClass public void setup() throws InterruptedException { remoteServer = new WebSocketRemoteServer(WEBSOCKET_REMOTE_SERVER_PORT, "xml, json"); remoteServer.run(); - WebSocketClientConnectorConfig configuration = new WebSocketClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); - clientConnector = httpConnectorFactory.createWsClientConnector(configuration); + httpConnectorFactory = new DefaultHttpWsConnectorFactory(); } @Test(description = "Test the idle timeout for WebSocket") @@ -144,6 +143,42 @@ public void testReadNextFrame() throws Throwable { } } + @Test + public void testInvalidUrl() throws InterruptedException { + String invalidUrl = "myUrl"; + connectAndAssertInvalidUrl(invalidUrl); + + String urlWithWrongScheme = "http://localhost:9090/websocket"; + connectAndAssertInvalidUrl(urlWithWrongScheme); + } + + private void connectAndAssertInvalidUrl(String url) throws InterruptedException { + WebSocketClientConnectorConfig clientConnectorConfig = new WebSocketClientConnectorConfig(url); + HandshakeResult result = connectAndGetHandshakeResult(clientConnectorConfig); + Throwable throwable = result.getThrowable(); + + Assert.assertNull(result.getWebSocketConnection()); + Assert.assertNull(result.getHandshakeResponse()); + Assert.assertNotNull(throwable); + Assert.assertTrue(throwable instanceof URISyntaxException); + Assert.assertEquals(throwable.getMessage(), "WebSocket client supports only WS_SCHEME(S) scheme: " + url); + } + + @Test + public void testWssCallWithoutSslConfig() throws InterruptedException { + String url = "wss://localhost:9090/websocket"; + WebSocketClientConnectorConfig clientConnectorConfig = new WebSocketClientConnectorConfig(url); + HandshakeResult result = connectAndGetHandshakeResult(clientConnectorConfig); + Throwable throwable = result.getThrowable(); + + Assert.assertNull(result.getWebSocketConnection()); + Assert.assertNull(result.getHandshakeResponse()); + Assert.assertNotNull(throwable); + Assert.assertTrue(throwable instanceof IllegalArgumentException); + Assert.assertEquals(throwable.getMessage(), + "TrustStoreFile or trustStorePassword not defined for HTTPS/WSS_SCHEME scheme"); + } + private String readNextTextMsg(WebSocketTestClientConnectorListener connectorListener, WebSocketConnection webSocketConnection) throws Throwable { CountDownLatch latch = new CountDownLatch(1); @@ -166,9 +201,9 @@ private String[] sendTextMessages(WebSocketConnection webSocketConnection, int n private HandshakeResult connectAndGetHandshakeResult(WebSocketClientConnectorConfig configuration) throws InterruptedException { - clientConnector = httpConnectorFactory.createWsClientConnector(configuration); + WebSocketClientConnector clientConnector = httpConnectorFactory.createWsClientConnector(configuration); WebSocketTestClientConnectorListener connectorListener = new WebSocketTestClientConnectorListener(); - ClientHandshakeFuture handshakeFuture = handshake(connectorListener); + ClientHandshakeFuture handshakeFuture = handshake(clientConnector, connectorListener); CountDownLatch handshakeFutureLatch = new CountDownLatch(1); AtomicReference connectionAtomicReference = new AtomicReference<>(); @@ -231,7 +266,8 @@ public void cleanUp() throws ServerConnectorException, InterruptedException { httpConnectorFactory.shutdown(); } - private ClientHandshakeFuture handshake(WebSocketConnectorListener connectorListener) { + private ClientHandshakeFuture handshake(WebSocketClientConnector clientConnector, + WebSocketConnectorListener connectorListener) { ClientHandshakeFuture clientHandshakeFuture = clientConnector.connect(); clientHandshakeFuture.setWebSocketConnectorListener(connectorListener); return clientHandshakeFuture; diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketTestClientConnectorListener.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketTestClientConnectorListener.java index 0e75dc709..1b0e266ac 100644 --- a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketTestClientConnectorListener.java +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketTestClientConnectorListener.java @@ -42,7 +42,7 @@ public class WebSocketTestClientConnectorListener implements WebSocketConnectorL private static final Logger log = LoggerFactory.getLogger(WebSocketTestClientConnectorListener.class); private final Queue textQueue = new LinkedList<>(); - private final Queue bufferQueue = new LinkedList<>(); + private final Queue binaryMessageQueue = new LinkedList<>(); private final Queue errorsQueue = new LinkedList<>(); private static final String PING = "ping"; private WebSocketCloseMessage closeMessage = null; @@ -88,7 +88,7 @@ public void onMessage(WebSocketTextMessage textMessage) { @Override public void onMessage(WebSocketBinaryMessage binaryMessage) { - bufferQueue.add(binaryMessage.getByteBuffer()); + binaryMessageQueue.add(binaryMessage); countDownLatch(); } @@ -144,9 +144,9 @@ public String getReceivedTextToClient() throws Throwable { * * @return the latest {@link ByteBuffer} received to client. */ - public ByteBuffer getReceivedByteBufferToClient() throws Throwable { + public WebSocketBinaryMessage getReceivedBinaryMessageToClient() throws Throwable { if (errorsQueue.isEmpty()) { - return bufferQueue.remove(); + return binaryMessageQueue.remove(); } throw errorsQueue.remove(); } diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/ssl/WebSocketSSLHandshakeFailureTestCase.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/ssl/WebSocketSSLHandshakeFailureTestCase.java new file mode 100644 index 000000000..7fad2489d --- /dev/null +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/ssl/WebSocketSSLHandshakeFailureTestCase.java @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2018, WSO2 Inc. (http://www.wso2.org) All Rights Reserved. + * + * WSO2 Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.transport.http.netty.websocket.ssl; + +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; +import org.wso2.transport.http.netty.config.ListenerConfiguration; +import org.wso2.transport.http.netty.contract.HttpWsConnectorFactory; +import org.wso2.transport.http.netty.contract.ServerConnector; +import org.wso2.transport.http.netty.contract.ServerConnectorFuture; +import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeFuture; +import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeListener; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnector; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; +import org.wso2.transport.http.netty.contract.websocket.WebSocketConnection; +import org.wso2.transport.http.netty.contractimpl.DefaultHttpWsConnectorFactory; +import org.wso2.transport.http.netty.message.HttpCarbonResponse; +import org.wso2.transport.http.netty.util.TestUtil; +import org.wso2.transport.http.netty.websocket.client.WebSocketTestClientConnectorListener; +import org.wso2.transport.http.netty.websocket.server.WebSocketTestServerConnectorListener; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static org.wso2.transport.http.netty.util.TestUtil.WEBSOCKET_REMOTE_SERVER_PORT; +import static org.wso2.transport.http.netty.util.TestUtil.WEBSOCKET_SECURE_REMOTE_SERVER_URL; +import static org.wso2.transport.http.netty.util.TestUtil.WEBSOCKET_TEST_IDLE_TIMEOUT; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class WebSocketSSLHandshakeFailureTestCase { + + private String password = "ballerina"; + private String tlsStoreType = "PKCS12"; + private HttpWsConnectorFactory httpConnectorFactory; + private ServerConnector serverConnector; + + @BeforeClass + public void setup() throws InterruptedException { + //set PKCS12 truststore to ballerina client. + httpConnectorFactory = new DefaultHttpWsConnectorFactory(); + + ListenerConfiguration listenerConfiguration = getListenerConfiguration(); + serverConnector = httpConnectorFactory + .createServerConnector(TestUtil.getDefaultServerBootstrapConfig(), listenerConfiguration); + ServerConnectorFuture future = serverConnector.start(); + future.setWebSocketConnectorListener(new WebSocketTestServerConnectorListener()); + future.sync(); + } + + private ListenerConfiguration getListenerConfiguration() { + ListenerConfiguration listenerConfiguration = ListenerConfiguration.getDefault(); + listenerConfiguration.setPort(WEBSOCKET_REMOTE_SERVER_PORT); + //set PKCS12 keystore to ballerina server. + String keyStoreFile = "/simple-test-config/wso2carbon.p12"; + listenerConfiguration.setKeyStoreFile(TestUtil.getAbsolutePath(keyStoreFile)); + listenerConfiguration.setScheme("https"); + listenerConfiguration.setKeyStorePass(password); + listenerConfiguration.setCertPass(password); + listenerConfiguration.setTLSStoreType(tlsStoreType); + return listenerConfiguration; + } + + private WebSocketClientConnectorConfig getWebSocketClientConnectorConfigWithSSL() { + WebSocketClientConnectorConfig clientConnectorConfig = + new WebSocketClientConnectorConfig(WEBSOCKET_SECURE_REMOTE_SERVER_URL); + String trustStoreFile = "/simple-test-config/cacerts.p12"; + clientConnectorConfig.setTrustStoreFile(TestUtil.getAbsolutePath(trustStoreFile)); + clientConnectorConfig.setTrustStorePass("cacertspassword"); + clientConnectorConfig.setTLSStoreType(tlsStoreType); + return clientConnectorConfig; + } + + @Test + public void testClientConnectionWithSSL() throws Throwable { + WebSocketClientConnector webSocketClientConnector = + httpConnectorFactory.createWsClientConnector(getWebSocketClientConnectorConfigWithSSL()); + CountDownLatch countDownLatch = new CountDownLatch(1); + AtomicReference webSocketConnectionAtomicReference = new AtomicReference<>(); + AtomicReference throwableAtomicReference = new AtomicReference<>(); + ClientHandshakeFuture handshakeFuture = webSocketClientConnector.connect(); + WebSocketTestClientConnectorListener clientConnectorListener = new WebSocketTestClientConnectorListener(); + handshakeFuture.setWebSocketConnectorListener(clientConnectorListener); + handshakeFuture.setClientHandshakeListener(new ClientHandshakeListener() { + @Override public void onSuccess(WebSocketConnection webSocketConnection, HttpCarbonResponse response) { + webSocketConnectionAtomicReference.set(webSocketConnection); + countDownLatch.countDown(); + } + + @Override public void onError(Throwable throwable, HttpCarbonResponse response) { + throwableAtomicReference.set(throwable); + countDownLatch.countDown(); + } + }); + countDownLatch.await(WEBSOCKET_TEST_IDLE_TIMEOUT, SECONDS); + Throwable throwable = throwableAtomicReference.get(); + + Assert.assertNull(webSocketConnectionAtomicReference.get()); + Assert.assertNotNull(throwable); + Assert.assertEquals(throwable.getMessage(), "General SSLEngine problem"); + } + + + + @AfterClass + public void cleanup() throws InterruptedException { + serverConnector.stop(); + httpConnectorFactory.shutdown(); + } + +} diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/ssl/WebSocketSSLHandshakeSuccessfulTestCase.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/ssl/WebSocketSSLHandshakeSuccessfulTestCase.java new file mode 100644 index 000000000..009a02e1c --- /dev/null +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/ssl/WebSocketSSLHandshakeSuccessfulTestCase.java @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2018, WSO2 Inc. (http://www.wso2.org) All Rights Reserved. + * + * WSO2 Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.transport.http.netty.websocket.ssl; + +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; +import org.wso2.transport.http.netty.config.ListenerConfiguration; +import org.wso2.transport.http.netty.contract.HttpWsConnectorFactory; +import org.wso2.transport.http.netty.contract.ServerConnector; +import org.wso2.transport.http.netty.contract.ServerConnectorFuture; +import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeFuture; +import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeListener; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnector; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; +import org.wso2.transport.http.netty.contract.websocket.WebSocketConnection; +import org.wso2.transport.http.netty.contractimpl.DefaultHttpWsConnectorFactory; +import org.wso2.transport.http.netty.message.HttpCarbonResponse; +import org.wso2.transport.http.netty.util.TestUtil; +import org.wso2.transport.http.netty.websocket.client.WebSocketTestClientConnectorListener; +import org.wso2.transport.http.netty.websocket.server.WebSocketTestServerConnectorListener; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static org.wso2.transport.http.netty.util.TestUtil.WEBSOCKET_REMOTE_SERVER_PORT; +import static org.wso2.transport.http.netty.util.TestUtil.WEBSOCKET_SECURE_REMOTE_SERVER_URL; +import static org.wso2.transport.http.netty.util.TestUtil.WEBSOCKET_TEST_IDLE_TIMEOUT; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class WebSocketSSLHandshakeSuccessfulTestCase { + + private String password = "ballerina"; + private String tlsStoreType = "PKCS12"; + private HttpWsConnectorFactory httpConnectorFactory; + private ServerConnector serverConnector; + + @BeforeClass + public void setup() throws InterruptedException { + //set PKCS12 truststore to ballerina client. + httpConnectorFactory = new DefaultHttpWsConnectorFactory(); + + ListenerConfiguration listenerConfiguration = getListenerConfiguration(); + serverConnector = httpConnectorFactory + .createServerConnector(TestUtil.getDefaultServerBootstrapConfig(), listenerConfiguration); + ServerConnectorFuture future = serverConnector.start(); + future.setWebSocketConnectorListener(new WebSocketTestServerConnectorListener()); + future.sync(); + } + + private ListenerConfiguration getListenerConfiguration() { + ListenerConfiguration listenerConfiguration = ListenerConfiguration.getDefault(); + listenerConfiguration.setPort(WEBSOCKET_REMOTE_SERVER_PORT); + //set PKCS12 keystore to ballerina server. + String keyStoreFile = "/simple-test-config/wso2carbon.p12"; + listenerConfiguration.setKeyStoreFile(TestUtil.getAbsolutePath(keyStoreFile)); + listenerConfiguration.setScheme("https"); + listenerConfiguration.setKeyStorePass(password); + listenerConfiguration.setCertPass(password); + listenerConfiguration.setTLSStoreType(tlsStoreType); + return listenerConfiguration; + } + + private WebSocketClientConnectorConfig getWebSocketClientConnectorConfigWithSSL() { + WebSocketClientConnectorConfig senderConfiguration = + new WebSocketClientConnectorConfig(WEBSOCKET_SECURE_REMOTE_SERVER_URL); + String trustStoreFile = "/simple-test-config/client-truststore.p12"; + senderConfiguration.setTrustStoreFile(TestUtil.getAbsolutePath(trustStoreFile)); + senderConfiguration.setTrustStorePass(password); + senderConfiguration.setTLSStoreType(tlsStoreType); + return senderConfiguration; + } + + @Test + public void testClientConnectionWithSSL() throws Throwable { + WebSocketClientConnector webSocketClientConnector = + httpConnectorFactory.createWsClientConnector(getWebSocketClientConnectorConfigWithSSL()); + CountDownLatch countDownLatch = new CountDownLatch(1); + AtomicReference webSocketConnectionAtomicReference = new AtomicReference<>(); + AtomicReference throwableAtomicReference = new AtomicReference<>(); + ClientHandshakeFuture handshakeFuture = webSocketClientConnector.connect(); + WebSocketTestClientConnectorListener clientConnectorListener = new WebSocketTestClientConnectorListener(); + handshakeFuture.setWebSocketConnectorListener(clientConnectorListener); + handshakeFuture.setClientHandshakeListener(new ClientHandshakeListener() { + @Override public void onSuccess(WebSocketConnection webSocketConnection, HttpCarbonResponse response) { + webSocketConnectionAtomicReference.set(webSocketConnection); + countDownLatch.countDown(); + } + + @Override public void onError(Throwable throwable, HttpCarbonResponse response) { + throwableAtomicReference.set(throwable); + countDownLatch.countDown(); + } + }); + countDownLatch.await(WEBSOCKET_TEST_IDLE_TIMEOUT, SECONDS); + WebSocketConnection webSocketConnection = webSocketConnectionAtomicReference.get(); + + Assert.assertNull(throwableAtomicReference.get()); + Assert.assertNotNull(webSocketConnection); + Assert.assertTrue(webSocketConnection.getSession().isSecure()); + + // Test whether message can be received after a successful handshake. + webSocketConnection.startReadingFrames(); + String testText = "testText"; + CountDownLatch msgCountDownLatch = new CountDownLatch(1); + clientConnectorListener.setCountDownLatch(msgCountDownLatch); + webSocketConnection.pushText(testText); + msgCountDownLatch.await(WEBSOCKET_TEST_IDLE_TIMEOUT, SECONDS); + + Assert.assertEquals(clientConnectorListener.getReceivedTextToClient(), testText); + } + + + + @AfterClass + public void cleanup() throws InterruptedException { + serverConnector.stop(); + httpConnectorFactory.shutdown(); + } + +} diff --git a/components/org.wso2.transport.http.netty/src/test/resources/testng.xml b/components/org.wso2.transport.http.netty/src/test/resources/testng.xml index c7282b84b..e23515d52 100644 --- a/components/org.wso2.transport.http.netty/src/test/resources/testng.xml +++ b/components/org.wso2.transport.http.netty/src/test/resources/testng.xml @@ -82,6 +82,8 @@ + +