diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index f8961fff8e17..ee9051d024c5 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -292,6 +292,12 @@ private[spark] class SecurityManager( */ def isSslRpcEnabled(): Boolean = sslRpcEnabled + /** + * Returns the SSLOptions object for the RPC namespace + * @return the SSLOptions object for the RPC namespace + */ + def getRpcSSLOptions(): SSLOptions = rpcSSLOptions + /** * Gets the user used for authenticating SASL connections. * For now use a single hardcoded user. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 310dc8284401..c2bae41d34ee 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -374,7 +374,12 @@ object SparkEnv extends Logging { } val externalShuffleClient = if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { - val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) + val transConf = SparkTransportConf.fromSparkConf( + conf, + "shuffle", + numUsableCores, + sslOptions = Some(securityManager.getRpcSSLOptions()) + ) Some(new ExternalBlockStoreClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT))) } else { diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 466c1f2e14b1..a56fbd5a644a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -53,7 +53,11 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val registeredExecutorsDB = "registeredExecutors" private val transportConf = - SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0) + SparkTransportConf.fromSparkConf( + sparkConf, + "shuffle", + numUsableCores = 0, + sslOptions = Some(securityManager.getRpcSSLOptions())) private val blockHandler = newShuffleBlockHandler(transportConf) private var transportContext: TransportContext = _ diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b074ac814a96..f964e2b50b57 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -90,7 +90,9 @@ private[spark] class CoarseGrainedExecutorBackend( logInfo("Connecting to driver: " + driverUrl) try { - val shuffleClientTransportConf = SparkTransportConf.fromSparkConf(env.conf, "shuffle") + val securityManager = new SecurityManager(env.conf) + val shuffleClientTransportConf = SparkTransportConf.fromSparkConf( + env.conf, "shuffle", sslOptions = Some(securityManager.getRpcSSLOptions())) if (NettyUtils.preferDirectBufs(shuffleClientTransportConf) && PlatformDependent.maxDirectMemory() < env.conf.get(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)) { throw new SparkException(s"Netty direct memory should at least be bigger than " + diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index f54383db4c0e..6b785a07c7f4 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -70,7 +70,11 @@ private[spark] class NettyBlockTransferService( val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager) var serverBootstrap: Option[TransportServerBootstrap] = None var clientBootstrap: Option[TransportClientBootstrap] = None - this.transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) + this.transportConf = SparkTransportConf.fromSparkConf( + conf, + "shuffle", + numCores, + sslOptions = Some(securityManager.getRpcSSLOptions())) if (authEnabled) { serverBootstrap = Some(new AuthServerBootstrap(transportConf, securityManager)) clientBootstrap = Some(new AuthClientBootstrap(transportConf, conf.getAppId, securityManager)) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 464b6cbc6b0a..7909f2327cdf 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -56,7 +56,9 @@ private[netty] class NettyRpcEnv( conf.clone.set(RPC_IO_NUM_CONNECTIONS_PER_PEER, 1), "rpc", conf.get(RPC_IO_THREADS).getOrElse(numUsableCores), - role) + role, + sslOptions = Some(securityManager.getRpcSSLOptions()) + ) private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores) @@ -391,7 +393,11 @@ private[netty] class NettyRpcEnv( } val ioThreads = clone.getInt("spark.files.io.threads", 1) - val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads) + val downloadConf = SparkTransportConf.fromSparkConf( + clone, + module, + ioThreads, + sslOptions = Some(securityManager.getRpcSSLOptions())) val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true) fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps()) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 919b0f5f7c13..ab34bae996cd 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -24,7 +24,7 @@ import java.nio.file.Files import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException} import org.apache.spark.errors.SparkCoreErrors import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.NioBufferedFileInputStream @@ -58,7 +58,11 @@ private[spark] class IndexShuffleBlockResolver( private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager) - private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") + private val transportConf = { + val securityManager = new SecurityManager(conf) + SparkTransportConf.fromSparkConf( + conf, "shuffle", sslOptions = Some(securityManager.getRpcSSLOptions())) + } private val remoteShuffleMaxDisk: Option[Long] = conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_MAX_DISK_SIZE) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala index ac43ba8b56fc..252f929da282 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala @@ -25,7 +25,7 @@ import java.util.concurrent.ExecutorService import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import scala.util.control.NonFatal -import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.{SecurityManager, ShuffleDependency, SparkConf, SparkContext, SparkEnv} import org.apache.spark.annotation.Since import org.apache.spark.executor.{CoarseGrainedExecutorBackend, ExecutorBackend} import org.apache.spark.internal.Logging @@ -108,7 +108,9 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { dep: ShuffleDependency[_, _, _], mapIndex: Int): Unit = { val numPartitions = dep.partitioner.numPartitions - val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") + val securityManager = new SecurityManager(conf) + val transportConf = SparkTransportConf.fromSparkConf( + conf, "shuffle", sslOptions = Some(securityManager.getRpcSSLOptions())) this.shuffleId = dep.shuffleId this.shuffleMergeId = dep.shuffleMergeId this.mapIndex = mapIndex diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index aa9ba7c34f69..f77fda461493 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1252,7 +1252,8 @@ private[spark] class BlockManager( new EncryptedBlockData(file, blockSize, conf, key)) case _ => - val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") + val transportConf = SparkTransportConf.fromSparkConf( + conf, "shuffle", sslOptions = Some(securityManager.getRpcSSLOptions())) new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } Some(managedBuffer) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 1d1bb9e9eee8..f0a63247e64b 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -49,8 +49,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi var transportContext: TransportContext = _ var rpcHandler: ExternalBlockHandler = _ - override def beforeAll(): Unit = { - super.beforeAll() + protected def initializeHandlers(): Unit = { val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2) rpcHandler = new ExternalBlockHandler(transportConf, null) transportContext = new TransportContext(transportConf, rpcHandler) @@ -61,6 +60,11 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi conf.set(config.SHUFFLE_SERVICE_PORT, server.getPort) } + override def beforeAll(): Unit = { + super.beforeAll() + initializeHandlers() + } + override def afterAll(): Unit = { Utils.tryLogNonFatalError{ server.close() diff --git a/core/src/test/scala/org/apache/spark/SslExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/SslExternalShuffleServiceSuite.scala new file mode 100644 index 000000000000..3ce1f11a7acd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SslExternalShuffleServiceSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.config +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.shuffle.ExternalBlockHandler + +/** + * This suite creates an external shuffle server and routes all shuffle fetches through it. + * Note that failures in this suite may arise due to changes in Spark that invalidate expectations + * set up in `ExternalShuffleBlockHandler`, such as changing the format of shuffle files or how + * we hash files into folders. + */ +class SslExternalShuffleServiceSuite extends ExternalShuffleServiceSuite { + + override def initializeHandlers(): Unit = { + SslTestUtils.updateWithSSLConfig(conf) + val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf); + // Show that we can successfully inherit options defined in the `spark.ssl` namespace + val defaultSslOptions = SSLOptions.parse(conf, hadoopConf, "spark.ssl") + val sslOptions = SSLOptions.parse( + conf, hadoopConf, "spark.ssl.rpc", defaults = Some(defaultSslOptions)) + val transportConf = SparkTransportConf.fromSparkConf( + conf, "shuffle", numUsableCores = 2, sslOptions = Some(sslOptions)) + + rpcHandler = new ExternalBlockHandler(transportConf, null) + transportContext = new TransportContext(transportConf, rpcHandler) + server = transportContext.createServer() + + conf.set(config.SHUFFLE_MANAGER, "sort") + conf.set(config.SHUFFLE_SERVICE_ENABLED, true) + conf.set(config.SHUFFLE_SERVICE_PORT, server.getPort) + } +} diff --git a/core/src/test/scala/org/apache/spark/SslShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/SslShuffleNettySuite.scala new file mode 100644 index 000000000000..7eaff7d37a81 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SslShuffleNettySuite.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +class SslShuffleNettySuite extends ShuffleNettySuite { + + override def beforeAll(): Unit = { + super.beforeAll() + SslTestUtils.updateWithSSLConfig(conf) + } +} diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala index 7a7021357eda..3ef4da6d3d3f 100644 --- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala @@ -52,8 +52,12 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite implicit val formats = DefaultFormats + def createSparkConf(): SparkConf = { + new SparkConf() + } + test("parsing no resources") { - val conf = new SparkConf + val conf = createSparkConf() val resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf) val serializer = new JavaSerializer(conf) val env = createMockEnv(conf, serializer) @@ -75,7 +79,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite } test("parsing one resource") { - val conf = new SparkConf + val conf = createSparkConf() conf.set(EXECUTOR_GPU_ID.amountConf, "2") val serializer = new JavaSerializer(conf) val env = createMockEnv(conf, serializer) @@ -100,11 +104,11 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite val ereqs = new ExecutorResourceRequests().resource(GPU, 2) ereqs.resource(FPGA, 3) val rp = rpBuilder.require(ereqs).build() - testParsingMultipleResources(new SparkConf, rp) + testParsingMultipleResources(createSparkConf(), rp) } test("parsing multiple resources") { - val conf = new SparkConf + val conf = createSparkConf() conf.set(EXECUTOR_GPU_ID.amountConf, "2") conf.set(EXECUTOR_FPGA_ID.amountConf, "3") testParsingMultipleResources(conf, ResourceProfile.getOrCreateDefaultProfile(conf)) @@ -136,7 +140,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite } test("error checking parsing resources and executor and task configs") { - val conf = new SparkConf + val conf = createSparkConf() conf.set(EXECUTOR_GPU_ID.amountConf, "2") val serializer = new JavaSerializer(conf) val env = createMockEnv(conf, serializer) @@ -178,11 +182,11 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite val ereqs = new ExecutorResourceRequests().resource(GPU, 4) val treqs = new TaskResourceRequests().resource(GPU, 1) val rp = rpBuilder.require(ereqs).require(treqs).build() - testExecutorResourceFoundLessThanRequired(new SparkConf, rp) + testExecutorResourceFoundLessThanRequired(createSparkConf(), rp) } test("executor resource found less than required") { - val conf = new SparkConf() + val conf = createSparkConf() conf.set(EXECUTOR_GPU_ID.amountConf, "4") conf.set(TASK_GPU_ID.amountConf, "1") testExecutorResourceFoundLessThanRequired(conf, ResourceProfile.getOrCreateDefaultProfile(conf)) @@ -213,7 +217,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite } test("use resource discovery") { - val conf = new SparkConf + val conf = createSparkConf() conf.set(EXECUTOR_FPGA_ID.amountConf, "3") assume(!(Utils.isWindows)) withTempDir { dir => @@ -246,7 +250,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite val ereqs = new ExecutorResourceRequests().resource(FPGA, 3, scriptPath) ereqs.resource(GPU, 2) val rp = rpBuilder.require(ereqs).build() - allocatedFileAndConfigsResourceDiscoveryTestFpga(dir, new SparkConf, rp) + allocatedFileAndConfigsResourceDiscoveryTestFpga(dir, createSparkConf(), rp) } } @@ -255,7 +259,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite withTempDir { dir => val scriptPath = createTempScriptWithExpectedOutput(dir, "fpgaDiscoverScript", """{"name": "fpga","addresses":["f1", "f2", "f3"]}""") - val conf = new SparkConf + val conf = createSparkConf() conf.set(EXECUTOR_FPGA_ID.amountConf, "3") conf.set(EXECUTOR_FPGA_ID.discoveryScriptConf, scriptPath) conf.set(EXECUTOR_GPU_ID.amountConf, "2") @@ -289,7 +293,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite } test("track allocated resources by taskId") { - val conf = new SparkConf + val conf = createSparkConf() val securityMgr = new SecurityManager(conf) val serializer = new JavaSerializer(conf) var backend: CoarseGrainedExecutorBackend = null @@ -389,7 +393,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite * being executed in [[Executor.TaskRunner]]. */ test(s"Tasks launched should always be cancelled.") { - val conf = new SparkConf + val conf = createSparkConf() val securityMgr = new SecurityManager(conf) val serializer = new JavaSerializer(conf) val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor") @@ -478,7 +482,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite * it has not been launched yet. */ test(s"Tasks not launched should always be cancelled.") { - val conf = new SparkConf + val conf = createSparkConf() val securityMgr = new SecurityManager(conf) val serializer = new JavaSerializer(conf) val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor") @@ -567,7 +571,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite * [[SparkUncaughtExceptionHandler]] and [[Executor]] can exit by itself. */ test("SPARK-40320 Executor should exit when initialization failed for fatal error") { - val conf = new SparkConf() + val conf = createSparkConf() .setMaster("local-cluster[1, 1, 1024]") .set(PLUGINS, Seq(classOf[TestFatalErrorPlugin].getName)) .setAppName("test") @@ -628,3 +632,11 @@ private class TestErrorExecutorPlugin extends ExecutorPlugin { // scalastyle:on throwerror } } + +class SslCoarseGrainedExecutorBackendSuite extends CoarseGrainedExecutorBackendSuite + with LocalSparkContext with MockitoSugar { + + override def createSparkConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createSparkConf()) + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 85b05cd5f98f..5c234ef95500 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -32,7 +32,7 @@ import org.scalatest.matchers.must.Matchers import org.scalatest.matchers.should.Matchers._ import org.scalatestplus.mockito.MockitoSugar -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite, SslTestUtils} import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Network import org.apache.spark.network.{BlockDataManager, BlockTransferService} @@ -43,8 +43,15 @@ import org.apache.spark.storage.{BlockId, ShuffleBlockId} import org.apache.spark.util.ThreadUtils class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar with Matchers { + + def createSparkConf(): SparkConf = { + new SparkConf() + } + + def isRunningWithSSL(): Boolean = false + test("security default off") { - val conf = new SparkConf() + val conf = createSparkConf() .set("spark.app.id", "app-id") testConnection(conf, conf) match { case Success(_) => // expected @@ -53,7 +60,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } test("security on same password") { - val conf = new SparkConf() + val conf = createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good") .set("spark.app.id", "app-id") @@ -64,7 +71,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } test("security on mismatch password") { - val conf0 = new SparkConf() + val conf0 = createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good") .set("spark.app.id", "app-id") @@ -76,7 +83,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } test("security mismatch auth off on server") { - val conf0 = new SparkConf() + val conf0 = createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good") .set("spark.app.id", "app-id") @@ -100,15 +107,17 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } test("security with aes encryption") { - val conf = new SparkConf() - .set(NETWORK_AUTH_ENABLED, true) - .set(AUTH_SECRET, "good") - .set("spark.app.id", "app-id") - .set(Network.NETWORK_CRYPTO_ENABLED, true) - .set(Network.NETWORK_CRYPTO_SASL_FALLBACK, false) - testConnection(conf, conf) match { - case Success(_) => // expected - case Failure(t) => fail(t) + if (!isRunningWithSSL()) { + val conf = new SparkConf() + .set(NETWORK_AUTH_ENABLED, true) + .set(AUTH_SECRET, "good") + .set("spark.app.id", "app-id") + .set(Network.NETWORK_CRYPTO_ENABLED, true) + .set(Network.NETWORK_CRYPTO_SASL_FALLBACK, false) + testConnection(conf, conf) match { + case Success(_) => // expected + case Failure(t) => fail(t) + } } } @@ -179,3 +188,11 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } } +class SslNettyBlockTransferSecuritySuite extends NettyBlockTransferSecuritySuite { + + override def isRunningWithSSL(): Boolean = true + + override def createSparkConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createSparkConf()) + } +} diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index a88be983b804..3ef382573517 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -44,9 +44,13 @@ abstract class RpcEnvSuite extends SparkFunSuite { var env: RpcEnv = _ + def createSparkConf(): SparkConf = { + new SparkConf() + } + override def beforeAll(): Unit = { super.beforeAll() - val conf = new SparkConf() + val conf = createSparkConf() env = createRpcEnv(conf, "local", 0) val sparkEnv = mock(classOf[SparkEnv]) @@ -93,7 +97,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) + val anotherEnv = createRpcEnv(createSparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "send-remotely") try { @@ -145,7 +149,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) + val anotherEnv = createRpcEnv(createSparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-remotely") try { @@ -168,7 +172,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val conf = new SparkConf() + val conf = createSparkConf() val shortProp = "spark.rpc.short.timeout" val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef @@ -198,7 +202,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val conf = new SparkConf() + val conf = createSparkConf() val shortProp = "spark.rpc.short.timeout" val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef @@ -467,7 +471,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) + val anotherEnv = createRpcEnv(createSparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely") try { @@ -507,7 +511,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) + val anotherEnv = createRpcEnv(createSparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely-error") try { @@ -556,8 +560,8 @@ abstract class RpcEnvSuite extends SparkFunSuite { } test("network events in sever RpcEnv when another RpcEnv is in server mode") { - val serverEnv1 = createRpcEnv(new SparkConf(), "server1", 0, clientMode = false) - val serverEnv2 = createRpcEnv(new SparkConf(), "server2", 0, clientMode = false) + val serverEnv1 = createRpcEnv(createSparkConf(), "server1", 0, clientMode = false) + val serverEnv2 = createRpcEnv(createSparkConf(), "server2", 0, clientMode = false) val (_, events) = setupNetworkEndpoint(serverEnv1, "network-events") val (serverRef2, _) = setupNetworkEndpoint(serverEnv2, "network-events") try { @@ -585,9 +589,9 @@ abstract class RpcEnvSuite extends SparkFunSuite { } test("network events in sever RpcEnv when another RpcEnv is in client mode") { - val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false) + val serverEnv = createRpcEnv(createSparkConf(), "server", 0, clientMode = false) val (serverRef, events) = setupNetworkEndpoint(serverEnv, "network-events") - val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) + val clientEnv = createRpcEnv(createSparkConf(), "client", 0, clientMode = true) try { val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address, serverRef.name) // Send a message to set up the connection @@ -615,8 +619,8 @@ abstract class RpcEnvSuite extends SparkFunSuite { } test("network events in client RpcEnv when another RpcEnv is in server mode") { - val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) - val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false) + val clientEnv = createRpcEnv(createSparkConf(), "client", 0, clientMode = true) + val serverEnv = createRpcEnv(createSparkConf(), "server", 0, clientMode = false) val (_, events) = setupNetworkEndpoint(clientEnv, "network-events") val (serverRef, _) = setupNetworkEndpoint(serverEnv, "network-events") try { @@ -652,7 +656,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) + val anotherEnv = createRpcEnv(createSparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-unserializable-error") @@ -669,7 +673,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } test("port conflict") { - val anotherEnv = createRpcEnv(new SparkConf(), "remote", env.address.port) + val anotherEnv = createRpcEnv(createSparkConf(), "remote", env.address.port) try { assert(anotherEnv.address.port != env.address.port) } finally { @@ -729,20 +733,20 @@ abstract class RpcEnvSuite extends SparkFunSuite { } test("send with authentication") { - testSend(new SparkConf() + testSend(createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good")) } test("send with SASL encryption") { - testSend(new SparkConf() + testSend(createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good") .set(SASL_ENCRYPTION_ENABLED, true)) } test("send with AES encryption") { - testSend(new SparkConf() + testSend(createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good") .set(Network.NETWORK_CRYPTO_ENABLED, true) @@ -750,20 +754,20 @@ abstract class RpcEnvSuite extends SparkFunSuite { } test("ask with authentication") { - testAsk(new SparkConf() + testAsk(createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good")) } test("ask with SASL encryption") { - testAsk(new SparkConf() + testAsk(createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good") .set(SASL_ENCRYPTION_ENABLED, true)) } test("ask with AES encryption") { - testAsk(new SparkConf() + testAsk(createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good") .set(Network.NETWORK_CRYPTO_ENABLED, true) @@ -861,7 +865,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { test("file server") { withTempDir { tempDir => withTempDir { destDir => - val conf = new SparkConf() + val conf = createSparkConf() val file = new File(tempDir, "file") Files.write(UUID.randomUUID().toString(), file, UTF_8) @@ -940,7 +944,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0) + val anotherEnv = createRpcEnv(createSparkConf(), "remote", 0) val endpoint = mock(classOf[RpcEndpoint]) anotherEnv.setupEndpoint("SPARK-14699", endpoint) @@ -960,7 +964,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { test("isolated endpoints") { val latch = new CountDownLatch(1) val singleThreadedEnv = createRpcEnv( - new SparkConf().set(Network.RPC_NETTY_DISPATCHER_NUM_THREADS, 1), "singleThread", 0) + createSparkConf().set(Network.RPC_NETTY_DISPATCHER_NUM_THREADS, 1), "singleThread", 0) try { val blockingEndpoint = singleThreadedEnv .setupEndpoint("blocking", new IsolatedThreadSafeRpcEndpoint { diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index fe6d0db837bd..dcd40b6afd56 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -53,7 +53,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar with TimeLimits { } test("advertise address different from bind address") { - val sparkConf = new SparkConf() + val sparkConf = createSparkConf() val config = RpcEnvConfig(sparkConf, "test", "localhost", "example.com", 0, new SecurityManager(sparkConf), 0, false) val env = new NettyRpcEnvFactory().create(config) @@ -95,7 +95,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar with TimeLimits { test("StackOverflowError should be sent back and Dispatcher should survive") { val numUsableCores = 2 - val conf = new SparkConf + val conf = createSparkConf() val config = RpcEnvConfig( conf, "test", @@ -150,7 +150,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar with TimeLimits { context.reply(msg) } }) - val conf = new SparkConf() + val conf = createSparkConf() val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-remotely-server") @@ -180,3 +180,9 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar with TimeLimits { } } } + +class SslNettyRpcEnvSuite extends NettyRpcEnvSuite with MockitoSugar with TimeLimits { + override def createSparkConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createSparkConf()) + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala index 18c27ff12699..99f113ec16ac 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala @@ -53,9 +53,13 @@ class ShuffleBlockPusherSuite extends SparkFunSuite { private var conf: SparkConf = _ private var pushedBlocks = new ArrayBuffer[String] + def createSparkConf(): SparkConf = { + new SparkConf(loadDefaults = false) + } + override def beforeEach(): Unit = { super.beforeEach() - conf = new SparkConf(loadDefaults = false) + conf = createSparkConf() MockitoAnnotations.openMocks(this).close() when(dependency.shuffleId).thenReturn(0) when(dependency.partitioner).thenReturn(new HashPartitioner(8)) @@ -480,3 +484,9 @@ class ShuffleBlockPusherSuite extends SparkFunSuite { } } } + +class SslShuffleBlockPusherSuite extends ShuffleBlockPusherSuite { + override def createSparkConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createSparkConf()) + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index 31b255cff728..8a9537b4f18d 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -26,7 +26,7 @@ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.roaringbitmap.RoaringBitmap -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite, SslTestUtils} import org.apache.spark.internal.config import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleBlockInfo} import org.apache.spark.storage._ @@ -37,8 +37,12 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite { @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ + def createSparkConf(): SparkConf = { + new SparkConf(loadDefaults = false) + } + private var tempDir: File = _ - private val conf: SparkConf = new SparkConf(loadDefaults = false) + private val conf: SparkConf = createSparkConf() private val appId = "TESTAPP" override def beforeEach(): Unit = { @@ -275,3 +279,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite { assert(checksumsInMemory === checksumsFromFile) } } + +class SslIndexShuffleBlockResolverSuite extends IndexShuffleBlockResolverSuite { + override def createSparkConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createSparkConf()) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 38a669bc8574..1fbc900727c4 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -50,7 +50,8 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { - val conf: SparkConf + val conf: SparkConf = createConf() + protected def createConf(): SparkConf protected var rpcEnv: RpcEnv = null protected var master: BlockManagerMaster = null @@ -459,15 +460,21 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite } class BlockManagerReplicationSuite extends BlockManagerReplicationBehavior { - val conf = new SparkConf(false).set("spark.app.id", "test") - conf.set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") + override def createConf(): SparkConf = { + new SparkConf(false) + .set("spark.app.id", "test") + .set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") + } } class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehavior { - val conf = new SparkConf(false).set("spark.app.id", "test") - conf.set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") - conf.set(STORAGE_REPLICATION_PROACTIVE, true) - conf.set(STORAGE_EXCEPTION_PIN_LEAK, true) + override def createConf(): SparkConf = { + new SparkConf(false) + .set("spark.app.id", "test") + .set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") + .set(STORAGE_REPLICATION_PROACTIVE, true) + .set(STORAGE_EXCEPTION_PIN_LEAK, true) + } (2 to 5).foreach { i => test(s"proactive block replication - $i replicas - ${i - 1} block manager deletions") { @@ -539,14 +546,17 @@ class DummyTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Log } class BlockManagerBasicStrategyReplicationSuite extends BlockManagerReplicationBehavior { - val conf: SparkConf = new SparkConf(false).set("spark.app.id", "test") - conf.set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") - conf.set( - STORAGE_REPLICATION_POLICY, - classOf[BasicBlockReplicationPolicy].getName) - conf.set( - STORAGE_REPLICATION_TOPOLOGY_MAPPER, - classOf[DummyTopologyMapper].getName) + override def createConf(): SparkConf = { + new SparkConf(false) + .set("spark.app.id", "test") + .set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") + .set( + STORAGE_REPLICATION_POLICY, + classOf[BasicBlockReplicationPolicy].getName) + .set( + STORAGE_REPLICATION_TOPOLOGY_MAPPER, + classOf[DummyTopologyMapper].getName) + } } // BlockReplicationPolicy to prioritize BlockManagers based on hostnames diff --git a/core/src/test/scala/org/apache/spark/storage/SslBlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/SslBlockManagerReplicationSuite.scala new file mode 100644 index 000000000000..760f31de0597 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/SslBlockManagerReplicationSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.{SparkConf, SslTestUtils} + +class SslBlockManagerReplicationSuite extends BlockManagerReplicationSuite { + override def createConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createConf()) + } +} + +class SslBlockManagerProactiveReplicationSuite extends BlockManagerProactiveReplicationSuite { + override def createConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createConf()) + } +} + +class SslBlockManagerBasicStrategyReplicationSuite + extends BlockManagerBasicStrategyReplicationSuite { + override def createConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createConf()) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/SslTestUtils.scala b/core/src/test/scala/org/apache/spark/util/SslTestUtils.scala new file mode 100644 index 000000000000..dd71a68f6250 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/SslTestUtils.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.network.ssl.SslSampleConfigs + +object SslTestUtils { + + /** + * Updates a SparkConf to contain SSL configurations + * + * @param conf The config to update + * @return The passed in SparkConf with SSL configurations added + */ + def updateWithSSLConfig(conf: SparkConf): SparkConf = { + SslSampleConfigs.createDefaultConfigMap().entrySet(). + forEach(entry => conf.set(entry.getKey, entry.getValue)) + conf + } +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala new file mode 100644 index 000000000000..322d6bfdb7cd --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn + +import org.apache.spark.network.ssl.SslSampleConfigs + +class SslYarnShuffleServiceWithRocksDBBackendSuite + extends YarnShuffleServiceWithRocksDBBackendSuite { + + /** + * Override to add "spark.ssl.rpc.*" configuration parameters... + */ + override def beforeEach(): Unit = { + super.beforeEach() + // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to import here. + SslSampleConfigs.createDefaultConfigMap().entrySet(). + forEach(entry => yarnConfig.set(entry.getKey, entry.getValue)) + } +}