diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Constants.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Constants.java index 879ca04be..0530f0e0f 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Constants.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Constants.java @@ -172,8 +172,8 @@ public final class Constants { public static final String LOCALHOST = "localhost"; public static final String HTTP_OBJECT_AGGREGATOR = "HTTP_OBJECT_AGGREGATOR"; - public static final String WEBSOCKET_PROTOCOL = "ws"; - public static final String WEBSOCKET_PROTOCOL_SECURED = "wss"; + public static final String WS_SCHEME = "ws"; + public static final String WSS_SCHEME = "wss"; public static final String WEBSOCKET_UPGRADE = "websocket"; public static final String WEBSOCKET_FRAME_HANDLER = "WEBSOCKET_FRAME_HANDLER"; public static final String WEBSOCKET_FRAME_BLOCKING_HANDLER = "WEBSOCKET_FRAME_BLOCKING_HANDLER"; diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Util.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Util.java index 9f5afbd90..7fecc32c7 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Util.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Util.java @@ -20,6 +20,7 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; +import io.netty.channel.socket.SocketChannel; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.DefaultHttpRequest; import io.netty.handler.codec.http.DefaultHttpResponse; @@ -35,15 +36,22 @@ import io.netty.handler.codec.http2.Http2Exception; import io.netty.handler.codec.http2.Http2Headers; import io.netty.handler.codec.http2.HttpConversionUtil; +import io.netty.handler.ssl.ReferenceCountedOpenSslContext; +import io.netty.handler.ssl.ReferenceCountedOpenSslEngine; +import io.netty.handler.ssl.SslHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.wso2.transport.http.netty.common.ssl.SSLConfig; +import org.wso2.transport.http.netty.common.ssl.SSLHandlerFactory; import org.wso2.transport.http.netty.config.ChunkConfig; import org.wso2.transport.http.netty.config.Parameter; +import org.wso2.transport.http.netty.config.SslConfiguration; import org.wso2.transport.http.netty.contract.HttpResponseFuture; import org.wso2.transport.http.netty.message.DefaultListener; import org.wso2.transport.http.netty.message.HTTPCarbonMessage; import org.wso2.transport.http.netty.message.Listener; +import org.wso2.transport.http.netty.sender.CertificateValidationHandler; +import org.wso2.transport.http.netty.sender.OCSPStaplingHandler; import java.io.File; import java.io.IOException; @@ -54,6 +62,8 @@ import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; import static org.wso2.transport.http.netty.common.Constants.COLON; import static org.wso2.transport.http.netty.common.Constants.HTTP_HOST; @@ -341,7 +351,7 @@ public static SSLConfig getSSLConfigForSender(String certPass, String keyStorePa certPass = keyStorePass; } if (trustStoreFilePath == null || trustStorePass == null) { - throw new IllegalArgumentException("TrusStoreFile or trustStorePassword not defined for HTTPS scheme"); + throw new IllegalArgumentException("TrustStoreFile or trustStorePassword not defined for HTTPS/WSS scheme"); } SSLConfig sslConfig = new SSLConfig(null, null).setCertPass(null); @@ -374,6 +384,64 @@ public static SSLConfig getSSLConfigForSender(String certPass, String keyStorePa return sslConfig; } + /** + * Configure outbound HTTP pipeline for SSL configuration. + * + * @param socketChannel Socket channel of outbound connection + * @param sslConfiguration {@link SslConfiguration} + * @param host host of the connection + * @param port port of the connection + * @throws SSLException if any error occurs in the SSL connection + */ + public static void configureHttpPipelineForSSL(SocketChannel socketChannel, String host, int port, + SslConfiguration sslConfiguration) throws SSLException { + log.debug("adding ssl handler"); + SSLConfig sslConfig = sslConfiguration.generateSSLConfig(); + ChannelPipeline pipeline = socketChannel.pipeline(); + if (sslConfiguration.isOcspStaplingEnabled()) { + SSLHandlerFactory sslHandlerFactory = new SSLHandlerFactory(sslConfig); + ReferenceCountedOpenSslContext referenceCountedOpenSslContext = sslHandlerFactory + .buildClientReferenceCountedOpenSslContext(); + + if (referenceCountedOpenSslContext != null) { + SslHandler sslHandler = referenceCountedOpenSslContext.newHandler(socketChannel.alloc()); + ReferenceCountedOpenSslEngine engine = (ReferenceCountedOpenSslEngine) sslHandler.engine(); + socketChannel.pipeline().addLast(sslHandler); + socketChannel.pipeline().addLast(new OCSPStaplingHandler(engine)); + } + } else { + SSLEngine sslEngine = instantiateAndConfigSSL(sslConfig, host, port, + sslConfiguration.hostNameVerificationEnabled()); + pipeline.addLast(Constants.SSL_HANDLER, new SslHandler(sslEngine)); + if (sslConfiguration.validateCertEnabled()) { + pipeline.addLast(Constants.HTTP_CERT_VALIDATION_HANDLER, new CertificateValidationHandler( + sslEngine, sslConfiguration.getCacheValidityPeriod(), sslConfiguration.getCacheSize())); + } + } + } + + /** + * Set configurations to create ssl engine. + * + * @param sslConfig ssl related configurations + * @return ssl engine + */ + public static SSLEngine instantiateAndConfigSSL(SSLConfig sslConfig, String host, int port, + boolean hostNameVerificationEnabled) { + // set the pipeline factory, which creates the pipeline for each newly created channels + SSLEngine sslEngine = null; + if (sslConfig != null) { + SSLHandlerFactory sslHandlerFactory = new SSLHandlerFactory(sslConfig); + sslEngine = sslHandlerFactory.buildClientSSLEngine(host, port); + sslEngine.setUseClientMode(true); + sslHandlerFactory.setSNIServerNames(sslEngine, host); + if (hostNameVerificationEnabled) { + sslHandlerFactory.setHostNameVerfication(sslEngine); + } + } + return sslEngine; + } + /** * Get integer type property value from a property map. *

diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/config/SenderConfiguration.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/config/SenderConfiguration.java index a4766a75d..97aaf1bd8 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/config/SenderConfiguration.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/config/SenderConfiguration.java @@ -19,25 +19,12 @@ package org.wso2.transport.http.netty.config; import org.wso2.transport.http.netty.common.ProxyServerConfiguration; -import org.wso2.transport.http.netty.common.Util; -import org.wso2.transport.http.netty.common.ssl.SSLConfig; import org.wso2.transport.http.netty.sender.channel.pool.PoolConfiguration; -import java.util.ArrayList; -import java.util.List; -import javax.xml.bind.annotation.XmlAccessType; -import javax.xml.bind.annotation.XmlAccessorType; -import javax.xml.bind.annotation.XmlAttribute; -import javax.xml.bind.annotation.XmlElement; -import javax.xml.bind.annotation.XmlElementWrapper; - - /** * JAXB representation of the Netty transport sender configuration. */ -@SuppressWarnings("unused") -@XmlAccessorType(XmlAccessType.FIELD) -public class SenderConfiguration { +public class SenderConfiguration extends SslConfiguration { private static final String DEFAULT_KEY = "netty"; @@ -48,57 +35,17 @@ public static SenderConfiguration getDefault() { return defaultConfig; } - @XmlAttribute(required = true) private String id = DEFAULT_KEY; - - @XmlAttribute - private String scheme = "http"; - - @XmlAttribute - private String keyStoreFile; - - @XmlAttribute - private String keyStorePassword; - - @XmlAttribute - private String trustStoreFile; - - @XmlAttribute - private String trustStorePass; - - @XmlAttribute - private String certPass; - - @XmlAttribute private int socketIdleTimeout = 60000; - - @XmlAttribute private boolean httpTraceLogEnabled; - private ChunkConfig chunkingConfig = ChunkConfig.AUTO; - - @XmlAttribute - private String sslProtocol; - - @XmlElementWrapper(name = "parameters") - @XmlElement(name = "parameter") - private List parameters = new ArrayList<>(); - private KeepAliveConfig keepAliveConfig = KeepAliveConfig.AUTO; - - @XmlAttribute private boolean forceHttp2 = false; - - private String tlsStoreType; private String httpVersion = "1.1"; private ProxyServerConfiguration proxyServerConfiguration; private PoolConfiguration poolConfiguration; - private boolean validateCertEnabled; - private int cacheSize = 50; - private int cacheValidityPeriod = 15; - private boolean hostNameVerificationEnabled = true; + private ForwardedExtensionConfig forwardedExtensionConfig; - private boolean ocspStaplingEnabled = false; public SenderConfiguration() { this.poolConfiguration = new PoolConfiguration(); @@ -109,30 +56,6 @@ public SenderConfiguration(String id) { this.poolConfiguration = new PoolConfiguration(); } - public void setSSLProtocol(String sslProtocol) { - this.sslProtocol = sslProtocol; - } - - public String getSSLProtocol() { - return sslProtocol; - } - - public String getCertPass() { - return certPass; - } - - public String getTLSStoreType() { - return tlsStoreType; - } - - public void setTLSStoreType(String storeType) { - this.tlsStoreType = storeType; - } - - public void setCertPass(String certPass) { - this.certPass = certPass; - } - public String getId() { return id; } @@ -141,62 +64,6 @@ public void setId(String id) { this.id = id; } - public String getKeyStoreFile() { - return keyStoreFile; - } - - public void setKeyStoreFile(String keyStoreFile) { - this.keyStoreFile = keyStoreFile; - } - - public String getKeyStorePassword() { - return keyStorePassword; - } - - public void setKeyStorePassword(String keyStorePassword) { - this.keyStorePassword = keyStorePassword; - } - - public String getScheme() { - return scheme; - } - - public void setScheme(String scheme) { - this.scheme = scheme; - } - - public List getParameters() { - return parameters; - } - - public void setParameters(List parameters) { - this.parameters = parameters; - } - - public String getTrustStoreFile() { - return trustStoreFile; - } - - public void setTrustStoreFile(String trustStoreFile) { - this.trustStoreFile = trustStoreFile; - } - - public String getTrustStorePass() { - return trustStorePass; - } - - public void setTrustStorePass(String trustStorePass) { - this.trustStorePass = trustStorePass; - } - - public SSLConfig getSSLConfig() { - if (scheme == null || !scheme.equalsIgnoreCase("https")) { - return null; - } - return Util.getSSLConfigForSender(certPass, keyStorePassword, keyStoreFile, trustStoreFile, trustStorePass, - parameters, sslProtocol, tlsStoreType); - } - public int getSocketIdleTimeout(int defaultValue) { if (socketIdleTimeout == 0) { return defaultValue; @@ -258,46 +125,6 @@ public void setForceHttp2(boolean forceHttp2) { this.forceHttp2 = forceHttp2; } - public void setValidateCertEnabled(boolean validateCertEnabled) { - this.validateCertEnabled = validateCertEnabled; - } - - public void setCacheSize(int cacheSize) { - this.cacheSize = cacheSize; - } - - public void setCacheValidityPeriod(int cacheValidityPeriod) { - this.cacheValidityPeriod = cacheValidityPeriod; - } - - public boolean validateCertEnabled() { - return validateCertEnabled; - } - - public int getCacheSize() { - return cacheSize; - } - - public void setHostNameVerificationEnabled(boolean hostNameVerificationEnabled) { - this.hostNameVerificationEnabled = hostNameVerificationEnabled; - } - - public boolean hostNameVerificationEnabled() { - return hostNameVerificationEnabled; - } - - public int getCacheValidityPeriod() { - return cacheValidityPeriod; - } - - public void setOcspStaplingEnabled(boolean ocspStaplingEnabled) { - this.ocspStaplingEnabled = ocspStaplingEnabled; - } - - public boolean isOcspStaplingEnabled() { - return ocspStaplingEnabled; - } - public PoolConfiguration getPoolConfiguration() { return poolConfiguration; } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/config/SslConfiguration.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/config/SslConfiguration.java new file mode 100644 index 000000000..4ed7f3bee --- /dev/null +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/config/SslConfiguration.java @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2018, WSO2 Inc. (http://www.wso2.org) All Rights Reserved. + * + * WSO2 Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.transport.http.netty.config; + +import org.wso2.transport.http.netty.common.Util; +import org.wso2.transport.http.netty.common.ssl.SSLConfig; + +import java.util.ArrayList; +import java.util.List; + +/** + * SSL configuration for HTTP connection. + */ +public class SslConfiguration { + + private String scheme = "http"; + private String keyStoreFile; + private String keyStorePassword; + private String trustStoreFile; + private String trustStorePass; + private String certPass; + private String sslProtocol; + private List parameters = new ArrayList<>(); + private String tlsStoreType; + private boolean hostNameVerificationEnabled = true; + private boolean validateCertEnabled; + private int cacheValidityPeriod = 15; + private int cacheSize = 50; + private boolean ocspStaplingEnabled = false; + + public String getCertPass() { + return certPass; + } + + public void setCertPass(String certPass) { + this.certPass = certPass; + } + + public String getKeyStoreFile() { + return keyStoreFile; + } + + public void setKeyStoreFile(String keyStoreFile) { + this.keyStoreFile = keyStoreFile; + } + + public String getKeyStorePassword() { + return keyStorePassword; + } + + public void setKeyStorePassword(String keyStorePassword) { + this.keyStorePassword = keyStorePassword; + } + + public String getScheme() { + return scheme; + } + + public void setScheme(String scheme) { + this.scheme = scheme; + } + + public String getTrustStoreFile() { + return trustStoreFile; + } + + public void setTrustStoreFile(String trustStoreFile) { + this.trustStoreFile = trustStoreFile; + } + + public String getTrustStorePass() { + return trustStorePass; + } + + public void setTrustStorePass(String trustStorePass) { + this.trustStorePass = trustStorePass; + } + + public void setSSLProtocol(String sslProtocol) { + this.sslProtocol = sslProtocol; + } + + public String getSSLProtocol() { + return sslProtocol; + } + + public List getParameters() { + return parameters; + } + + public void setParameters(List parameters) { + this.parameters = parameters; + } + + public String getTLSStoreType() { + return tlsStoreType; + } + + public void setTLSStoreType(String storeType) { + this.tlsStoreType = storeType; + } + + public void setValidateCertEnabled(boolean validateCertEnabled) { + this.validateCertEnabled = validateCertEnabled; + } + + public boolean validateCertEnabled() { + return validateCertEnabled; + } + + public void setHostNameVerificationEnabled(boolean hostNameVerificationEnabled) { + this.hostNameVerificationEnabled = hostNameVerificationEnabled; + } + + public boolean hostNameVerificationEnabled() { + return hostNameVerificationEnabled; + } + + public void setCacheValidityPeriod(int cacheValidityPeriod) { + this.cacheValidityPeriod = cacheValidityPeriod; + } + + public int getCacheValidityPeriod() { + return cacheValidityPeriod; + } + + public void setCacheSize(int cacheSize) { + this.cacheSize = cacheSize; + } + + public int getCacheSize() { + return cacheSize; + } + + public void setOcspStaplingEnabled(boolean ocspStaplingEnabled) { + this.ocspStaplingEnabled = ocspStaplingEnabled; + } + + public boolean isOcspStaplingEnabled() { + return ocspStaplingEnabled; + } + + public SSLConfig generateSSLConfig() { + if (scheme == null || !scheme.equalsIgnoreCase("https")) { + return null; + } + return Util.getSSLConfigForSender(certPass, keyStorePassword, keyStoreFile, trustStoreFile, trustStorePass, + parameters, sslProtocol, tlsStoreType); + } +} diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/ServerConnectorFuture.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/ServerConnectorFuture.java index 72197925b..f80d943f2 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/ServerConnectorFuture.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/ServerConnectorFuture.java @@ -27,7 +27,7 @@ public interface ServerConnectorFuture extends HttpConnectorFuture, WebSocketConnectorFuture { /** - * Set life cycle event listener for the HTTP/WS connector + * Set life cycle event listener for the HTTP/WS_SCHEME connector * * @param portBindingEventListener The PortBindingEventListener implementation */ diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketClientConnectorConfig.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketClientConnectorConfig.java index 6d4406ab8..503279191 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketClientConnectorConfig.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketClientConnectorConfig.java @@ -21,7 +21,10 @@ import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.HttpHeaders; +import org.wso2.transport.http.netty.common.Constants; +import org.wso2.transport.http.netty.config.SslConfiguration; +import java.net.URI; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -29,7 +32,7 @@ /** * Configuration for WebSocket client connector. */ -public class WebSocketClientConnectorConfig { +public class WebSocketClientConnectorConfig extends SslConfiguration { private final String remoteAddress; private List subProtocols; @@ -38,16 +41,10 @@ public class WebSocketClientConnectorConfig { private final HttpHeaders headers; public WebSocketClientConnectorConfig(String remoteAddress) { - this(remoteAddress, null, -1, true); - } - - public WebSocketClientConnectorConfig(String remoteAddress, List subProtocols, - int idleTimeoutInSeconds, boolean autoRead) { this.remoteAddress = remoteAddress; - this.subProtocols = subProtocols; - this.idleTimeoutInSeconds = idleTimeoutInSeconds; - this.autoRead = autoRead; this.headers = new DefaultHttpHeaders(); + this.setScheme(Constants.WSS_SCHEME.equals(URI.create(remoteAddress).getScheme()) + ? Constants.HTTPS_SCHEME : Constants.HTTP_SCHEME); } /** diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketControlMessage.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketControlMessage.java index 3863f89bb..cc6099980 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketControlMessage.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketControlMessage.java @@ -20,7 +20,7 @@ package org.wso2.transport.http.netty.contract.websocket; /** - * This message contains the details of WebSocket bong message. + * This message contains the details of WebSocket control message. */ public interface WebSocketControlMessage extends WebSocketBinaryMessage { diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/DefaultHttpClientConnector.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/DefaultHttpClientConnector.java index 9f3bc1502..96c0e4199 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/DefaultHttpClientConnector.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/DefaultHttpClientConnector.java @@ -292,7 +292,7 @@ private void initTargetChannelProperties(SenderConfiguration senderConfiguration this.httpVersion = senderConfiguration.getHttpVersion(); this.chunkConfig = senderConfiguration.getChunkingConfig(); this.socketIdleTimeout = senderConfiguration.getSocketIdleTimeout(Constants.ENDPOINT_TIMEOUT); - this.sslConfig = senderConfiguration.getSSLConfig(); + this.sslConfig = senderConfiguration.generateSSLConfig(); this.keepAliveConfig = senderConfiguration.getKeepAliveConfig(); this.forwardedExtensionConfig = senderConfiguration.getForwardedExtensionConfig(); } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultClientHandshakeFuture.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultClientHandshakeFuture.java index cb38c2a81..5886395b2 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultClientHandshakeFuture.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultClientHandshakeFuture.java @@ -42,8 +42,7 @@ public void setClientHandshakeListener(ClientHandshakeListener clientHandshakeLi this.clientHandshakeListener = clientHandshakeListener; if (throwable != null) { clientHandshakeListener.onError(throwable, response); - } - if (webSocketConnection != null && response != null) { + } else if (webSocketConnection != null && response != null) { clientHandshakeListener.onSuccess(webSocketConnection, response); } } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketClientConnector.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketClientConnector.java index 7304c093f..594d947ee 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketClientConnector.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketClientConnector.java @@ -20,7 +20,6 @@ package org.wso2.transport.http.netty.contractimpl.websocket; import io.netty.channel.EventLoopGroup; -import io.netty.handler.codec.http.HttpHeaders; import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeFuture; import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnector; import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; @@ -31,27 +30,15 @@ */ public class DefaultWebSocketClientConnector implements WebSocketClientConnector { - private final String remoteUrl; - private final String subProtocols; - private final int idleTimeout; - private final HttpHeaders customHeaders; - private final EventLoopGroup wsClientEventLoopGroup; - private final boolean autoRead; + private final WebSocketClient webSocketClient; public DefaultWebSocketClientConnector(WebSocketClientConnectorConfig clientConnectorConfig, EventLoopGroup wsClientEventLoopGroup) { - this.remoteUrl = clientConnectorConfig.getRemoteAddress(); - this.subProtocols = clientConnectorConfig.getSubProtocolsAsCSV(); - this.customHeaders = clientConnectorConfig.getHeaders(); - this.idleTimeout = clientConnectorConfig.getIdleTimeoutInMillis(); - this.wsClientEventLoopGroup = wsClientEventLoopGroup; - this.autoRead = clientConnectorConfig.isAutoRead(); + this.webSocketClient = new WebSocketClient(wsClientEventLoopGroup, clientConnectorConfig); } @Override public ClientHandshakeFuture connect() { - WebSocketClient webSocketClient = new WebSocketClient(remoteUrl, subProtocols, idleTimeout, - wsClientEventLoopGroup, customHeaders, autoRead); return webSocketClient.handshake(); } } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketConnection.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketConnection.java index 19ec58a84..271e54b45 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketConnection.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketConnection.java @@ -15,7 +15,7 @@ import org.wso2.transport.http.netty.contract.websocket.WebSocketConnection; import org.wso2.transport.http.netty.contract.websocket.WebSocketFrameType; import org.wso2.transport.http.netty.internal.websocket.DefaultWebSocketSession; -import org.wso2.transport.http.netty.listener.WebSocketFramesBlockingHandler; +import org.wso2.transport.http.netty.listener.WebSocketFramesQueueHandler; import java.nio.ByteBuffer; import javax.websocket.Session; @@ -28,13 +28,13 @@ public class DefaultWebSocketConnection implements WebSocketConnection { private final WebSocketInboundFrameHandler frameHandler; private final ChannelHandlerContext ctx; private final DefaultWebSocketSession session; - private WebSocketFramesBlockingHandler blockingHandler; + private WebSocketFramesQueueHandler blockingHandler; private WebSocketFrameType continuationFrameType; private boolean closeFrameSent; private int closeInitiatedStatusCode; public DefaultWebSocketConnection(ChannelHandlerContext ctx, WebSocketInboundFrameHandler frameHandler, - WebSocketFramesBlockingHandler blockingHandler, + WebSocketFramesQueueHandler blockingHandler, DefaultWebSocketSession session) { this.ctx = ctx; this.frameHandler = frameHandler; @@ -67,7 +67,7 @@ public void startReadingFrames() { public void stopReadingFrames() { ctx.channel().config().setAutoRead(false); ctx.pipeline().addBefore(Constants.WEBSOCKET_FRAME_HANDLER, Constants.WEBSOCKET_FRAME_BLOCKING_HANDLER, - blockingHandler = new WebSocketFramesBlockingHandler()); + blockingHandler = new WebSocketFramesQueueHandler()); } @Override diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/WebSocketInboundFrameHandler.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/WebSocketInboundFrameHandler.java index 58fe62c82..f76105b48 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/WebSocketInboundFrameHandler.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/WebSocketInboundFrameHandler.java @@ -47,7 +47,7 @@ import org.wso2.transport.http.netty.contractimpl.websocket.message.DefaultWebSocketControlMessage; import org.wso2.transport.http.netty.exception.UnknownWebSocketFrameTypeException; import org.wso2.transport.http.netty.internal.websocket.WebSocketUtil; -import org.wso2.transport.http.netty.listener.WebSocketFramesBlockingHandler; +import org.wso2.transport.http.netty.listener.WebSocketFramesQueueHandler; import java.net.InetSocketAddress; @@ -63,17 +63,17 @@ public class WebSocketInboundFrameHandler extends ChannelInboundHandlerAdapter { private final boolean securedConnection; private final String target; private final String interfaceId; - private final WebSocketFramesBlockingHandler blockingHandler; private DefaultWebSocketConnection webSocketConnection; private ChannelHandlerContext ctx; private ChannelPromise closePromise; private WebSocketFrameType continuationFrameType; + private final WebSocketFramesQueueHandler blockingHandler; private boolean caughtException; private boolean closeFrameReceived; private boolean closeInitialized; public WebSocketInboundFrameHandler(WebSocketConnectorFuture connectorFuture, - WebSocketFramesBlockingHandler blockingHandler, boolean isServer, + WebSocketFramesQueueHandler blockingHandler, boolean isServer, boolean securedConnection, String target, String interfaceId) { this.connectorFuture = connectorFuture; this.blockingHandler = blockingHandler; @@ -147,7 +147,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws WebSocketConnector return; } - if (closePromise != null && !closePromise.isDone()) { + if (closePromise != null && !closeFrameReceived) { String errMsg = "Connection is closed by remote endpoint without echoing a close frame"; ctx.close().addListener(closeFuture -> closePromise.setFailure(new IllegalStateException(errMsg))); } @@ -175,6 +175,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception } notifyBinaryMessage(binaryFrame, binaryFrame.content(), binaryFrame.isFinalFragment()); } else if (msg instanceof CloseWebSocketFrame) { + closeFrameReceived = true; notifyCloseMessage((CloseWebSocketFrame) msg); } else if (msg instanceof PingWebSocketFrame) { notifyPingMessage((PingWebSocketFrame) msg); @@ -228,7 +229,6 @@ private void notifyCloseMessage(CloseWebSocketFrame closeWebSocketFrame) throws if (closePromise == null) { DefaultWebSocketMessage webSocketCloseMessage = new DefaultWebSocketCloseMessage(statusCode, reasonText); setupCommonProperties(webSocketCloseMessage); - closeFrameReceived = true; connectorFuture.notifyWebSocketListener((WebSocketCloseMessage) webSocketCloseMessage); } else { if (webSocketConnection.getCloseInitiatedStatusCode() != closeWebSocketFrame.statusCode()) { diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketInitMessage.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketInitMessage.java index 6eba145bf..2a2ce3fb7 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketInitMessage.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketInitMessage.java @@ -42,7 +42,7 @@ import org.wso2.transport.http.netty.contractimpl.websocket.DefaultWebSocketMessage; import org.wso2.transport.http.netty.contractimpl.websocket.WebSocketInboundFrameHandler; import org.wso2.transport.http.netty.internal.websocket.WebSocketUtil; -import org.wso2.transport.http.netty.listener.WebSocketFramesBlockingHandler; +import org.wso2.transport.http.netty.listener.WebSocketFramesQueueHandler; import org.wso2.transport.http.netty.message.HttpCarbonRequest; import java.nio.charset.StandardCharsets; @@ -165,7 +165,7 @@ private ServerHandshakeFuture handleHandshake(WebSocketServerHandshaker handshak channelFuture.addListener(future -> { if (future.isSuccess() && future.cause() == null) { String selectedSubProtocol = handshaker.selectedSubprotocol(); - WebSocketFramesBlockingHandler blockingHandler = new WebSocketFramesBlockingHandler(); + WebSocketFramesQueueHandler blockingHandler = new WebSocketFramesQueueHandler(); WebSocketInboundFrameHandler frameHandler = new WebSocketInboundFrameHandler(connectorFuture, blockingHandler, true, secureConnection, target, listenerInterface); configureFrameHandlingPipeline(idleTimeout, blockingHandler, frameHandler); @@ -180,7 +180,7 @@ private ServerHandshakeFuture handleHandshake(WebSocketServerHandshaker handshak return handshakeFuture; } - private void configureFrameHandlingPipeline(int idleTimeout, WebSocketFramesBlockingHandler blockingHandler, + private void configureFrameHandlingPipeline(int idleTimeout, WebSocketFramesQueueHandler blockingHandler, WebSocketInboundFrameHandler frameHandler) { ChannelPipeline pipeline = ctx.pipeline(); if (idleTimeout > 0) { @@ -198,9 +198,9 @@ private void configureFrameHandlingPipeline(int idleTimeout, WebSocketFramesBloc /* Get the URL of the given connection */ private String getWebSocketURL(HttpRequest req) { - String protocol = Constants.WEBSOCKET_PROTOCOL; + String protocol = Constants.WS_SCHEME; if (secureConnection) { - protocol = Constants.WEBSOCKET_PROTOCOL_SECURED; + protocol = Constants.WSS_SCHEME; } return protocol + "://" + req.headers().get("Host") + req.uri(); } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/internal/websocket/WebSocketUtil.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/internal/websocket/WebSocketUtil.java index b9e0c3b9b..c414eb997 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/internal/websocket/WebSocketUtil.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/internal/websocket/WebSocketUtil.java @@ -30,7 +30,7 @@ import org.wso2.transport.http.netty.contractimpl.websocket.message.DefaultWebSocketBinaryMessage; import org.wso2.transport.http.netty.contractimpl.websocket.message.DefaultWebSocketControlMessage; import org.wso2.transport.http.netty.contractimpl.websocket.message.DefaultWebSocketTextMessage; -import org.wso2.transport.http.netty.listener.WebSocketFramesBlockingHandler; +import org.wso2.transport.http.netty.listener.WebSocketFramesQueueHandler; import java.net.URISyntaxException; import java.nio.ByteBuffer; @@ -46,7 +46,7 @@ public static String getSessionID(ChannelHandlerContext ctx) { public static DefaultWebSocketConnection getWebSocketConnection(ChannelHandlerContext ctx, WebSocketInboundFrameHandler frameHandler, - WebSocketFramesBlockingHandler blockingHandler, + WebSocketFramesQueueHandler blockingHandler, boolean isSecured, String uri) throws URISyntaxException { DefaultWebSocketSession session = new DefaultWebSocketSession(ctx, isSecured, uri, getSessionID(ctx)); diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/listener/WebSocketFramesBlockingHandler.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/listener/WebSocketFramesQueueHandler.java similarity index 93% rename from components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/listener/WebSocketFramesBlockingHandler.java rename to components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/listener/WebSocketFramesQueueHandler.java index 05ef3b2de..760d2f852 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/listener/WebSocketFramesBlockingHandler.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/listener/WebSocketFramesQueueHandler.java @@ -28,13 +28,13 @@ * This Handler is responsible for issuing frame by frame when the WebSocket connection is asked to read next frame * when autoRead is set to false. */ -public class WebSocketFramesBlockingHandler extends ChannelInboundHandlerAdapter { +public class WebSocketFramesQueueHandler extends ChannelInboundHandlerAdapter { private final Queue frameCollectorQueue; private ChannelHandlerContext ctx; private boolean readNext; - public WebSocketFramesBlockingHandler() { + public WebSocketFramesQueueHandler() { this.frameCollectorQueue = new ConcurrentLinkedQueue<>(); } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/HttpClientChannelInitializer.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/HttpClientChannelInitializer.java index 41dca75d6..049ff5941 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/HttpClientChannelInitializer.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/HttpClientChannelInitializer.java @@ -44,6 +44,7 @@ import org.wso2.transport.http.netty.common.FrameLogger; import org.wso2.transport.http.netty.common.HttpRoute; import org.wso2.transport.http.netty.common.ProxyServerConfiguration; +import org.wso2.transport.http.netty.common.Util; import org.wso2.transport.http.netty.common.ssl.SSLConfig; import org.wso2.transport.http.netty.common.ssl.SSLHandlerFactory; import org.wso2.transport.http.netty.config.KeepAliveConfig; @@ -55,7 +56,6 @@ import org.wso2.transport.http.netty.sender.http2.Http2ConnectionManager; import org.wso2.transport.http.netty.sender.http2.Http2TargetHandler; -import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLException; import static io.netty.handler.logging.LogLevel.TRACE; @@ -99,7 +99,7 @@ public HttpClientChannelInitializer(SenderConfiguration senderConfiguration, Htt this.cacheSize = senderConfiguration.getCacheSize(); this.senderConfiguration = senderConfiguration; this.httpRoute = httpRoute; - this.sslConfig = senderConfiguration.getSSLConfig(); + this.sslConfig = senderConfiguration.generateSSLConfig(); this.connectionAvailabilityFuture = connectionAvailabilityFuture; String httpVersion = senderConfiguration.getHttpVersion(); @@ -129,7 +129,7 @@ protected void initChannel(SocketChannel socketChannel) throws Exception { targetHandler.setHttp2TargetHandler(http2TargetHandler); targetHandler.setKeepAliveConfig(getKeepAliveConfig()); if (http2) { - SSLConfig sslConfig = senderConfiguration.getSSLConfig(); + SSLConfig sslConfig = senderConfiguration.generateSSLConfig(); if (sslConfig != null) { configureSslForHttp2(socketChannel, clientPipeline, sslConfig); } else if (senderConfiguration.isForceHttp2()) { @@ -139,7 +139,11 @@ protected void initChannel(SocketChannel socketChannel) throws Exception { } } else { if (sslConfig != null) { - configureSslForHttp(clientPipeline, targetHandler, socketChannel); + connectionAvailabilityFuture.setSSLEnabled(true); + Util.configureHttpPipelineForSSL(socketChannel, httpRoute.getHost(), httpRoute.getPort(), + senderConfiguration); + clientPipeline.addLast(Constants.SSL_COMPLETION_HANDLER, + new SslHandshakeCompletionHandlerForClient(connectionAvailabilityFuture, this, targetHandler)); } else { configureHttpPipeline(clientPipeline, targetHandler); } @@ -162,34 +166,6 @@ private void configureProxyServer(ChannelPipeline clientPipeline) { } } - private void configureSslForHttp(ChannelPipeline clientPipeline, TargetHandler targetHandler, - SocketChannel socketChannel) - throws SSLException { - log.debug("adding ssl handler"); - connectionAvailabilityFuture.setSSLEnabled(true); - if (senderConfiguration.isOcspStaplingEnabled()) { - SSLHandlerFactory sslHandlerFactory = new SSLHandlerFactory(sslConfig); - ReferenceCountedOpenSslContext referenceCountedOpenSslContext = sslHandlerFactory - .buildClientReferenceCountedOpenSslContext(); - - if (referenceCountedOpenSslContext != null) { - SslHandler sslHandler = referenceCountedOpenSslContext.newHandler(socketChannel.alloc()); - ReferenceCountedOpenSslEngine engine = (ReferenceCountedOpenSslEngine) sslHandler.engine(); - socketChannel.pipeline().addLast(sslHandler); - socketChannel.pipeline().addLast(new OCSPStaplingHandler(engine)); - } - } else { - SSLEngine sslEngine = instantiateAndConfigSSL(sslConfig); - clientPipeline.addLast(Constants.SSL_HANDLER, new SslHandler(sslEngine)); - if (validateCertEnabled) { - clientPipeline.addLast(Constants.HTTP_CERT_VALIDATION_HANDLER, - new CertificateValidationHandler(sslEngine, this.cacheDelay, this.cacheSize)); - } - } - clientPipeline.addLast(Constants.SSL_COMPLETION_HANDLER, - new SslHandshakeCompletionHandlerForClient(connectionAvailabilityFuture, this, targetHandler)); - } - private void configureSslForHttp2(SocketChannel ch, ChannelPipeline clientPipeline, SSLConfig sslConfig) throws SSLException { connectionAvailabilityFuture.setSSLEnabled(true); @@ -277,27 +253,6 @@ private void addCommonHandlers(ChannelPipeline pipeline) { } } - /** - * Set configurations to create ssl engine. - * - * @param sslConfig ssl related configurations - * @return ssl engine - */ - private SSLEngine instantiateAndConfigSSL(SSLConfig sslConfig) { - // set the pipeline factory, which creates the pipeline for each newly created channels - SSLEngine sslEngine = null; - if (sslConfig != null) { - SSLHandlerFactory sslHandlerFactory = new SSLHandlerFactory(sslConfig); - sslEngine = sslHandlerFactory.buildClientSSLEngine(httpRoute.getHost(), httpRoute.getPort()); - sslEngine.setUseClientMode(true); - sslHandlerFactory.setSNIServerNames(sslEngine, httpRoute.getHost()); - if (senderConfiguration.hostNameVerificationEnabled()) { - sslHandlerFactory.setHostNameVerfication(sslEngine); - } - } - return sslEngine; - } - /** * Gets the associated {@link Http2Connection}. * diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/OCSPStaplingHandler.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/OCSPStaplingHandler.java index fcb13fa47..18a21cb60 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/OCSPStaplingHandler.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/OCSPStaplingHandler.java @@ -45,7 +45,7 @@ public class OCSPStaplingHandler extends OcspClientHandler { private static final Logger log = LoggerFactory.getLogger(OCSPStaplingHandler.class); - protected OCSPStaplingHandler(ReferenceCountedOpenSslEngine engine) { + public OCSPStaplingHandler(ReferenceCountedOpenSslEngine engine) { super(engine); } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClient.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClient.java index 4147bc4db..2e85238c4 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClient.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClient.java @@ -20,6 +20,8 @@ package org.wso2.transport.http.netty.sender.websocket; import io.netty.bootstrap.Bootstrap; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; @@ -32,21 +34,21 @@ import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory; import io.netty.handler.codec.http.websocketx.WebSocketVersion; import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler; -import io.netty.handler.ssl.SslContext; -import io.netty.handler.ssl.SslContextBuilder; -import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.SslHandshakeCompletionEvent; import io.netty.handler.timeout.IdleStateHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.wso2.transport.http.netty.common.Constants; +import org.wso2.transport.http.netty.common.Util; import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeFuture; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; import org.wso2.transport.http.netty.contractimpl.websocket.DefaultClientHandshakeFuture; -import org.wso2.transport.http.netty.contractimpl.websocket.DefaultWebSocketConnection; -import org.wso2.transport.http.netty.listener.WebSocketFramesBlockingHandler; +import org.wso2.transport.http.netty.listener.WebSocketFramesQueueHandler; import java.net.URI; import java.net.URISyntaxException; import java.util.concurrent.TimeUnit; +import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLException; /** @@ -64,23 +66,20 @@ public class WebSocketClient { private final HttpHeaders headers; private final EventLoopGroup wsClientEventLoopGroup; private final boolean autoRead; + private final WebSocketClientConnectorConfig connectorConfig; /** - * @param url url of the remote endpoint - * @param subProtocols subProtocols the negotiable sub-protocol if server is asking for it - * @param idleTimeout Idle timeout of the connection * @param wsClientEventLoopGroup of the client connector - * @param headers any specific headers which need to send to the server - * @param autoRead sets the read interest + * @param connectorConfig Connector configuration for WebSocket client. */ - public WebSocketClient(String url, String subProtocols, int idleTimeout, EventLoopGroup wsClientEventLoopGroup, - HttpHeaders headers, boolean autoRead) { - this.url = url; - this.subProtocols = subProtocols; - this.idleTimeout = idleTimeout; - this.headers = headers; + public WebSocketClient(EventLoopGroup wsClientEventLoopGroup, WebSocketClientConnectorConfig connectorConfig) { + this.url = connectorConfig.getRemoteAddress(); + this.subProtocols = connectorConfig.getSubProtocolsAsCSV(); + this.idleTimeout = connectorConfig.getIdleTimeoutInMillis(); + this.headers = connectorConfig.getHeaders(); this.wsClientEventLoopGroup = wsClientEventLoopGroup; - this.autoRead = autoRead; + this.autoRead = connectorConfig.isAutoRead(); + this.connectorConfig = connectorConfig; } /** @@ -89,91 +88,109 @@ public WebSocketClient(String url, String subProtocols, int idleTimeout, EventLo * @return handshake future for connection. */ public ClientHandshakeFuture handshake() { - ClientHandshakeFuture handshakeFuture = new DefaultClientHandshakeFuture(); + final DefaultClientHandshakeFuture handshakeFuture = new DefaultClientHandshakeFuture(); try { URI uri = new URI(url); - String scheme = uri.getScheme() == null ? "ws" : uri.getScheme(); - final String host = uri.getHost() == null ? "127.0.0.1" : uri.getHost(); - final int port = getPort(uri); - if (!"ws".equalsIgnoreCase(scheme) && !"wss".equalsIgnoreCase(scheme)) { + String scheme = uri.getScheme(); + if (!Constants.WS_SCHEME.equalsIgnoreCase(scheme) && !Constants.WSS_SCHEME.equalsIgnoreCase(scheme)) { log.error("Only WS(S) is supported."); throw new URISyntaxException(url, "WebSocket client supports only WS(S) scheme"); } - final boolean ssl = "wss".equalsIgnoreCase(scheme); + + final String host = uri.getHost() == null ? "127.0.0.1" : uri.getHost(); + final int port = getPort(uri); + final boolean ssl = Constants.WSS_SCHEME.equalsIgnoreCase(scheme); WebSocketClientHandshaker webSocketHandshaker = WebSocketClientHandshakerFactory.newHandshaker( uri, WebSocketVersion.V13, subProtocols, true, headers); - WebSocketFramesBlockingHandler blockingHandler = new WebSocketFramesBlockingHandler(); - clientHandshakeHandler = new WebSocketClientHandshakeHandler(webSocketHandshaker, blockingHandler, ssl, - autoRead, url, handshakeFuture); - Bootstrap clientBootstrap = initClientBootstrap(host, port, getSslContext(ssl)); - clientBootstrap.connect(uri.getHost(), port).sync().channel(); - clientHandshakeHandler.handshakeFuture().addListener(clientHandshakeFuture -> { - Throwable cause = clientHandshakeFuture.cause(); - if (clientHandshakeFuture.isSuccess() && cause == null) { - DefaultWebSocketConnection webSocketConnection = - clientHandshakeHandler.getInboundFrameHandler().getWebSocketConnection(); - String actualSubProtocol = webSocketHandshaker.actualSubprotocol(); - webSocketConnection.getDefaultWebSocketSession().setNegotiatedSubProtocol(actualSubProtocol); - handshakeFuture.notifySuccess(webSocketConnection, clientHandshakeHandler.getHttpCarbonResponse()); - } else { - handshakeFuture.notifyError(cause, clientHandshakeHandler.getHttpCarbonResponse()); - } - }); + WebSocketFramesQueueHandler blockingHandler = new WebSocketFramesQueueHandler(); + clientHandshakeHandler = new WebSocketClientHandshakeHandler(webSocketHandshaker, handshakeFuture, + blockingHandler, ssl, autoRead, url, handshakeFuture); + Bootstrap clientBootstrap = initClientBootstrap(host, port, handshakeFuture); + clientBootstrap.connect(uri.getHost(), port); } catch (Throwable throwable) { - if (clientHandshakeHandler != null) { - handshakeFuture.notifyError(throwable, clientHandshakeHandler.getHttpCarbonResponse()); - } else { - handshakeFuture.notifyError(throwable, null); - } + handleHandshakeError(handshakeFuture, throwable); } return handshakeFuture; } - private Bootstrap initClientBootstrap(String host, int port, SslContext sslCtx) { + private void handleHandshakeError(DefaultClientHandshakeFuture handshakeFuture, Throwable throwable) { + if (clientHandshakeHandler != null) { + handshakeFuture.notifyError(throwable, clientHandshakeHandler.getHttpCarbonResponse()); + } else { + handshakeFuture.notifyError(throwable, null); + } + } + + private Bootstrap initClientBootstrap(String host, int port, DefaultClientHandshakeFuture handshakeFuture) { Bootstrap clientBootstrap = new Bootstrap(); + SSLEngine sslEngine = Util.instantiateAndConfigSSL(connectorConfig.generateSSLConfig(), host, port, + connectorConfig.hostNameVerificationEnabled()); clientBootstrap.group(wsClientEventLoopGroup).channel(NioSocketChannel.class).handler( new ChannelInitializer() { @Override - protected void initChannel(SocketChannel ch) { - ChannelPipeline pipeline = ch.pipeline(); - if (sslCtx != null) { - pipeline.addLast(sslCtx.newHandler(ch.alloc(), host, port)); + protected void initChannel(SocketChannel socketChannel) throws SSLException { + if (sslEngine != null) { + Util.configureHttpPipelineForSSL(socketChannel, host, port, connectorConfig); + socketChannel.pipeline().addLast(Constants.SSL_COMPLETION_HANDLER, + new WebSocketClientSSLHandshakeCompletionHandler(handshakeFuture)); + } else { + configureHandshakePipeline(socketChannel.pipeline()); } - pipeline.addLast(new HttpClientCodec()); - // Assuming that WebSocket Handshake messages will not be large than 8KB - pipeline.addLast(new HttpObjectAggregator(8192)); - pipeline.addLast(WebSocketClientCompressionHandler.INSTANCE); - if (idleTimeout > 0) { - pipeline.addLast(new IdleStateHandler(idleTimeout, idleTimeout, - idleTimeout, TimeUnit.MILLISECONDS)); - } - pipeline.addLast(Constants.WEBSOCKET_CLIENT_HANDSHAKE_HANDLER, clientHandshakeHandler); } }); return clientBootstrap; } + private void configureHandshakePipeline(ChannelPipeline pipeline) { + pipeline.addLast(new HttpClientCodec()); + // Assuming that WebSocket Handshake messages will not be large than 8KB + pipeline.addLast(new HttpObjectAggregator(8192)); // TODO: Use constant if has + pipeline.addLast(WebSocketClientCompressionHandler.INSTANCE); + if (idleTimeout > 0) { + pipeline.addLast(new IdleStateHandler(idleTimeout, idleTimeout, + idleTimeout, TimeUnit.MILLISECONDS)); + } + pipeline.addLast(Constants.WEBSOCKET_CLIENT_HANDSHAKE_HANDLER, clientHandshakeHandler); + } + private int getPort(URI uri) { String scheme = uri.getScheme(); if (uri.getPort() == -1) { switch (scheme) { - case "ws": - return 80; - case "wss": - return 443; - default: - return -1; + case "ws": // TODO: Constants + return 80; + case "wss": + return 443; + default: + return -1; } } else { return uri.getPort(); } } - private SslContext getSslContext(boolean ssl) throws SSLException { - if (ssl) { - return SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE).build(); + /** + * handler to identify the SSL handshake completion. + */ + private class WebSocketClientSSLHandshakeCompletionHandler extends ChannelInboundHandlerAdapter { + private final DefaultClientHandshakeFuture clientHandshakeFuture; + + private WebSocketClientSSLHandshakeCompletionHandler(DefaultClientHandshakeFuture clientHandshakeFuture) { + this.clientHandshakeFuture = clientHandshakeFuture; + } + + @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt; + if (event.isSuccess() && event.cause() == null) { + configureHandshakePipeline(ctx.channel().pipeline()); + ctx.pipeline().remove(Constants.SSL_COMPLETION_HANDLER); + ctx.fireChannelActive(); + } else { + clientHandshakeFuture.notifyError(event.cause(), null); + } + } } - return null; } } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClientHandshakeHandler.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClientHandshakeHandler.java index 0eaa8e17c..a69dbd058 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClientHandshakeHandler.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClientHandshakeHandler.java @@ -18,10 +18,9 @@ package org.wso2.transport.http.netty.sender.websocket; -import io.netty.channel.ChannelFuture; +import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker; @@ -29,8 +28,10 @@ import org.slf4j.LoggerFactory; import org.wso2.transport.http.netty.common.Constants; import org.wso2.transport.http.netty.contract.websocket.WebSocketConnectorFuture; +import org.wso2.transport.http.netty.contractimpl.websocket.DefaultClientHandshakeFuture; +import org.wso2.transport.http.netty.contractimpl.websocket.DefaultWebSocketConnection; import org.wso2.transport.http.netty.contractimpl.websocket.WebSocketInboundFrameHandler; -import org.wso2.transport.http.netty.listener.WebSocketFramesBlockingHandler; +import org.wso2.transport.http.netty.listener.WebSocketFramesQueueHandler; import org.wso2.transport.http.netty.message.DefaultListener; import org.wso2.transport.http.netty.message.HttpCarbonResponse; @@ -42,76 +43,70 @@ public class WebSocketClientHandshakeHandler extends ChannelInboundHandlerAdapte private static final Logger log = LoggerFactory.getLogger(WebSocketClientHandshakeHandler.class); private final WebSocketClientHandshaker handshaker; - private final WebSocketFramesBlockingHandler blockingHandler; + private final WebSocketFramesQueueHandler blockingHandler; private final boolean isSecure; private final boolean autoRead; private final String requestedUri; - private ChannelPromise handshakeFuture; - private HttpCarbonResponse httpCarbonResponse; + private final DefaultClientHandshakeFuture handshakeFuture; private final WebSocketConnectorFuture connectorFuture; - private WebSocketInboundFrameHandler inboundFrameHandler; + private HttpCarbonResponse httpCarbonResponse; public WebSocketClientHandshakeHandler(WebSocketClientHandshaker handshaker, - WebSocketFramesBlockingHandler framesBlockingHandler, boolean isSecure, - boolean autoRead, String requestedUri, WebSocketConnectorFuture connectorFuture) { + DefaultClientHandshakeFuture handshakeFuture, WebSocketFramesQueueHandler framesBlockingHandler, + boolean isSecure, boolean autoRead, String requestedUri, WebSocketConnectorFuture connectorFuture) { this.handshaker = handshaker; this.blockingHandler = framesBlockingHandler; this.isSecure = isSecure; this.autoRead = autoRead; this.requestedUri = requestedUri; - this.handshakeFuture = null; this.connectorFuture = connectorFuture; - } - - public ChannelFuture handshakeFuture() { - return handshakeFuture; + this.handshakeFuture = handshakeFuture; } public HttpCarbonResponse getHttpCarbonResponse() { return httpCarbonResponse; } - public WebSocketInboundFrameHandler getInboundFrameHandler() { - return this.inboundFrameHandler; - } - - @Override - public void handlerAdded(ChannelHandlerContext ctx) { - handshakeFuture = ctx.newPromise(); - } - @Override public void channelActive(ChannelHandlerContext ctx) { handshaker.handshake(ctx.channel()); } @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + public void channelRead(ChannelHandlerContext ctx, Object msg) { if (!(msg instanceof FullHttpResponse)) { - throw new IllegalArgumentException("HTTP response is expected!"); + throw new IllegalArgumentException("HTTP response is expected"); } - FullHttpResponse fullHttpResponse = (FullHttpResponse) msg; - httpCarbonResponse = setUpCarbonMessage(ctx, fullHttpResponse); - log.debug("WebSocket Client connected!"); - ctx.channel().config().setAutoRead(autoRead); - if (!autoRead) { - ctx.channel().pipeline().addLast(Constants.WEBSOCKET_FRAME_BLOCKING_HANDLER, blockingHandler); + FullHttpResponse handshakeResponse = (FullHttpResponse) msg; + httpCarbonResponse = setUpCarbonMessage(ctx, handshakeResponse); + try { + ctx.channel().config().setAutoRead(false); + handshaker.finishHandshake(ctx.channel(), handshakeResponse); + Channel channel = ctx.channel(); + if (!autoRead) { + channel.pipeline().addLast(Constants.WEBSOCKET_FRAME_BLOCKING_HANDLER, blockingHandler); + } + WebSocketInboundFrameHandler inboundFrameHandler = new WebSocketInboundFrameHandler(connectorFuture, + blockingHandler, false, isSecure, requestedUri, null); + channel.pipeline().addLast(Constants.WEBSOCKET_FRAME_HANDLER, inboundFrameHandler); + channel.pipeline().remove(Constants.WEBSOCKET_CLIENT_HANDSHAKE_HANDLER); + ctx.fireChannelActive(); + DefaultWebSocketConnection webSocketConnection = inboundFrameHandler.getWebSocketConnection(); + String actualSubProtocol = handshaker.actualSubprotocol(); + webSocketConnection.getDefaultWebSocketSession().setNegotiatedSubProtocol(actualSubProtocol); + handshakeFuture.notifySuccess(webSocketConnection, httpCarbonResponse); + channel.config().setAutoRead(autoRead); + log.debug("WebSocket Client connected"); + } finally { + handshakeResponse.release(); } - inboundFrameHandler = - new WebSocketInboundFrameHandler(connectorFuture, blockingHandler, false, isSecure, requestedUri, null); - ctx.channel().pipeline().addLast(Constants.WEBSOCKET_FRAME_HANDLER, inboundFrameHandler); - handshaker.finishHandshake(ctx.channel(), fullHttpResponse); - ctx.channel().pipeline().remove(Constants.WEBSOCKET_CLIENT_HANDSHAKE_HANDLER); - ctx.fireChannelActive(); - handshakeFuture.setSuccess(); - fullHttpResponse.release(); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { log.error("Caught exception", cause); - handshakeFuture.setFailure(cause); + handshakeFuture.notifyError(cause, httpCarbonResponse); } private HttpCarbonResponse setUpCarbonMessage(ChannelHandlerContext ctx, HttpResponse msg) { diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/pkcs/PKCSTest.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/pkcs/PKCSTest.java index a967e883f..19ff68870 100644 --- a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/pkcs/PKCSTest.java +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/pkcs/PKCSTest.java @@ -61,8 +61,6 @@ public class PKCSTest { @BeforeClass public void setup() throws InterruptedException { - - //set PKCS12 truststore to ballerina client. httpConnectorFactory = new DefaultHttpWsConnectorFactory(); ListenerConfiguration listenerConfiguration = getListenerConfiguration(); diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/util/TestUtil.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/util/TestUtil.java index 093470ebf..8f7583470 100644 --- a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/util/TestUtil.java +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/util/TestUtil.java @@ -95,6 +95,8 @@ public class TestUtil { public static final TimeUnit HTTP2_RESPONSE_TIME_UNIT = TimeUnit.SECONDS; public static final String WEBSOCKET_REMOTE_SERVER_URL = String.format("ws://%s:%d/%s", TEST_HOST, WEBSOCKET_REMOTE_SERVER_PORT, "websocket"); + public static final String WEBSOCKET_SECURE_REMOTE_SERVER_URL = + String.format("wss://%s:%d/%s", TEST_HOST, WEBSOCKET_REMOTE_SERVER_PORT, "websocket"); private static final DefaultHttpWsConnectorFactory httpConnectorFactory = new DefaultHttpWsConnectorFactory(); public static HttpServer startHTTPServer(int port, ChannelInitializer channelInitializer) { diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientFunctionalityTestCase.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientFunctionalityTestCase.java index 3e5746498..22bf47fa0 100644 --- a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientFunctionalityTestCase.java +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientFunctionalityTestCase.java @@ -29,6 +29,7 @@ import org.wso2.transport.http.netty.contract.ServerConnectorException; import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeFuture; import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeListener; +import org.wso2.transport.http.netty.contract.websocket.WebSocketBinaryMessage; import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnector; import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; import org.wso2.transport.http.netty.contract.websocket.WebSocketCloseMessage; @@ -63,6 +64,7 @@ public void setup() throws InterruptedException { remoteServer = new WebSocketRemoteServer(WEBSOCKET_REMOTE_SERVER_PORT, "xml, json"); remoteServer.run(); WebSocketClientConnectorConfig configuration = new WebSocketClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); + configuration.setAutoRead(true); clientConnector = httpConnectorFactory.createWsClientConnector(configuration); } @@ -129,9 +131,10 @@ public void testBinarySendAndReceive() throws Throwable { byte[] bytes = {1, 2, 3, 4, 5}; ByteBuffer bufferSent = ByteBuffer.wrap(bytes); WebSocketTestClientConnectorListener connectorListener = handshakeAndSendBinaryMessage(bufferSent); - ByteBuffer bufferReceived = connectorListener.getReceivedByteBufferToClient(); + WebSocketBinaryMessage receivedBinaryMessage = connectorListener.getReceivedBinaryMessageToClient(); - Assert.assertEquals(bufferReceived, bufferSent); + Assert.assertEquals(receivedBinaryMessage.getByteBuffer(), bufferSent); + Assert.assertEquals(receivedBinaryMessage.getByteArray(), bytes); } private WebSocketTestClientConnectorListener handshakeAndSendBinaryMessage(ByteBuffer bufferSent) diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientHandshakeFunctionalityTestCase.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientHandshakeFunctionalityTestCase.java index 4228dc65b..744732342 100644 --- a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientHandshakeFunctionalityTestCase.java +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientHandshakeFunctionalityTestCase.java @@ -35,6 +35,7 @@ import org.wso2.transport.http.netty.message.HttpCarbonResponse; import org.wso2.transport.http.netty.util.server.websocket.WebSocketRemoteServer; +import java.net.URISyntaxException; import java.util.NoSuchElementException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -51,16 +52,14 @@ public class WebSocketClientHandshakeFunctionalityTestCase { private static final Logger log = LoggerFactory.getLogger(WebSocketClientHandshakeFunctionalityTestCase.class); - private DefaultHttpWsConnectorFactory httpConnectorFactory = new DefaultHttpWsConnectorFactory(); - private WebSocketClientConnector clientConnector; + private DefaultHttpWsConnectorFactory httpConnectorFactory; private WebSocketRemoteServer remoteServer; @BeforeClass public void setup() throws InterruptedException { remoteServer = new WebSocketRemoteServer(WEBSOCKET_REMOTE_SERVER_PORT, "xml, json"); remoteServer.run(); - WebSocketClientConnectorConfig configuration = new WebSocketClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); - clientConnector = httpConnectorFactory.createWsClientConnector(configuration); + httpConnectorFactory = new DefaultHttpWsConnectorFactory(); } @Test(description = "Test the idle timeout for WebSocket") @@ -144,6 +143,42 @@ public void testReadNextFrame() throws Throwable { } } + @Test + public void testInvalidUrl() throws InterruptedException { + String invalidUrl = "myUrl"; + connectAndAssertInvalidUrl(invalidUrl); + + String urlWithWrongScheme = "http://localhost:9090/websocket"; + connectAndAssertInvalidUrl(urlWithWrongScheme); + } + + private void connectAndAssertInvalidUrl(String url) throws InterruptedException { + WebSocketClientConnectorConfig clientConnectorConfig = new WebSocketClientConnectorConfig(url); + HandshakeResult result = connectAndGetHandshakeResult(clientConnectorConfig); + Throwable throwable = result.getThrowable(); + + Assert.assertNull(result.getWebSocketConnection()); + Assert.assertNull(result.getHandshakeResponse()); + Assert.assertNotNull(throwable); + Assert.assertTrue(throwable instanceof URISyntaxException); + Assert.assertEquals(throwable.getMessage(), "WebSocket client supports only WS(S) scheme: " + url); + } + + @Test + public void testWssCallWithoutSslConfig() throws InterruptedException { + String url = "wss://localhost:9090/websocket"; + WebSocketClientConnectorConfig clientConnectorConfig = new WebSocketClientConnectorConfig(url); + HandshakeResult result = connectAndGetHandshakeResult(clientConnectorConfig); + Throwable throwable = result.getThrowable(); + + Assert.assertNull(result.getWebSocketConnection()); + Assert.assertNull(result.getHandshakeResponse()); + Assert.assertNotNull(throwable); + Assert.assertTrue(throwable instanceof IllegalArgumentException); + Assert.assertEquals(throwable.getMessage(), + "TrustStoreFile or trustStorePassword not defined for HTTPS/WSS scheme"); + } + private String readNextTextMsg(WebSocketTestClientConnectorListener connectorListener, WebSocketConnection webSocketConnection) throws Throwable { CountDownLatch latch = new CountDownLatch(1); @@ -166,9 +201,9 @@ private String[] sendTextMessages(WebSocketConnection webSocketConnection, int n private HandshakeResult connectAndGetHandshakeResult(WebSocketClientConnectorConfig configuration) throws InterruptedException { - clientConnector = httpConnectorFactory.createWsClientConnector(configuration); + WebSocketClientConnector clientConnector = httpConnectorFactory.createWsClientConnector(configuration); WebSocketTestClientConnectorListener connectorListener = new WebSocketTestClientConnectorListener(); - ClientHandshakeFuture handshakeFuture = handshake(connectorListener); + ClientHandshakeFuture handshakeFuture = handshake(clientConnector, connectorListener); CountDownLatch handshakeFutureLatch = new CountDownLatch(1); AtomicReference connectionAtomicReference = new AtomicReference<>(); @@ -231,7 +266,8 @@ public void cleanUp() throws ServerConnectorException, InterruptedException { httpConnectorFactory.shutdown(); } - private ClientHandshakeFuture handshake(WebSocketConnectorListener connectorListener) { + private ClientHandshakeFuture handshake(WebSocketClientConnector clientConnector, + WebSocketConnectorListener connectorListener) { ClientHandshakeFuture clientHandshakeFuture = clientConnector.connect(); clientHandshakeFuture.setWebSocketConnectorListener(connectorListener); return clientHandshakeFuture; diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketTestClientConnectorListener.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketTestClientConnectorListener.java index 0e75dc709..1b0e266ac 100644 --- a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketTestClientConnectorListener.java +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketTestClientConnectorListener.java @@ -42,7 +42,7 @@ public class WebSocketTestClientConnectorListener implements WebSocketConnectorL private static final Logger log = LoggerFactory.getLogger(WebSocketTestClientConnectorListener.class); private final Queue textQueue = new LinkedList<>(); - private final Queue bufferQueue = new LinkedList<>(); + private final Queue binaryMessageQueue = new LinkedList<>(); private final Queue errorsQueue = new LinkedList<>(); private static final String PING = "ping"; private WebSocketCloseMessage closeMessage = null; @@ -88,7 +88,7 @@ public void onMessage(WebSocketTextMessage textMessage) { @Override public void onMessage(WebSocketBinaryMessage binaryMessage) { - bufferQueue.add(binaryMessage.getByteBuffer()); + binaryMessageQueue.add(binaryMessage); countDownLatch(); } @@ -144,9 +144,9 @@ public String getReceivedTextToClient() throws Throwable { * * @return the latest {@link ByteBuffer} received to client. */ - public ByteBuffer getReceivedByteBufferToClient() throws Throwable { + public WebSocketBinaryMessage getReceivedBinaryMessageToClient() throws Throwable { if (errorsQueue.isEmpty()) { - return bufferQueue.remove(); + return binaryMessageQueue.remove(); } throw errorsQueue.remove(); } diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/ssl/WebSocketSSLHandshakeFailureTestCase.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/ssl/WebSocketSSLHandshakeFailureTestCase.java new file mode 100644 index 000000000..b08417686 --- /dev/null +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/ssl/WebSocketSSLHandshakeFailureTestCase.java @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2018, WSO2 Inc. (http://www.wso2.org) All Rights Reserved. + * + * WSO2 Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.transport.http.netty.websocket.ssl; + +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; +import org.wso2.transport.http.netty.config.ListenerConfiguration; +import org.wso2.transport.http.netty.contract.HttpWsConnectorFactory; +import org.wso2.transport.http.netty.contract.ServerConnector; +import org.wso2.transport.http.netty.contract.ServerConnectorFuture; +import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeFuture; +import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeListener; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnector; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; +import org.wso2.transport.http.netty.contract.websocket.WebSocketConnection; +import org.wso2.transport.http.netty.contractimpl.DefaultHttpWsConnectorFactory; +import org.wso2.transport.http.netty.message.HttpCarbonResponse; +import org.wso2.transport.http.netty.util.TestUtil; +import org.wso2.transport.http.netty.websocket.client.WebSocketTestClientConnectorListener; +import org.wso2.transport.http.netty.websocket.server.WebSocketTestServerConnectorListener; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static org.wso2.transport.http.netty.util.TestUtil.WEBSOCKET_REMOTE_SERVER_PORT; +import static org.wso2.transport.http.netty.util.TestUtil.WEBSOCKET_SECURE_REMOTE_SERVER_URL; +import static org.wso2.transport.http.netty.util.TestUtil.WEBSOCKET_TEST_IDLE_TIMEOUT; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class WebSocketSSLHandshakeFailureTestCase { + + private String password = "ballerina"; + private String tlsStoreType = "PKCS12"; + private HttpWsConnectorFactory httpConnectorFactory; + private ServerConnector serverConnector; + + @BeforeClass + public void setup() throws InterruptedException { + httpConnectorFactory = new DefaultHttpWsConnectorFactory(); + + ListenerConfiguration listenerConfiguration = getListenerConfiguration(); + serverConnector = httpConnectorFactory + .createServerConnector(TestUtil.getDefaultServerBootstrapConfig(), listenerConfiguration); + ServerConnectorFuture future = serverConnector.start(); + future.setWebSocketConnectorListener(new WebSocketTestServerConnectorListener()); + future.sync(); + } + + private ListenerConfiguration getListenerConfiguration() { + ListenerConfiguration listenerConfiguration = ListenerConfiguration.getDefault(); + listenerConfiguration.setPort(WEBSOCKET_REMOTE_SERVER_PORT); + //set PKCS12 keystore to ballerina server. + String keyStoreFile = "/simple-test-config/wso2carbon.p12"; + listenerConfiguration.setKeyStoreFile(TestUtil.getAbsolutePath(keyStoreFile)); + listenerConfiguration.setScheme("https"); + listenerConfiguration.setKeyStorePass(password); + listenerConfiguration.setCertPass(password); + listenerConfiguration.setTLSStoreType(tlsStoreType); + return listenerConfiguration; + } + + private WebSocketClientConnectorConfig getWebSocketClientConnectorConfigWithSSL() { + WebSocketClientConnectorConfig clientConnectorConfig = + new WebSocketClientConnectorConfig(WEBSOCKET_SECURE_REMOTE_SERVER_URL); + String trustStoreFile = "/simple-test-config/cacerts.p12"; + clientConnectorConfig.setTrustStoreFile(TestUtil.getAbsolutePath(trustStoreFile)); + clientConnectorConfig.setTrustStorePass("cacertspassword"); + clientConnectorConfig.setTLSStoreType(tlsStoreType); + return clientConnectorConfig; + } + + @Test + public void testClientConnectionWithSSL() throws Throwable { + WebSocketClientConnector webSocketClientConnector = + httpConnectorFactory.createWsClientConnector(getWebSocketClientConnectorConfigWithSSL()); + CountDownLatch countDownLatch = new CountDownLatch(1); + AtomicReference webSocketConnectionAtomicReference = new AtomicReference<>(); + AtomicReference throwableAtomicReference = new AtomicReference<>(); + ClientHandshakeFuture handshakeFuture = webSocketClientConnector.connect(); + WebSocketTestClientConnectorListener clientConnectorListener = new WebSocketTestClientConnectorListener(); + handshakeFuture.setWebSocketConnectorListener(clientConnectorListener); + handshakeFuture.setClientHandshakeListener(new ClientHandshakeListener() { + @Override public void onSuccess(WebSocketConnection webSocketConnection, HttpCarbonResponse response) { + webSocketConnectionAtomicReference.set(webSocketConnection); + countDownLatch.countDown(); + } + + @Override public void onError(Throwable throwable, HttpCarbonResponse response) { + throwableAtomicReference.set(throwable); + countDownLatch.countDown(); + } + }); + countDownLatch.await(WEBSOCKET_TEST_IDLE_TIMEOUT, SECONDS); + Throwable throwable = throwableAtomicReference.get(); + + Assert.assertNull(webSocketConnectionAtomicReference.get()); + Assert.assertNotNull(throwable); + Assert.assertEquals(throwable.getMessage(), "General SSLEngine problem"); + } + + + + @AfterClass + public void cleanup() throws InterruptedException { + serverConnector.stop(); + httpConnectorFactory.shutdown(); + } + +} diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/ssl/WebSocketSSLHandshakeSuccessfulTestCase.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/ssl/WebSocketSSLHandshakeSuccessfulTestCase.java new file mode 100644 index 000000000..7e877afca --- /dev/null +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/ssl/WebSocketSSLHandshakeSuccessfulTestCase.java @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2018, WSO2 Inc. (http://www.wso2.org) All Rights Reserved. + * + * WSO2 Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.transport.http.netty.websocket.ssl; + +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; +import org.wso2.transport.http.netty.config.ListenerConfiguration; +import org.wso2.transport.http.netty.contract.HttpWsConnectorFactory; +import org.wso2.transport.http.netty.contract.ServerConnector; +import org.wso2.transport.http.netty.contract.ServerConnectorFuture; +import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeFuture; +import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeListener; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnector; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; +import org.wso2.transport.http.netty.contract.websocket.WebSocketConnection; +import org.wso2.transport.http.netty.contractimpl.DefaultHttpWsConnectorFactory; +import org.wso2.transport.http.netty.message.HttpCarbonResponse; +import org.wso2.transport.http.netty.util.TestUtil; +import org.wso2.transport.http.netty.websocket.client.WebSocketTestClientConnectorListener; +import org.wso2.transport.http.netty.websocket.server.WebSocketTestServerConnectorListener; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static org.wso2.transport.http.netty.util.TestUtil.WEBSOCKET_REMOTE_SERVER_PORT; +import static org.wso2.transport.http.netty.util.TestUtil.WEBSOCKET_SECURE_REMOTE_SERVER_URL; +import static org.wso2.transport.http.netty.util.TestUtil.WEBSOCKET_TEST_IDLE_TIMEOUT; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class WebSocketSSLHandshakeSuccessfulTestCase { + + private String password = "ballerina"; + private String tlsStoreType = "PKCS12"; + private HttpWsConnectorFactory httpConnectorFactory; + private ServerConnector serverConnector; + + @BeforeClass + public void setup() throws InterruptedException { + httpConnectorFactory = new DefaultHttpWsConnectorFactory(); + + ListenerConfiguration listenerConfiguration = getListenerConfiguration(); + serverConnector = httpConnectorFactory + .createServerConnector(TestUtil.getDefaultServerBootstrapConfig(), listenerConfiguration); + ServerConnectorFuture future = serverConnector.start(); + future.setWebSocketConnectorListener(new WebSocketTestServerConnectorListener()); + future.sync(); + } + + private ListenerConfiguration getListenerConfiguration() { + ListenerConfiguration listenerConfiguration = ListenerConfiguration.getDefault(); + listenerConfiguration.setPort(WEBSOCKET_REMOTE_SERVER_PORT); + String keyStoreFile = "/simple-test-config/wso2carbon.p12"; + listenerConfiguration.setKeyStoreFile(TestUtil.getAbsolutePath(keyStoreFile)); + listenerConfiguration.setScheme("https"); + listenerConfiguration.setKeyStorePass(password); + listenerConfiguration.setCertPass(password); + listenerConfiguration.setTLSStoreType(tlsStoreType); + return listenerConfiguration; + } + + private WebSocketClientConnectorConfig getWebSocketClientConnectorConfigWithSSL() { + WebSocketClientConnectorConfig senderConfiguration = + new WebSocketClientConnectorConfig(WEBSOCKET_SECURE_REMOTE_SERVER_URL); + String trustStoreFile = "/simple-test-config/client-truststore.p12"; + senderConfiguration.setTrustStoreFile(TestUtil.getAbsolutePath(trustStoreFile)); + senderConfiguration.setTrustStorePass(password); + senderConfiguration.setTLSStoreType(tlsStoreType); + return senderConfiguration; + } + + @Test + public void testClientConnectionWithSSL() throws Throwable { + WebSocketClientConnector webSocketClientConnector = + httpConnectorFactory.createWsClientConnector(getWebSocketClientConnectorConfigWithSSL()); + CountDownLatch countDownLatch = new CountDownLatch(1); + AtomicReference webSocketConnectionAtomicReference = new AtomicReference<>(); + AtomicReference throwableAtomicReference = new AtomicReference<>(); + ClientHandshakeFuture handshakeFuture = webSocketClientConnector.connect(); + WebSocketTestClientConnectorListener clientConnectorListener = new WebSocketTestClientConnectorListener(); + handshakeFuture.setWebSocketConnectorListener(clientConnectorListener); + handshakeFuture.setClientHandshakeListener(new ClientHandshakeListener() { + @Override public void onSuccess(WebSocketConnection webSocketConnection, HttpCarbonResponse response) { + webSocketConnectionAtomicReference.set(webSocketConnection); + countDownLatch.countDown(); + } + + @Override public void onError(Throwable throwable, HttpCarbonResponse response) { + throwableAtomicReference.set(throwable); + countDownLatch.countDown(); + } + }); + countDownLatch.await(WEBSOCKET_TEST_IDLE_TIMEOUT, SECONDS); + WebSocketConnection webSocketConnection = webSocketConnectionAtomicReference.get(); + + Assert.assertNull(throwableAtomicReference.get()); + Assert.assertNotNull(webSocketConnection); + Assert.assertTrue(webSocketConnection.getSession().isSecure()); + + // Test whether message can be received after a successful handshake. + webSocketConnection.startReadingFrames(); + String testText = "testText"; + CountDownLatch msgCountDownLatch = new CountDownLatch(1); + clientConnectorListener.setCountDownLatch(msgCountDownLatch); + webSocketConnection.pushText(testText); + msgCountDownLatch.await(WEBSOCKET_TEST_IDLE_TIMEOUT, SECONDS); + + Assert.assertEquals(clientConnectorListener.getReceivedTextToClient(), testText); + } + + + + @AfterClass + public void cleanup() throws InterruptedException { + serverConnector.stop(); + httpConnectorFactory.shutdown(); + } + +} diff --git a/components/org.wso2.transport.http.netty/src/test/resources/testng.xml b/components/org.wso2.transport.http.netty/src/test/resources/testng.xml index c7282b84b..e23515d52 100644 --- a/components/org.wso2.transport.http.netty/src/test/resources/testng.xml +++ b/components/org.wso2.transport.http.netty/src/test/resources/testng.xml @@ -82,6 +82,8 @@ + +