Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.azure.core.annotation.Immutable;
import com.azure.core.credential.TokenCredential;
import com.azure.core.util.ClientOptions;
import org.apache.qpid.proton.engine.SslDomain;
import reactor.core.scheduler.Scheduler;

import java.util.Objects;
Expand All @@ -26,10 +27,13 @@ public class ConnectionOptions {
private final String fullyQualifiedNamespace;
private final CbsAuthorizationType authorizationType;
private final ClientOptions clientOptions;
private final SslDomain.VerifyMode verifyMode;

public ConnectionOptions(String fullyQualifiedNamespace, TokenCredential tokenCredential,
CbsAuthorizationType authorizationType, AmqpTransportType transport, AmqpRetryOptions retryOptions,
ProxyOptions proxyOptions, Scheduler scheduler, ClientOptions clientOptions) {
ProxyOptions proxyOptions, Scheduler scheduler, ClientOptions clientOptions,
SslDomain.VerifyMode verifyMode) {

this.fullyQualifiedNamespace = Objects.requireNonNull(fullyQualifiedNamespace,
"'fullyQualifiedNamespace' is required.");
this.tokenCredential = Objects.requireNonNull(tokenCredential, "'tokenCredential' is required.");
Expand All @@ -38,7 +42,12 @@ public ConnectionOptions(String fullyQualifiedNamespace, TokenCredential tokenCr
this.retryOptions = Objects.requireNonNull(retryOptions, "'retryOptions' is required.");
this.proxyOptions = Objects.requireNonNull(proxyOptions, "'proxyConfiguration' is required.");
this.scheduler = Objects.requireNonNull(scheduler, "'scheduler' is required.");
this.clientOptions = clientOptions;
this.clientOptions = Objects.requireNonNull(clientOptions, "'clientOptions' is required.");
this.verifyMode = Objects.requireNonNull(verifyMode, "'verifyMode' is required.");
}

public CbsAuthorizationType getAuthorizationType() {
return authorizationType;
}

public ClientOptions getClientOptions() {
Expand All @@ -49,18 +58,6 @@ public String getFullyQualifiedNamespace() {
return fullyQualifiedNamespace;
}

public TokenCredential getTokenCredential() {
return tokenCredential;
}

public CbsAuthorizationType getAuthorizationType() {
return authorizationType;
}

public AmqpTransportType getTransportType() {
return transport;
}

public AmqpRetryOptions getRetry() {
return retryOptions;
}
Expand All @@ -72,4 +69,16 @@ public ProxyOptions getProxyOptions() {
public Scheduler getScheduler() {
return scheduler;
}

public SslDomain.VerifyMode getSslVerifyMode() {
return verifyMode;
}

public TokenCredential getTokenCredential() {
return tokenCredential;
}

public AmqpTransportType getTransportType() {
return transport;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ public ReactorConnection(String connectionId, ConnectionOptions connectionOption
this.messageSerializer = messageSerializer;
this.handler = handlerProvider.createConnectionHandler(connectionId,
connectionOptions.getFullyQualifiedNamespace(), connectionOptions.getTransportType(),
connectionOptions.getProxyOptions(), product, clientVersion, connectionOptions.getClientOptions());
connectionOptions.getProxyOptions(), product, clientVersion, connectionOptions.getSslVerifyMode(),
connectionOptions.getClientOptions());

this.retryPolicy = RetryUtil.getRetryPolicy(connectionOptions.getRetry());
this.senderSettleMode = senderSettleMode;
this.receiverSettleMode = receiverSettleMode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.azure.core.amqp.implementation.handler.WebSocketsProxyConnectionHandler;
import com.azure.core.util.ClientOptions;
import com.azure.core.util.logging.ClientLogger;
import org.apache.qpid.proton.engine.SslDomain;
import org.apache.qpid.proton.reactor.Reactor;

import java.time.Duration;
Expand Down Expand Up @@ -49,20 +50,20 @@ public ReactorHandlerProvider(ReactorProvider provider) {
*/
public ConnectionHandler createConnectionHandler(String connectionId, String hostname,
AmqpTransportType transportType, ProxyOptions proxyOptions, String product, String clientVersion,
ClientOptions clientOptions) {
SslDomain.VerifyMode verifyMode, ClientOptions clientOptions) {
switch (transportType) {
case AMQP:
return new ConnectionHandler(connectionId, hostname, product, clientVersion, clientOptions);
return new ConnectionHandler(connectionId, hostname, product, clientVersion, verifyMode, clientOptions);
case AMQP_WEB_SOCKETS:
if (proxyOptions != null && proxyOptions.isProxyAddressConfigured()) {
return new WebSocketsProxyConnectionHandler(connectionId, hostname, proxyOptions, product,
clientVersion, clientOptions);
clientVersion, verifyMode, clientOptions);
} else if (WebSocketsProxyConnectionHandler.shouldUseProxy(hostname)) {
logger.info("System default proxy configured for hostname '{}'. Using proxy.", hostname);
return new WebSocketsProxyConnectionHandler(connectionId, hostname,
ProxyOptions.SYSTEM_DEFAULTS, product, clientVersion, clientOptions);
ProxyOptions.SYSTEM_DEFAULTS, product, clientVersion, verifyMode, clientOptions);
} else {
return new WebSocketsConnectionHandler(connectionId, hostname, product, clientVersion,
return new WebSocketsConnectionHandler(connectionId, hostname, product, clientVersion, verifyMode,
clientOptions);
}
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.azure.core.amqp.implementation.ClientConstants;
import com.azure.core.amqp.implementation.ExceptionUtil;
import com.azure.core.util.ClientOptions;
import com.azure.core.util.CoreUtils;
import com.azure.core.util.UserAgentUtil;
import com.azure.core.util.logging.ClientLogger;
import org.apache.qpid.proton.Proton;
Expand All @@ -16,12 +17,16 @@
import org.apache.qpid.proton.engine.EndpointState;
import org.apache.qpid.proton.engine.Event;
import org.apache.qpid.proton.engine.SslDomain;
import org.apache.qpid.proton.engine.SslPeerDetails;
import org.apache.qpid.proton.engine.Transport;
import org.apache.qpid.proton.engine.impl.TransportInternal;
import org.apache.qpid.proton.reactor.Handshaker;

import javax.net.ssl.SSLContext;
import java.security.NoSuchAlgorithmException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

/**
* Creates an AMQP connection using sockets and the default AMQP protocol port 5671.
Expand All @@ -38,6 +43,7 @@ public class ConnectionHandler extends Handler {

private final Map<String, Object> connectionProperties;
private final ClientLogger logger = new ClientLogger(ConnectionHandler.class);
private final SslDomain.VerifyMode verifyMode;

/**
* Creates a handler that handles proton-j's connection events.
Expand All @@ -48,19 +54,29 @@ public class ConnectionHandler extends Handler {
* @param clientVersion The version of the client library creating the connection handler.
* @param clientOptions provided by user.
*/
public ConnectionHandler(final String connectionId, final String hostname,
String product, String clientVersion, final ClientOptions clientOptions) {
public ConnectionHandler(final String connectionId, final String hostname, final String product,
final String clientVersion, final SslDomain.VerifyMode verifyMode, final ClientOptions clientOptions) {
super(connectionId, hostname);

add(new Handshaker());

Objects.requireNonNull(connectionId, "'connectionId' cannot be null.");
Objects.requireNonNull(hostname, "'hostname' cannot be null.");
Objects.requireNonNull(product, "'product' cannot be null.");
Objects.requireNonNull(clientVersion, "'clientVersion' cannot be null.");
Objects.requireNonNull(verifyMode, "'verifyMode' cannot be null.");
Objects.requireNonNull(clientOptions, "'clientOptions' cannot be null.");

this.verifyMode = Objects.requireNonNull(verifyMode, "'verifyMode' cannot be null");
this.connectionProperties = new HashMap<>();
this.connectionProperties.put(PRODUCT.toString(), product);
this.connectionProperties.put(VERSION.toString(), clientVersion);
this.connectionProperties.put(PLATFORM.toString(), ClientConstants.PLATFORM_INFO);
this.connectionProperties.put(FRAMEWORK.toString(), ClientConstants.FRAMEWORK_INFO);

final String applicationId = clientOptions != null ? clientOptions.getApplicationId() : null;
final String applicationId = !CoreUtils.isNullOrEmpty(clientOptions.getApplicationId())
? clientOptions.getApplicationId()
: null;
String userAgent = UserAgentUtil.toUserAgentString(applicationId, product, clientVersion, null);
this.connectionProperties.put(USER_AGENT.toString(), userAgent);
}
Expand Down Expand Up @@ -93,8 +109,44 @@ public int getMaxFrameSize() {
}

protected void addTransportLayers(final Event event, final TransportInternal transport) {
final SslDomain domain = createSslDomain(SslDomain.Mode.CLIENT);
transport.ssl(domain);
final SslDomain sslDomain = Proton.sslDomain();
sslDomain.init(SslDomain.Mode.CLIENT);

final SSLContext defaultSslContext;

if (verifyMode == SslDomain.VerifyMode.ANONYMOUS_PEER) {
defaultSslContext = null;
} else {
try {
defaultSslContext = SSLContext.getDefault();
} catch (NoSuchAlgorithmException e) {
throw logger.logExceptionAsError(new RuntimeException(
"Default SSL algorithm not found in JRE. Please check your JRE setup.", e));
}
}

if (verifyMode == SslDomain.VerifyMode.VERIFY_PEER_NAME) {
final StrictTlsContextSpi serviceProvider = new StrictTlsContextSpi(defaultSslContext);
final SSLContext context = new StrictTlsContext(serviceProvider, defaultSslContext.getProvider(),
defaultSslContext.getProtocol());
final SslPeerDetails peerDetails = Proton.sslPeerDetails(getHostname(), getProtocolPort());

sslDomain.setSslContext(context);
transport.ssl(sslDomain, peerDetails);
return;
}

if (verifyMode == SslDomain.VerifyMode.VERIFY_PEER) {
sslDomain.setSslContext(defaultSslContext);
} else if (verifyMode == SslDomain.VerifyMode.ANONYMOUS_PEER) {
logger.warning("{} is not secure.", verifyMode);
} else {
throw logger.logExceptionAsError(new UnsupportedOperationException(
"verifyMode is not supported: " + verifyMode));
}

sslDomain.setPeerAuthentication(verifyMode);
transport.ssl(sslDomain);
}

@Override
Expand Down Expand Up @@ -240,15 +292,6 @@ public AmqpErrorContext getErrorContext() {
return new AmqpErrorContext(getHostname());
}

private static SslDomain createSslDomain(SslDomain.Mode mode) {
final SslDomain domain = Proton.sslDomain();
domain.init(mode);

// TODO: VERIFY_PEER_NAME support
domain.setPeerAuthentication(SslDomain.VerifyMode.ANONYMOUS_PEER);
return domain;
}

private void notifyErrorContext(Connection connection, ErrorCondition condition) {
if (connection == null || connection.getRemoteState() == EndpointState.CLOSED) {
return;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.core.amqp.implementation.handler;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLContextSpi;
import java.security.Provider;

/**
* Context that removes SSLv2Hello protocol from every SSLEngine created.
*/
public class StrictTlsContext extends SSLContext {
/**
* Creates an SSLContext object.
*
* @param contextSpi The service provider for SSL context.
* @param provider The security provider.
* @param protocol The SSL protocol.
*/
protected StrictTlsContext(SSLContextSpi contextSpi, Provider provider, String protocol) {
super(contextSpi, provider, protocol);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.core.amqp.implementation.handler;

import com.azure.core.util.logging.ClientLogger;

import javax.net.ssl.KeyManager;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLContextSpi;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSessionContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import java.security.KeyManagementException;
import java.security.SecureRandom;
import java.util.Objects;
import java.util.stream.Stream;

/**
* SSL context service provider that takes a standard SSLContext and disables the SSLv2Hello protocol.
*/
class StrictTlsContextSpi extends SSLContextSpi {
private static final String SSL_V2_HELLO = "SSLv2Hello";

private final ClientLogger logger = new ClientLogger(StrictTlsContextSpi.class);
private final SSLContext sslContext;

/**
* Creates an instance with the given SSL context.
*
* @param sslContext SSL context to use.
*
* @throws NullPointerException if {@code sslContext} is null.
*/
StrictTlsContextSpi(SSLContext sslContext) {
this.sslContext = Objects.requireNonNull(sslContext, "'sslContext' cannot be null.");
}

@Override
protected void engineInit(KeyManager[] keyManagers, TrustManager[] trustManagers, SecureRandom secureRandom)
throws KeyManagementException {

sslContext.init(keyManagers, trustManagers, secureRandom);
}

@Override
protected SSLSocketFactory engineGetSocketFactory() {
return sslContext.getSocketFactory();
}

@Override
protected SSLServerSocketFactory engineGetServerSocketFactory() {
return sslContext.getServerSocketFactory();
}

/**
* Creates an SSLEngine from the context without SSLv2Hello protocol enabled.
*
* @return An {@code SSLEngine} object.
*/
@Override
protected SSLEngine engineCreateSSLEngine() {
final SSLEngine sslEngine = sslContext.createSSLEngine();
final String[] protocols = getAllowedProtocols(sslEngine.getEnabledProtocols());

sslEngine.setEnabledProtocols(protocols);
return sslEngine;
}

/**
* Creates an SSLEngine from the context without SSLv2Hello protocol enabled.
*
* @param host the non-authoritative name of the host
* @param port the non-authoritative port
*
* @return An {@code SSLEngine} object.
*/
@Override
protected SSLEngine engineCreateSSLEngine(String host, int port) {
final SSLEngine sslEngine = sslContext.createSSLEngine(host, port);
final String[] protocols = getAllowedProtocols(sslEngine.getEnabledProtocols());

sslEngine.setEnabledProtocols(protocols);
return sslEngine;
}

@Override
protected SSLSessionContext engineGetServerSessionContext() {
return sslContext.getServerSessionContext();
}

@Override
protected SSLSessionContext engineGetClientSessionContext() {
return sslContext.getClientSessionContext();
}

/**
* Removes {@link #SSL_V2_HELLO} protocol if it is available.
*
* @return Enabled protocols.
*/
private String[] getAllowedProtocols(String[] protocols) {
return Stream.of(protocols)
.filter(protocol -> {
final boolean isSSLv2Hello = protocol.equalsIgnoreCase(SSL_V2_HELLO);
if (isSSLv2Hello) {
logger.info("{} was an enabled protocol. Filtering out.", SSL_V2_HELLO);
}

return !isSSLv2Hello;
}).toArray(String[]::new);
}
}
Loading