Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ object SparkEnv extends Logging {

// NB: blockManager is not valid until initialize() is called later.
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
serializer, conf, mapOutputTracker, shuffleManager, blockTransferService)
serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager)

val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)

Expand Down
12 changes: 7 additions & 5 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ private[spark] class BlockManager(
val conf: SparkConf,
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager,
blockTransferService: BlockTransferService)
blockTransferService: BlockTransferService,
securityManager: SecurityManager)
extends BlockDataManager with Logging {

val diskBlockManager = new DiskBlockManager(this, conf)
Expand Down Expand Up @@ -115,7 +116,8 @@ private[spark] class BlockManager(
// Client to read other executors' shuffle files. This is either an external service, or just the
// standard BlockTranserService to directly connect to other Executors.
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf))
new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), securityManager,
securityManager.isAuthenticationEnabled())
} else {
blockTransferService
}
Expand Down Expand Up @@ -166,9 +168,10 @@ private[spark] class BlockManager(
conf: SparkConf,
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager,
blockTransferService: BlockTransferService) = {
blockTransferService: BlockTransferService,
securityManager: SecurityManager) = {
this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
conf, mapOutputTracker, shuffleManager, blockTransferService)
conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager)
}

/**
Expand Down Expand Up @@ -219,7 +222,6 @@ private[spark] class BlockManager(
return
} catch {
case e: Exception if i < MAX_ATTEMPTS =>
val attemptsRemaining =
logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}"
+ s" more times after waiting $SLEEP_TIME_SECS seconds...", e)
Thread.sleep(SLEEP_TIME_SECS * 1000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val transfer = new NioBlockTransferService(conf, securityMgr)
val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
mapOutputTracker, shuffleManager, transfer)
mapOutputTracker, shuffleManager, transfer, securityMgr)
store.initialize("app-id")
allStores += store
store
Expand Down Expand Up @@ -263,7 +263,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
when(failableTransfer.hostName).thenReturn("some-hostname")
when(failableTransfer.port).thenReturn(1000)
val failableStore = new BlockManager("failable-store", actorSystem, master, serializer,
10000, conf, mapOutputTracker, shuffleManager, failableTransfer)
10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr)
failableStore.initialize("app-id")
allStores += failableStore // so that this gets stopped after test
assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val transfer = new NioBlockTransferService(conf, securityMgr)
val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
mapOutputTracker, shuffleManager, transfer)
mapOutputTracker, shuffleManager, transfer, securityMgr)
manager.initialize("app-id")
manager
}
Expand Down Expand Up @@ -795,7 +795,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
// Use Java serializer so we can create an unserializable error.
val transfer = new NioBlockTransferService(conf, securityMgr)
store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master,
new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer)
new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr)

// The put should fail since a1 is not serializable.
class UnserializableClass
Expand Down
1 change: 1 addition & 0 deletions network/shuffle/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>11.0.1</version> <!-- yarn 2.4.0's version -->
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually all supported Yarn versions (2.2+) use 11.0.2. I just verified this. Maybe we can even generalize the comment a little.

<scope>provided</scope>
</dependency>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback
logger.trace("SASL client callback: setting realm");
RealmCallback rc = (RealmCallback) callback;
rc.setText(rc.getDefaultText());
logger.info("Realm callback");
} else if (callback instanceof RealmChoiceCallback) {
// ignore (?)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.BaseEncoding;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.base64.Base64;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -159,12 +160,14 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback
/* Encode a byte[] identifier as a Base64-encoded string. */
public static String encodeIdentifier(String identifier) {
Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled");
return BaseEncoding.base64().encode(identifier.getBytes(Charsets.UTF_8));
return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(Charsets.UTF_8)))
.toString(Charsets.UTF_8);
}

/** Encode a password as a base64-encoded char[] array. */
public static char[] encodePassword(String password) {
Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled");
return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray();
return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(Charsets.UTF_8)))
.toString(Charsets.UTF_8).toCharArray();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@

package org.apache.spark.network.shuffle;

import java.util.List;

import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.TransportContext;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.sasl.SaslClientBootstrap;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor;
import org.apache.spark.network.util.JavaUtils;
Expand All @@ -37,18 +43,35 @@
public class ExternalShuffleClient extends ShuffleClient {
private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class);

private final TransportClientFactory clientFactory;
private final TransportConf conf;
private final boolean saslEnabled;
private final SecretKeyHolder secretKeyHolder;

private TransportClientFactory clientFactory;
private String appId;

public ExternalShuffleClient(TransportConf conf) {
TransportContext context = new TransportContext(conf, new NoOpRpcHandler());
this.clientFactory = context.createClientFactory();
/**
* Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled,
* then secretKeyHolder may be null.
*/
public ExternalShuffleClient(
TransportConf conf,
SecretKeyHolder secretKeyHolder,
boolean saslEnabled) {
this.conf = conf;
this.secretKeyHolder = secretKeyHolder;
this.saslEnabled = saslEnabled;
}

@Override
public void init(String appId) {
this.appId = appId;
TransportContext context = new TransportContext(conf, new NoOpRpcHandler());
List<TransportClientBootstrap> bootstraps = Lists.newArrayList();
if (saslEnabled) {
bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder));
}
clientFactory = context.createClientFactory(bootstraps);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) thro

final Semaphore requestsRemaining = new Semaphore(0);

ExternalShuffleClient client = new ExternalShuffleClient(conf);
ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false);
client.init(APP_ID);
client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
new BlockFetchingListener() {
Expand Down Expand Up @@ -267,7 +267,7 @@ public void testFetchNoServer() throws Exception {
}

private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) {
ExternalShuffleClient client = new ExternalShuffleClient(conf);
ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false);
client.init(APP_ID);
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
executorId, executorInfo);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.shuffle;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import static org.junit.Assert.*;

import org.apache.spark.network.TestUtils;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.sasl.SaslRpcHandler;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;

public class ExternalShuffleSecuritySuite {

TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
TransportServer server;

@Before
public void beforeEach() {
RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(),
new TestSecretKeyHolder("my-app-id", "secret"));
TransportContext context = new TransportContext(conf, handler);
this.server = context.createServer();
}

@After
public void afterEach() {
if (server != null) {
server.close();
server = null;
}
}

@Test
public void testValid() {
validate("my-app-id", "secret");
}

@Test
public void testBadAppId() {
try {
validate("wrong-app-id", "secret");
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!"));
}
}

@Test
public void testBadSecret() {
try {
validate("my-app-id", "bad-secret");
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
}
}

/** Creates an ExternalShuffleClient and attempts to register with the server. */
private void validate(String appId, String secretKey) {
ExternalShuffleClient client =
new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true);
client.init(appId);
// Registration either succeeds or throws an exception.
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0",
new ExecutorShuffleInfo(new String[0], 0, ""));
}

/** Provides a secret key holder which always returns the given secret key, for a single appId. */
static class TestSecretKeyHolder implements SecretKeyHolder {
private final String appId;
private final String secretKey;

TestSecretKeyHolder(String appId, String secretKey) {
this.appId = appId;
this.secretKey = secretKey;
}

@Override
public String getSaslUser(String appId) {
return "user";
}
@Override
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need new line above this

public String getSecretKey(String appId) {
if (!appId.equals(this.appId)) {
throw new IllegalArgumentException("Wrong appId!");
}
return secretKey;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche

blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer,
blockManagerSize, conf, mapOutputTracker, shuffleManager,
new NioBlockTransferService(conf, securityMgr))
new NioBlockTransferService(conf, securityMgr), securityMgr)
blockManager.initialize("app-id")

tempDirectory = Files.createTempDir()
Expand Down