diff --git a/hadoop-hdds/server-scm/src/main/java/org/apache/hadoop/hdds/scm/ha/InterSCMGrpcProtocolService.java b/hadoop-hdds/server-scm/src/main/java/org/apache/hadoop/hdds/scm/ha/InterSCMGrpcProtocolService.java
index 7197bb1cd601..92f28d07e973 100644
--- a/hadoop-hdds/server-scm/src/main/java/org/apache/hadoop/hdds/scm/ha/InterSCMGrpcProtocolService.java
+++ b/hadoop-hdds/server-scm/src/main/java/org/apache/hadoop/hdds/scm/ha/InterSCMGrpcProtocolService.java
@@ -26,16 +26,20 @@
import org.apache.hadoop.hdds.scm.ScmConfigKeys;
import org.apache.hadoop.hdds.scm.server.StorageContainerManager;
import org.apache.hadoop.hdds.security.SecurityConfig;
+import org.apache.hadoop.hdds.security.ssl.KeyStoresFactory;
import org.apache.hadoop.hdds.security.x509.certificate.client.CertificateClient;
import org.apache.hadoop.ozone.OzoneConsts;
import org.apache.ratis.thirdparty.io.grpc.Server;
import org.apache.ratis.thirdparty.io.grpc.ServerBuilder;
import org.apache.ratis.thirdparty.io.grpc.netty.GrpcSslContexts;
import org.apache.ratis.thirdparty.io.grpc.netty.NettyServerBuilder;
+import org.apache.ratis.thirdparty.io.netty.handler.ssl.ClientAuth;
import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import static org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder.forServer;
+
/**
* Service to serve SCM DB checkpoints available for SCM HA.
* Ideally should only be run on a ratis leader.
@@ -66,11 +70,13 @@ public class InterSCMGrpcProtocolService {
&& securityConfig.isGrpcTlsEnabled()) {
try {
CertificateClient certClient = scm.getScmCertificateClient();
+ KeyStoresFactory keyStores = certClient.getServerKeyStoresFactory();
SslContextBuilder sslServerContextBuilder =
- SslContextBuilder.forServer(
- certClient.getServerKeyStoresFactory().getKeyManagers()[0]);
+ forServer(keyStores.getKeyManagers()[0])
+ .trustManager(keyStores.getTrustManagers()[0]);
SslContextBuilder sslContextBuilder = GrpcSslContexts.configure(
sslServerContextBuilder, securityConfig.getGrpcSslProvider());
+ sslContextBuilder.clientAuth(ClientAuth.REQUIRE);
nettyServerBuilder.sslContext(sslContextBuilder.build());
} catch (Exception ex) {
LOG.error("Unable to setup TLS for secure " +
diff --git a/hadoop-hdds/server-scm/src/test/java/org/apache/hadoop/hdds/scm/ha/TestInterSCMGrpcProtocolService.java b/hadoop-hdds/server-scm/src/test/java/org/apache/hadoop/hdds/scm/ha/TestInterSCMGrpcProtocolService.java
new file mode 100644
index 000000000000..998839b2d835
--- /dev/null
+++ b/hadoop-hdds/server-scm/src/test/java/org/apache/hadoop/hdds/scm/ha/TestInterSCMGrpcProtocolService.java
@@ -0,0 +1,360 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package org.apache.hadoop.hdds.scm.ha;
+
+import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
+import org.apache.hadoop.hdds.HddsConfigKeys;
+import org.apache.hadoop.hdds.conf.ConfigurationSource;
+import org.apache.hadoop.hdds.conf.OzoneConfiguration;
+import org.apache.hadoop.hdds.scm.ScmConfigKeys;
+import org.apache.hadoop.hdds.scm.metadata.SCMMetadataStore;
+import org.apache.hadoop.hdds.scm.server.StorageContainerManager;
+import org.apache.hadoop.hdds.security.ssl.KeyStoresFactory;
+import org.apache.hadoop.hdds.security.SecurityConfig;
+import org.apache.hadoop.hdds.security.x509.certificate.client.SCMCertificateClient;
+import org.apache.hadoop.hdds.security.x509.keys.HDDSKeyGenerator;
+import org.apache.hadoop.hdds.utils.TransactionInfo;
+import org.apache.hadoop.hdds.utils.db.DBCheckpoint;
+import org.apache.hadoop.hdds.utils.db.DBStore;
+import org.apache.hadoop.hdds.utils.db.Table;
+import org.apache.hadoop.ozone.OzoneConfigKeys;
+import org.bouncycastle.asn1.oiw.OIWObjectIdentifiers;
+import org.bouncycastle.asn1.x500.X500Name;
+import org.bouncycastle.asn1.x509.AlgorithmIdentifier;
+import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier;
+import org.bouncycastle.asn1.x509.BasicConstraints;
+import org.bouncycastle.asn1.x509.Extension;
+import org.bouncycastle.asn1.x509.SubjectKeyIdentifier;
+import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo;
+import org.bouncycastle.cert.X509ExtensionUtils;
+import org.bouncycastle.cert.X509v3CertificateBuilder;
+import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
+import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder;
+import org.bouncycastle.jce.provider.BouncyCastleProvider;
+import org.bouncycastle.operator.ContentSigner;
+import org.bouncycastle.operator.DigestCalculator;
+import org.bouncycastle.operator.OperatorCreationException;
+import org.bouncycastle.operator.bc.BcDigestCalculatorProvider;
+import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.mockito.ArgumentCaptor;
+
+import javax.net.ssl.KeyManager;
+import javax.net.ssl.TrustManager;
+import javax.net.ssl.X509KeyManager;
+import javax.net.ssl.X509TrustManager;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.math.BigInteger;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.security.KeyPair;
+import java.security.NoSuchAlgorithmException;
+import java.security.NoSuchProviderException;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Date;
+import java.util.Random;
+import java.util.concurrent.CompletableFuture;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyBoolean;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * This test checks that mTLS authentication is turned on for
+ * {@link InterSCMGrpcProtocolService}.
+ *
+ * @see HDDS-8901
+ */
+public class TestInterSCMGrpcProtocolService {
+
+ private static final String CP_FILE_NAME = "cpFile";
+ private static final String CP_CONTENTS = "Hello world!";
+
+ private X509Certificate serviceCert;
+ private X509Certificate clientCert;
+
+ private X509KeyManager serverKeyManager;
+ private X509TrustManager serverTrustManager;
+ private X509KeyManager clientKeyManager;
+ private X509TrustManager clientTrustManager;
+
+ @Rule
+ public TemporaryFolder temp = new TemporaryFolder();
+
+ @Test
+ public void testMTLSOnInterScmGrpcProtocolServiceAccess() throws Exception {
+ int port = new Random().nextInt(1000) + 45000;
+ OzoneConfiguration conf = setupConfiguration(port);
+ SCMCertificateClient
+ scmCertClient = setupCertificateClientForMTLS(conf);
+ InterSCMGrpcProtocolService service =
+ new InterSCMGrpcProtocolService(conf, scmWith(scmCertClient));
+ service.start();
+
+ InterSCMGrpcClient client =
+ new InterSCMGrpcClient("localhost", port, conf, scmCertClient);
+ CompletableFuture res = client.download(temp.newFile().toPath());
+ Path downloaded = res.get();
+
+ verifyServiceUsedItsCertAndValidatedClientCert();
+ verifyClientUsedItsCertAndValidatedServerCert();
+ verifyDownloadedCheckPoint(downloaded);
+
+ client.close();
+ service.stop();
+ }
+
+ private void verifyServiceUsedItsCertAndValidatedClientCert()
+ throws CertificateException {
+ ArgumentCaptor capturedCerts =
+ ArgumentCaptor.forClass(X509Certificate[].class);
+ verify(serverKeyManager, times(1)).getCertificateChain(any());
+ verify(serverTrustManager, never()).checkServerTrusted(any(), any());
+ verify(serverTrustManager, times(1))
+ .checkClientTrusted(capturedCerts.capture(), any());
+ assertThat(capturedCerts.getValue().length, is(1));
+ assertThat(capturedCerts.getValue()[0], is(clientCert));
+ }
+
+ private void verifyClientUsedItsCertAndValidatedServerCert()
+ throws CertificateException {
+ ArgumentCaptor capturedCerts =
+ ArgumentCaptor.forClass(X509Certificate[].class);
+ verify(clientKeyManager, times(1)).getCertificateChain(any());
+ verify(clientTrustManager, times(1))
+ .checkServerTrusted(capturedCerts.capture(), any());
+ verify(clientTrustManager, never()).checkClientTrusted(any(), any());
+ assertThat(capturedCerts.getValue().length, is(1));
+ assertThat(capturedCerts.getValue()[0], is(serviceCert));
+ }
+
+ private void verifyDownloadedCheckPoint(Path downloaded) throws IOException {
+ try (
+ TarArchiveInputStream in =
+ new TarArchiveInputStream(Files.newInputStream(downloaded));
+ BufferedReader reader =
+ new BufferedReader(new InputStreamReader(in, UTF_8))
+ ) {
+ assertThat(in.getNextTarEntry().getName(), is(CP_FILE_NAME));
+ assertThat(reader.readLine(), is(CP_CONTENTS));
+ }
+ }
+
+ private StorageContainerManager scmWith(
+ SCMCertificateClient scmCertClient) throws IOException {
+ StorageContainerManager scmMock = mock(StorageContainerManager.class);
+ when(scmMock.getScmCertificateClient()).thenReturn(scmCertClient);
+ SCMMetadataStore metadataStore = metadataStore();
+ when(scmMock.getScmMetadataStore()).thenReturn(metadataStore);
+ SCMHAManager haManager = scmHAManager();
+ when(scmMock.getScmHAManager()).thenReturn(haManager);
+ when(scmMock.getClusterId()).thenReturn("clusterId");
+ return scmMock;
+ }
+
+ private SCMHAManager scmHAManager() {
+ SCMHAManager hamanager = mock(SCMHAManager.class);
+ doReturn(mock(SCMHADBTransactionBuffer.class))
+ .when(hamanager).asSCMHADBTransactionBuffer();
+ return hamanager;
+ }
+
+ private SCMMetadataStore metadataStore() throws IOException {
+ SCMMetadataStore metaStoreMock = mock(SCMMetadataStore.class);
+ DBStore dbStore = dbStore();
+ when(metaStoreMock.getStore()).thenReturn(dbStore);
+ return metaStoreMock;
+ }
+
+ private DBStore dbStore() throws IOException {
+ DBStore dbStoreMock = mock(DBStore.class);
+ doReturn(trInfoTable()).when(dbStoreMock).getTable(any(), any(), any());
+ doReturn(checkPoint()).when(dbStoreMock).getCheckpoint(anyBoolean());
+ return dbStoreMock;
+ }
+
+ private DBCheckpoint checkPoint() throws IOException {
+ Path checkPointLocation = temp.newFolder().toPath();
+ Path cpFile = Paths.get(checkPointLocation.toString(), CP_FILE_NAME);
+ Files.write(cpFile, CP_CONTENTS.getBytes(UTF_8));
+ DBCheckpoint checkpoint = mock(DBCheckpoint.class);
+ when(checkpoint.getCheckpointLocation()).thenReturn(checkPointLocation);
+ return checkpoint;
+ }
+
+ private Table trInfoTable()
+ throws IOException {
+ Table tableMock = mock(Table.class);
+ doReturn(mock(TransactionInfo.class)).when(tableMock).get(any());
+ return tableMock;
+ }
+
+ private SCMCertificateClient setupCertificateClientForMTLS(
+ OzoneConfiguration conf
+ ) throws Exception {
+ KeyPair serviceKeys = aKeyPair(conf);
+ KeyPair clientKeys = aKeyPair(conf);
+
+ serviceCert = createSelfSignedCert(serviceKeys, "service");
+ clientCert = createSelfSignedCert(clientKeys, "client");
+
+ serverKeyManager = aKeyManagerWith(serviceKeys, serviceCert);
+ serverTrustManager = aTrustManagerThatTrusts(clientCert);
+ KeyStoresFactory serverKeyStores =
+ aKeyStoresFactoryWith(serverKeyManager, serverTrustManager);
+
+ clientKeyManager = aKeyManagerWith(clientKeys, clientCert);
+ clientTrustManager = aTrustManagerThatTrusts(serviceCert);
+ KeyStoresFactory clientKeyStores =
+ aKeyStoresFactoryWith(clientKeyManager, clientTrustManager);
+
+ SCMCertificateClient scmCertClient = mock(SCMCertificateClient.class);
+ doReturn(serverKeyStores).when(scmCertClient).getServerKeyStoresFactory();
+ doReturn(clientKeyStores).when(scmCertClient).getClientKeyStoresFactory();
+ return scmCertClient;
+ }
+
+ private KeyStoresFactory aKeyStoresFactoryWith(
+ X509KeyManager keyManager,
+ X509TrustManager trustManager
+ ) {
+ KeyStoresFactory serverKeyStores = mock(KeyStoresFactory.class);
+ doReturn(new KeyManager[]{keyManager})
+ .when(serverKeyStores).getKeyManagers();
+ doReturn(new TrustManager[]{trustManager})
+ .when(serverKeyStores).getTrustManagers();
+ return serverKeyStores;
+ }
+
+ private X509TrustManager aTrustManagerThatTrusts(X509Certificate certificate)
+ throws CertificateException {
+ X509TrustManager trustManager = mock(X509TrustManager.class);
+ doNothing().when(trustManager).checkServerTrusted(any(), any());
+ doNothing().when(trustManager).checkClientTrusted(any(), any());
+ doReturn(new X509Certificate[] {certificate})
+ .when(trustManager).getAcceptedIssuers();
+ return trustManager;
+ }
+
+ private X509KeyManager aKeyManagerWith(KeyPair keyPair,
+ X509Certificate certificate) {
+ X509KeyManager keyManager = mock(X509KeyManager.class);
+ doReturn("server")
+ .when(keyManager).chooseServerAlias(any(), any(), any());
+ doReturn("client")
+ .when(keyManager).chooseClientAlias(any(), any(), any());
+ doReturn(new String[] {"server"})
+ .when(keyManager).getServerAliases(any(), any());
+ doReturn(new String[] {"client"})
+ .when(keyManager).getClientAliases(any(), any());
+ doReturn(new X509Certificate[] {certificate})
+ .when(keyManager).getCertificateChain(any());
+ doReturn(keyPair.getPrivate())
+ .when(keyManager).getPrivateKey(any());
+ return keyManager;
+ }
+
+ private KeyPair aKeyPair(ConfigurationSource conf)
+ throws NoSuchProviderException, NoSuchAlgorithmException {
+ return new HDDSKeyGenerator(new SecurityConfig(conf)).generateKey();
+ }
+
+ private OzoneConfiguration setupConfiguration(int port) {
+ OzoneConfiguration conf = new OzoneConfiguration();
+ conf.setInt(ScmConfigKeys.OZONE_SCM_GRPC_PORT_KEY, port);
+ conf.setBoolean(OzoneConfigKeys.OZONE_SECURITY_ENABLED_KEY, true);
+ conf.setBoolean(HddsConfigKeys.HDDS_GRPC_TLS_ENABLED, true);
+ return conf;
+ }
+
+
+ private static final String HASH_ALGO = "SHA256WithRSA";
+
+ private X509Certificate createSelfSignedCert(KeyPair keys, String commonName)
+ throws Exception {
+ final Instant now = Instant.now();
+ final Date notBefore = Date.from(now);
+ final Date notAfter = Date.from(now.plus(Duration.ofDays(1)));
+ final ContentSigner contentSigner =
+ new JcaContentSignerBuilder(HASH_ALGO).build(keys.getPrivate());
+ final X500Name x500Name = new X500Name("CN=" + commonName);
+
+ SubjectKeyIdentifier keyId = subjectKeyIdOf(keys);
+ AuthorityKeyIdentifier authorityKeyId = authorityKeyIdOf(keys);
+ BasicConstraints constraints = new BasicConstraints(true);
+
+ final X509v3CertificateBuilder certificateBuilder =
+ new JcaX509v3CertificateBuilder(
+ x500Name,
+ BigInteger.valueOf(keys.getPublic().hashCode()),
+ notBefore,
+ notAfter,
+ x500Name,
+ keys.getPublic()
+ );
+ certificateBuilder
+ .addExtension(Extension.subjectKeyIdentifier, false, keyId)
+ .addExtension(Extension.authorityKeyIdentifier, false, authorityKeyId)
+ .addExtension(Extension.basicConstraints, true, constraints);
+
+ return new JcaX509CertificateConverter()
+ .setProvider(new BouncyCastleProvider())
+ .getCertificate(certificateBuilder.build(contentSigner));
+ }
+
+ private SubjectKeyIdentifier subjectKeyIdOf(KeyPair keys) throws Exception {
+ return extensionUtil().createSubjectKeyIdentifier(pubKeyInfo(keys));
+ }
+
+ private AuthorityKeyIdentifier authorityKeyIdOf(KeyPair keys)
+ throws Exception {
+ return extensionUtil().createAuthorityKeyIdentifier(pubKeyInfo(keys));
+ }
+
+ private SubjectPublicKeyInfo pubKeyInfo(KeyPair keys) {
+ return SubjectPublicKeyInfo.getInstance(keys.getPublic().getEncoded());
+ }
+
+ private X509ExtensionUtils extensionUtil()
+ throws OperatorCreationException {
+ DigestCalculator digest =
+ new BcDigestCalculatorProvider()
+ .get(new AlgorithmIdentifier(OIWObjectIdentifiers.idSHA1));
+
+ return new X509ExtensionUtils(digest);
+ }
+
+}