Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,27 @@

package org.springframework.boot.web.embedded.netty;

import java.net.Socket;
import java.net.URL;
import java.security.InvalidAlgorithmParameterException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.Principal;
import java.security.PrivateKey;
import java.security.Provider;
import java.security.UnrecoverableKeyException;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.stream.Collectors;

import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.KeyManagerFactorySpi;
import javax.net.ssl.ManagerFactoryParameters;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509ExtendedKeyManager;

import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.SslContextBuilder;
Expand Down Expand Up @@ -92,8 +107,10 @@ else if (this.ssl.getClientAuth() == Ssl.ClientAuth.WANT) {
protected KeyManagerFactory getKeyManagerFactory(Ssl ssl, SslStoreProvider sslStoreProvider) {
try {
KeyStore keyStore = getKeyStore(ssl, sslStoreProvider);
KeyManagerFactory keyManagerFactory = KeyManagerFactory
.getInstance(KeyManagerFactory.getDefaultAlgorithm());
KeyManagerFactory keyManagerFactory = (ssl.getKeyAlias() == null)
? KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
: ConfigurableAliasKeyManagerFactory.instance(ssl.getKeyAlias(),
KeyManagerFactory.getDefaultAlgorithm());
char[] keyPassword = (ssl.getKeyPassword() != null) ? ssl.getKeyPassword().toCharArray() : null;
if (keyPassword == null && ssl.getKeyStorePassword() != null) {
keyPassword = ssl.getKeyStorePassword().toCharArray();
Expand Down Expand Up @@ -161,4 +178,120 @@ private KeyStore loadStore(String type, String provider, String resource, String

}

/**
* A {@link KeyManagerFactory} that allows a configurable key alias to be used. Due to
* the fact that the actual calls to retrieve the key by alias are done at request
* time the approach is to wrap the actual key managers with a
* {@link ConfigurableAliasKeyManager}. The actual SPI has to be wrapped as well due
* to the fact that {@link KeyManagerFactory#getKeyManagers()} is final.
*/
private static class ConfigurableAliasKeyManagerFactory extends KeyManagerFactory {

static final ConfigurableAliasKeyManagerFactory instance(String alias, String algorithm)
throws NoSuchAlgorithmException {
KeyManagerFactory originalFactory = KeyManagerFactory.getInstance(algorithm);
ConfigurableAliasKeyManagerFactorySpi spi = new ConfigurableAliasKeyManagerFactorySpi(originalFactory,
alias);
return new ConfigurableAliasKeyManagerFactory(spi, originalFactory.getProvider(), algorithm);
}

ConfigurableAliasKeyManagerFactory(ConfigurableAliasKeyManagerFactorySpi spi, Provider provider,
String algorithm) {
super(spi, provider, algorithm);
}

}

private static class ConfigurableAliasKeyManagerFactorySpi extends KeyManagerFactorySpi {

private KeyManagerFactory originalFactory;

private String alias;

ConfigurableAliasKeyManagerFactorySpi(KeyManagerFactory originalFactory, String alias) {
this.originalFactory = originalFactory;
this.alias = alias;
}

@Override
protected void engineInit(KeyStore keyStore, char[] chars)
throws KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException {
this.originalFactory.init(keyStore, chars);
}

@Override
protected void engineInit(ManagerFactoryParameters managerFactoryParameters)
throws InvalidAlgorithmParameterException {
throw new InvalidAlgorithmParameterException("Unsupported ManagerFactoryParameters");
}

@Override
protected KeyManager[] engineGetKeyManagers() {
return Arrays.stream(this.originalFactory.getKeyManagers()).filter(X509ExtendedKeyManager.class::isInstance)
.map(X509ExtendedKeyManager.class::cast).map(this::wrapKeyManager).collect(Collectors.toList())
.toArray(new KeyManager[0]);
}

private ConfigurableAliasKeyManager wrapKeyManager(X509ExtendedKeyManager km) {
return new ConfigurableAliasKeyManager(km, this.alias);
}

}

private static class ConfigurableAliasKeyManager extends X509ExtendedKeyManager {

private final X509ExtendedKeyManager keyManager;

private final String alias;

ConfigurableAliasKeyManager(X509ExtendedKeyManager keyManager, String alias) {
this.keyManager = keyManager;
this.alias = alias;
}

@Override
public String chooseEngineClientAlias(String[] strings, Principal[] principals, SSLEngine sslEngine) {
return this.keyManager.chooseEngineClientAlias(strings, principals, sslEngine);
}

@Override
public String chooseEngineServerAlias(String s, Principal[] principals, SSLEngine sslEngine) {
if (this.alias == null) {
return this.keyManager.chooseEngineServerAlias(s, principals, sslEngine);
}
return this.alias;
}

@Override
public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) {
return this.keyManager.chooseClientAlias(keyType, issuers, socket);
}

@Override
public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) {
return this.keyManager.chooseServerAlias(keyType, issuers, socket);
}

@Override
public X509Certificate[] getCertificateChain(String alias) {
return this.keyManager.getCertificateChain(alias);
}

@Override
public String[] getClientAliases(String keyType, Principal[] issuers) {
return this.keyManager.getClientAliases(keyType, issuers);
}

@Override
public PrivateKey getPrivateKey(String alias) {
return this.keyManager.getPrivateKey(alias);
}

@Override
public String[] getServerAliases(String keyType, Principal[] issuers) {
return this.keyManager.getServerAliases(keyType, issuers);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,25 @@

package org.springframework.boot.web.embedded.netty;

import java.time.Duration;
import java.util.Arrays;

import javax.net.ssl.SSLHandshakeException;

import org.junit.Test;
import org.mockito.InOrder;
import reactor.core.publisher.Mono;
import reactor.netty.http.server.HttpServer;
import reactor.test.StepVerifier;

import org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactory;
import org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactoryTests;
import org.springframework.boot.web.server.PortInUseException;
import org.springframework.boot.web.server.Ssl;
import org.springframework.http.MediaType;
import org.springframework.http.client.reactive.ReactorClientHttpConnector;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.WebClient;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
Expand Down Expand Up @@ -83,4 +93,38 @@ public void useForwardedHeaders() {
assertForwardHeaderIsUsed(factory);
}

@Test
public void testSslWithValidAlias() {
Mono<String> result = testSslWithAlias("test-alias");
StepVerifier.setDefaultTimeout(Duration.ofSeconds(30));
StepVerifier.create(result).expectNext("Hello World").verifyComplete();
}

@Test
public void testSslWithInvalidAlias() {
Mono<String> result = testSslWithAlias("test-alias-bad");
StepVerifier.setDefaultTimeout(Duration.ofSeconds(30));
StepVerifier.create(result).expectErrorMatches((throwable) -> throwable instanceof SSLHandshakeException
&& throwable.getMessage().contains("HANDSHAKE_FAILURE")).verify();
}

protected Mono<String> testSslWithAlias(String alias) {
String keyStore = "classpath:test.jks";
String keyPassword = "password";
NettyReactiveWebServerFactory factory = getFactory();
Ssl ssl = new Ssl();
ssl.setKeyStore(keyStore);
ssl.setKeyPassword(keyPassword);
ssl.setKeyAlias(alias);
factory.setSsl(ssl);
this.webServer = factory.getWebServer(new EchoHandler());
this.webServer.start();
ReactorClientHttpConnector connector = buildTrustAllSslConnector();
WebClient client = WebClient.builder().baseUrl("https://localhost:" + this.webServer.getPort())
.clientConnector(connector).build();
return client.post().uri("/test").contentType(MediaType.TEXT_PLAIN)
.body(BodyInserters.fromObject("Hello World")).exchange()
.flatMap((response) -> response.bodyToMono(String.class));
}

}