diff --git a/common/network-common/src/main/java/org/apache/spark/network/ssl/ReloadingX509TrustManager.java b/common/network-common/src/main/java/org/apache/spark/network/ssl/ReloadingX509TrustManager.java
new file mode 100644
index 000000000000..4c39a5d2a3de
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/ssl/ReloadingX509TrustManager.java
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.ssl;
+
+import javax.net.ssl.TrustManager;
+import javax.net.ssl.TrustManagerFactory;
+import javax.net.ssl.X509TrustManager;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.security.GeneralSecurityException;
+import java.security.KeyStore;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
+import java.util.concurrent.atomic.AtomicReference;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * A {@link TrustManager} implementation that reloads its configuration when
+ * the truststore file on disk changes.
+ * This implementation is based off of the
+ * org.apache.hadoop.security.ssl.ReloadingX509TrustManager class in the Apache Hadoop Encrypted
+ * Shuffle implementation.
+ *
+ * @see Hadoop MapReduce Next Generation - Encrypted Shuffle
+ */
+public final class ReloadingX509TrustManager
+ implements X509TrustManager, Runnable {
+
+ private final Logger logger = LoggerFactory.getLogger(ReloadingX509TrustManager.class);
+
+ private final String type;
+ private final File file;
+ // The file being pointed to by `file` if it's a link
+ private String canonicalPath;
+ private final String password;
+ private long lastLoaded;
+ private final long reloadInterval;
+ @VisibleForTesting
+ protected volatile int reloadCount;
+ @VisibleForTesting
+ protected volatile int needsReloadCheckCounts;
+ private final AtomicReference trustManagerRef;
+
+ private volatile boolean running;
+ private Thread reloader;
+
+ /**
+ * Creates a reloadable trustmanager. The trustmanager reloads itself
+ * if the underlying trustore file has changed.
+ *
+ * @param type type of truststore file, typically 'jks'.
+ * @param trustStore the truststore file.
+ * @param password password of the truststore file.
+ * @param reloadInterval interval to check if the truststore file has
+ * changed, in milliseconds.
+ * @throws IOException thrown if the truststore could not be initialized due
+ * to an IO error.
+ * @throws GeneralSecurityException thrown if the truststore could not be
+ * initialized due to a security error.
+ */
+ public ReloadingX509TrustManager(
+ String type, File trustStore, String password, long reloadInterval)
+ throws IOException, GeneralSecurityException {
+ this.type = type;
+ this.file = trustStore;
+ this.canonicalPath = this.file.getCanonicalPath();
+ this.password = password;
+ this.trustManagerRef = new AtomicReference();
+ this.trustManagerRef.set(loadTrustManager());
+ this.reloadInterval = reloadInterval;
+ this.reloadCount = 0;
+ this.needsReloadCheckCounts = 0;
+ }
+
+ /**
+ * Starts the reloader thread.
+ */
+ public void init() {
+ reloader = new Thread(this, "Truststore reloader thread");
+ reloader.setDaemon(true);
+ running = true;
+ reloader.start();
+ }
+
+ /**
+ * Stops the reloader thread.
+ */
+ public void destroy() throws InterruptedException {
+ running = false;
+ reloader.interrupt();
+ reloader.join();
+ }
+
+ /**
+ * Returns the reload check interval.
+ *
+ * @return the reload check interval, in milliseconds.
+ */
+ public long getReloadInterval() {
+ return reloadInterval;
+ }
+
+ @Override
+ public void checkClientTrusted(X509Certificate[] chain, String authType)
+ throws CertificateException {
+ X509TrustManager tm = trustManagerRef.get();
+ if (tm != null) {
+ tm.checkClientTrusted(chain, authType);
+ } else {
+ throw new CertificateException("Unknown client chain certificate: " +
+ chain[0].toString() + ". Please ensure the correct trust store is specified in the config");
+ }
+ }
+
+ @Override
+ public void checkServerTrusted(X509Certificate[] chain, String authType)
+ throws CertificateException {
+ X509TrustManager tm = trustManagerRef.get();
+ if (tm != null) {
+ tm.checkServerTrusted(chain, authType);
+ } else {
+ throw new CertificateException("Unknown server chain certificate: " +
+ chain[0].toString() + ". Please ensure the correct trust store is specified in the config");
+ }
+ }
+
+ private static final X509Certificate[] EMPTY = new X509Certificate[0];
+
+ @Override
+ public X509Certificate[] getAcceptedIssuers() {
+ X509Certificate[] issuers = EMPTY;
+ X509TrustManager tm = trustManagerRef.get();
+ if (tm != null) {
+ issuers = tm.getAcceptedIssuers();
+ }
+ return issuers;
+ }
+
+ boolean needsReload() throws IOException {
+ boolean reload = true;
+ File latestCanonicalFile = file.getCanonicalFile();
+ if (file.exists() && latestCanonicalFile.exists()) {
+ // `file` can be a symbolic link. We need to reload if it points to another file,
+ // or if the file has been modified
+ if (latestCanonicalFile.getPath().equals(canonicalPath) &&
+ latestCanonicalFile.lastModified() == lastLoaded) {
+ reload = false;
+ }
+ } else {
+ lastLoaded = 0;
+ }
+ return reload;
+ }
+
+ X509TrustManager loadTrustManager()
+ throws IOException, GeneralSecurityException {
+ X509TrustManager trustManager = null;
+ KeyStore ks = KeyStore.getInstance(type);
+ File latestCanonicalFile = file.getCanonicalFile();
+ canonicalPath = latestCanonicalFile.getPath();
+ lastLoaded = latestCanonicalFile.lastModified();
+ try (FileInputStream in = new FileInputStream(latestCanonicalFile)) {
+ ks.load(in, password.toCharArray());
+ logger.debug("Loaded truststore '" + file + "'");
+ }
+
+ TrustManagerFactory trustManagerFactory =
+ TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+ trustManagerFactory.init(ks);
+ TrustManager[] trustManagers = trustManagerFactory.getTrustManagers();
+ for (TrustManager trustManager1 : trustManagers) {
+ if (trustManager1 instanceof X509TrustManager) {
+ trustManager = (X509TrustManager) trustManager1;
+ break;
+ }
+ }
+ return trustManager;
+ }
+
+ @Override
+ public void run() {
+ while (running) {
+ try {
+ Thread.sleep(reloadInterval);
+ } catch (InterruptedException e) {
+ //NOP
+ }
+ try {
+ if (running && needsReload()) {
+ try {
+ trustManagerRef.set(loadTrustManager());
+ this.reloadCount += 1;
+ } catch (Exception ex) {
+ logger.warn(
+ "Could not load truststore (keep using existing one) : " + ex.toString(),
+ ex
+ );
+ }
+ }
+ } catch (IOException ex) {
+ logger.warn("Could not check whether truststore needs reloading: " + ex.toString(), ex);
+ }
+ needsReloadCheckCounts++;
+ }
+ }
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/ssl/ReloadingX509TrustManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ssl/ReloadingX509TrustManagerSuite.java
new file mode 100644
index 000000000000..7e2cc38e70b3
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/ssl/ReloadingX509TrustManagerSuite.java
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.ssl;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.file.Files;
+import java.security.KeyPair;
+import java.security.KeyStore;
+import java.security.cert.X509Certificate;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+import static org.apache.spark.network.ssl.SslSampleConfigs.*;
+
+public class ReloadingX509TrustManagerSuite {
+
+ private final Logger logger = LoggerFactory.getLogger(ReloadingX509TrustManagerSuite.class);
+
+ /**
+ * Waits until reload count hits the requested value, sleeping 100ms at a time.
+ * If the maximum number of attempts is hit, throws a RuntimeException
+ * @param tm the trust manager to wait for
+ * @param count The count to wait for
+ * @param attempts The number of attempts to wait for
+ */
+ private void waitForReloadCount(ReloadingX509TrustManager tm, int count, int attempts)
+ throws InterruptedException {
+ if (tm.reloadCount > count) {
+ throw new IllegalStateException(
+ "Passed invalid count " + count + " to waitForReloadCount, already have " + tm.reloadCount);
+ }
+ for (int i = 0; i < attempts; i++) {
+ if (tm.reloadCount >= count) {
+ return;
+ }
+ // Adapted from SystemClock.waitTillTime
+ long startTime = System.currentTimeMillis();
+ long targetTime = startTime + 100;
+ long currentTime = startTime;
+ while (currentTime < targetTime) {
+ long sleepTime = Math.min(10, targetTime - currentTime);
+ Thread.sleep(sleepTime);
+ currentTime = System.currentTimeMillis();
+ }
+ }
+ throw new IllegalStateException("Trust store not reloaded after " + attempts + " attempts!");
+ }
+
+ /**
+ * Waits until we make some number of attempts to reload, and verifies
+ * that the actual reload count did not change
+ *
+ * @param tm the trust manager to wait for
+ * @param attempts The number of attempts to wait for
+ */
+ private void waitForNoReload(ReloadingX509TrustManager tm, int attempts)
+ throws InterruptedException {
+ int oldReloadCount = tm.reloadCount;
+ int checkCount = tm.needsReloadCheckCounts;
+ int target = checkCount + attempts;
+ while (checkCount < target) {
+ Thread.sleep(100);
+ checkCount = tm.needsReloadCheckCounts;
+ }
+ assertEquals(oldReloadCount, tm.reloadCount);
+ }
+
+ /**
+ * Tests to ensure that loading a missing trust-store fails
+ *
+ * @throws Exception
+ */
+ @Test
+ public void testLoadMissingTrustStore() throws Exception {
+ File trustStore = new File("testmissing.jks");
+ assertFalse(trustStore.exists());
+
+ assertThrows(IOException.class, () -> {
+ ReloadingX509TrustManager tm = new ReloadingX509TrustManager(
+ KeyStore.getDefaultType(),
+ trustStore,
+ "password",
+ 10
+ );
+ try {
+ tm.init();
+ } finally {
+ tm.destroy();
+ }
+ });
+ }
+
+ /**
+ * Tests to ensure that loading a corrupt trust-store fails
+ *
+ * @throws Exception
+ */
+ @Test
+ public void testLoadCorruptTrustStore() throws Exception {
+ File corruptStore = File.createTempFile("truststore-corrupt", "jks");
+ corruptStore.deleteOnExit();
+ OutputStream os = new FileOutputStream(corruptStore);
+ os.write(1);
+ os.close();
+
+ assertThrows(IOException.class, () -> {
+ ReloadingX509TrustManager tm = new ReloadingX509TrustManager(
+ KeyStore.getDefaultType(),
+ corruptStore,
+ "password",
+ 10
+ );
+ try {
+ tm.init();
+ } finally {
+ tm.destroy();
+ corruptStore.delete();
+ }
+ });
+ }
+
+ /**
+ * Tests that we successfully reload when a file is updated
+ * @throws Exception
+ */
+ @Test
+ public void testReload() throws Exception {
+ KeyPair kp = generateKeyPair("RSA");
+ X509Certificate cert1 = generateCertificate("CN=Cert1", kp, 30, "SHA1withRSA");
+ X509Certificate cert2 = generateCertificate("CN=Cert2", kp, 30, "SHA1withRSA");
+ File trustStore = File.createTempFile("testreload", "jks");
+ trustStore.deleteOnExit();
+ createTrustStore(trustStore, "password", "cert1", cert1);
+
+ ReloadingX509TrustManager tm =
+ new ReloadingX509TrustManager("jks", trustStore, "password", 1);
+ assertEquals(1, tm.getReloadInterval());
+ assertEquals(0, tm.reloadCount);
+ try {
+ tm.init();
+ assertEquals(1, tm.getAcceptedIssuers().length);
+ // At this point we haven't reloaded, just the initial load
+ assertEquals(0, tm.reloadCount);
+
+ // Add another cert
+ Map certs = new HashMap();
+ certs.put("cert1", cert1);
+ certs.put("cert2", cert2);
+ createTrustStore(trustStore, "password", certs);
+
+ // Wait up to 5s until we reload
+ waitForReloadCount(tm, 1, 50);
+
+ assertEquals(2, tm.getAcceptedIssuers().length);
+ } finally {
+ tm.destroy();
+ trustStore.delete();
+ }
+ }
+
+ /**
+ * Tests that we keep old certs if the trust store goes missing
+ *
+ * @throws Exception
+ */
+ @Test
+ public void testReloadMissingTrustStore() throws Exception {
+ KeyPair kp = generateKeyPair("RSA");
+ X509Certificate cert1 = generateCertificate("CN=Cert1", kp, 30, "SHA1withRSA");
+ File trustStore = new File("testmissing.jks");
+ trustStore.deleteOnExit();
+ assertFalse(trustStore.exists());
+ createTrustStore(trustStore, "password", "cert1", cert1);
+
+ ReloadingX509TrustManager tm =
+ new ReloadingX509TrustManager("jks", trustStore, "password", 1);
+ assertEquals(0, tm.reloadCount);
+ try {
+ tm.init();
+ assertEquals(1, tm.getAcceptedIssuers().length);
+ X509Certificate cert = tm.getAcceptedIssuers()[0];
+ trustStore.delete();
+
+ // Wait for up to 5s - we should *not* reload
+ waitForNoReload(tm, 50);
+
+ assertEquals(1, tm.getAcceptedIssuers().length);
+ assertEquals(cert, tm.getAcceptedIssuers()[0]);
+ } finally {
+ tm.destroy();
+ }
+ }
+
+ /**
+ * Tests that we keep old certs if the new truststore is corrupt
+ * @throws Exception
+ */
+ @Test
+ public void testReloadCorruptTrustStore() throws Exception {
+ KeyPair kp = generateKeyPair("RSA");
+ X509Certificate cert1 = generateCertificate("CN=Cert1", kp, 30, "SHA1withRSA");
+ File corruptStore = File.createTempFile("truststore-corrupt", "jks");
+ corruptStore.deleteOnExit();
+ createTrustStore(corruptStore, "password", "cert1", cert1);
+
+ ReloadingX509TrustManager tm =
+ new ReloadingX509TrustManager("jks", corruptStore, "password", 1);
+ assertEquals(0, tm.reloadCount);
+ try {
+ tm.init();
+ assertEquals(1, tm.getAcceptedIssuers().length);
+ X509Certificate cert = tm.getAcceptedIssuers()[0];
+
+ OutputStream os = new FileOutputStream(corruptStore);
+ os.write(1);
+ os.close();
+ corruptStore.setLastModified(System.currentTimeMillis() - 1000);
+
+ // Wait for up to 5s - we should *not* reload
+ waitForNoReload(tm, 50);
+
+ assertEquals(1, tm.getAcceptedIssuers().length);
+ assertEquals(cert, tm.getAcceptedIssuers()[0]);
+ } finally {
+ tm.destroy();
+ corruptStore.delete();
+ }
+ }
+
+ /**
+ * Tests that we successfully reload when the trust store is a symlink
+ * and we update the contents of the pointed-to file or we update the file it points to.
+ * @throws Exception
+ */
+ @Test
+ public void testReloadSymlink() throws Exception {
+ KeyPair kp = generateKeyPair("RSA");
+ X509Certificate cert1 = generateCertificate("CN=Cert1", kp, 30, "SHA1withRSA");
+ X509Certificate cert2 = generateCertificate("CN=Cert2", kp, 30, "SHA1withRSA");
+ X509Certificate cert3 = generateCertificate("CN=Cert3", kp, 30, "SHA1withRSA");
+
+ File trustStore1 = File.createTempFile("testreload", "jks");
+ trustStore1.deleteOnExit();
+ createTrustStore(trustStore1, "password", "cert1", cert1);
+
+ File trustStore2 = File.createTempFile("testreload", "jks");
+ Map certs = new HashMap();
+ certs.put("cert1", cert1);
+ certs.put("cert2", cert2);
+ createTrustStore(trustStore2, "password", certs);
+
+ File trustStoreSymlink = File.createTempFile("testreloadsymlink", "jks");
+ trustStoreSymlink.delete();
+ Files.createSymbolicLink(trustStoreSymlink.toPath(), trustStore1.toPath());
+
+ ReloadingX509TrustManager tm =
+ new ReloadingX509TrustManager("jks", trustStoreSymlink, "password", 1);
+ assertEquals(1, tm.getReloadInterval());
+ assertEquals(0, tm.reloadCount);
+ logger.info("TRUST STORE 1 IS" + trustStore1);
+ logger.info("TRUST STORE 2 IS " + trustStore2);
+ try {
+ tm.init();
+ assertEquals(1, tm.getAcceptedIssuers().length);
+ // At this point we haven't reloaded, just the initial load
+ assertEquals(0, tm.reloadCount);
+
+ // Repoint to trustStore2, which has another cert
+ logger.info("REPOINTING SYMLINK!!!");
+ trustStoreSymlink.delete();
+ Files.createSymbolicLink(trustStoreSymlink.toPath(), trustStore2.toPath());
+ logger.info("REPOINTED!!!");
+
+ // Wait up to 5s until we reload
+ waitForReloadCount(tm, 1, 50);
+
+ assertEquals(2, tm.getAcceptedIssuers().length);
+
+ // Add another cert
+ certs.put("cert3", cert3);
+ createTrustStore(trustStore2, "password", certs);
+
+ // Wait up to 5s until we reload
+ waitForReloadCount(tm, 2, 50);
+
+ assertEquals(3, tm.getAcceptedIssuers().length);
+ } finally {
+ tm.destroy();
+ trustStore1.delete();
+ trustStore2.delete();
+ trustStoreSymlink.delete();
+ }
+ }
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/ssl/SslSampleConfigs.java b/common/network-common/src/test/java/org/apache/spark/network/ssl/SslSampleConfigs.java
index 3c81b0af3186..2a04d740e8ad 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/ssl/SslSampleConfigs.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/ssl/SslSampleConfigs.java
@@ -21,6 +21,8 @@
import java.io.FileOutputStream;
import java.io.IOException;
import java.math.BigInteger;
+import java.nio.file.Files;
+import java.nio.file.StandardCopyOption;
import java.security.GeneralSecurityException;
import java.security.InvalidKeyException;
import java.security.Key;
@@ -41,9 +43,6 @@
import org.apache.spark.network.util.MapConfigProvider;
-/**
- *
- */
public class SslSampleConfigs {
public static final String keyStorePath = getAbsolutePath("/keystore");
@@ -217,9 +216,18 @@ private static KeyStore createEmptyKeyStore()
private static void saveKeyStore(
KeyStore ks, File keyStore, String password)
throws GeneralSecurityException, IOException {
- FileOutputStream out = new FileOutputStream(keyStore);
+ // Write the file atomically to ensure tests don't read a partial write
+ File tempFile = File.createTempFile("temp-key-store", "jks");
+ FileOutputStream out = new FileOutputStream(tempFile);
try {
ks.store(out, password.toCharArray());
+ out.close();
+ Files.move(
+ tempFile.toPath(),
+ keyStore.toPath(),
+ StandardCopyOption.REPLACE_EXISTING,
+ StandardCopyOption.ATOMIC_MOVE
+ );
} finally {
out.close();
}