Skip to content

Commit eb9f065

Browse files
committed
Improve documentation and add end-to-end test at Spark-level
1 parent a6b95f1 commit eb9f065

File tree

3 files changed

+172
-7
lines changed

3 files changed

+172
-7
lines changed

core/src/main/scala/org/apache/spark/SecurityManager.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ import org.apache.spark.network.sasl.SecretKeyHolder
8585
* Authenticator installed in the SecurityManager to how it does the authentication
8686
* and in this case gets the user name and password from the request.
8787
*
88-
* - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously
88+
* - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously
8989
* exchange messages. For this we use the Java SASL
9090
* (Simple Authentication and Security Layer) API and again use DIGEST-MD5
9191
* as the authentication mechanism. This means the shared secret is not passed
@@ -99,7 +99,7 @@ import org.apache.spark.network.sasl.SecretKeyHolder
9999
* of protection they want. If we support those, the messages will also have to
100100
* be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's.
101101
*
102-
* Since the connectionManager does asynchronous messages passing, the SASL
102+
* Since the NioBlockTransferService does asynchronous messages passing, the SASL
103103
* authentication is a bit more complex. A ConnectionManager can be both a client
104104
* and a Server, so for a particular connection is has to determine what to do.
105105
* A ConnectionId was added to be able to track connections and is used to
@@ -108,6 +108,10 @@ import org.apache.spark.network.sasl.SecretKeyHolder
108108
* and waits for the response from the server and does the handshake before sending
109109
* the real message.
110110
*
111+
* The NettyBlockTransferService ensures that SASL authentication is performed
112+
* synchronously prior to any other communication on a connection. This is done in
113+
* SaslClientBootstrap on the client side and SaslRpcHandler on the server side.
114+
*
111115
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
112116
* can be used. Yarn requires a specific AmIpFilter be installed for security to work
113117
* properly. For non-Yarn deployments, users can write a filter to go through a
@@ -347,6 +351,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with
347351

348352
override def getSecretKey(appId: String): String = {
349353
val myAppId = sparkConf.getAppId
354+
println("App id: " + appId + " / " + myAppId)
350355
require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}")
351356
getSecretKey()
352357
}
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.netty
19+
20+
import java.nio._
21+
import java.util.concurrent.TimeUnit
22+
23+
import scala.concurrent.duration._
24+
import scala.concurrent.{Await, Promise}
25+
import scala.util.{Failure, Success, Try}
26+
27+
import org.apache.commons.io.IOUtils
28+
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
29+
import org.apache.spark.network.shuffle.BlockFetchingListener
30+
import org.apache.spark.network.{BlockDataManager, BlockTransferService}
31+
import org.apache.spark.storage.{BlockId, ShuffleBlockId}
32+
import org.apache.spark.{SecurityManager, SparkConf}
33+
import org.mockito.Mockito._
34+
import org.scalatest.mock.MockitoSugar
35+
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers}
36+
37+
class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers {
38+
test("security default off") {
39+
testConnection(new SparkConf, new SparkConf) match {
40+
case Success(_) => // expected
41+
case Failure(t) => fail(t)
42+
}
43+
}
44+
45+
test("security on same password") {
46+
val conf = new SparkConf()
47+
.set("spark.authenticate", "true")
48+
.set("spark.authenticate.secret", "good")
49+
.set("spark.app.id", "app-id")
50+
testConnection(conf, conf) match {
51+
case Success(_) => // expected
52+
case Failure(t) => fail(t)
53+
}
54+
}
55+
56+
test("security on mismatch password") {
57+
val conf0 = new SparkConf()
58+
.set("spark.authenticate", "true")
59+
.set("spark.authenticate.secret", "good")
60+
.set("spark.app.id", "app-id")
61+
val conf1 = conf0.clone.set("spark.authenticate.secret", "bad")
62+
testConnection(conf0, conf1) match {
63+
case Success(_) => fail("Should have failed")
64+
case Failure(t) => t.getMessage should include ("Mismatched response")
65+
}
66+
}
67+
68+
test("security mismatch auth off on server") {
69+
val conf0 = new SparkConf()
70+
.set("spark.authenticate", "true")
71+
.set("spark.authenticate.secret", "good")
72+
.set("spark.app.id", "app-id")
73+
val conf1 = conf0.clone.set("spark.authenticate", "false")
74+
testConnection(conf0, conf1) match {
75+
case Success(_) => fail("Should have failed")
76+
case Failure(t) => // any funny error may occur, sever will interpret SASL token as RPC
77+
}
78+
}
79+
80+
test("security mismatch auth off on client") {
81+
val conf0 = new SparkConf()
82+
.set("spark.authenticate", "false")
83+
.set("spark.authenticate.secret", "good")
84+
.set("spark.app.id", "app-id")
85+
val conf1 = conf0.clone.set("spark.authenticate", "true")
86+
testConnection(conf0, conf1) match {
87+
case Success(_) => fail("Should have failed")
88+
case Failure(t) => t.getMessage should include ("Expected SaslMessage")
89+
}
90+
}
91+
92+
test("security mismatch app ids") {
93+
val conf0 = new SparkConf()
94+
.set("spark.authenticate", "true")
95+
.set("spark.authenticate.secret", "good")
96+
.set("spark.app.id", "app-id")
97+
val conf1 = conf0.clone.set("spark.app.id", "other-id")
98+
testConnection(conf0, conf1) match {
99+
case Success(_) => fail("Should have failed")
100+
case Failure(t) => t.getMessage should include ("SASL appId app-id did not match")
101+
}
102+
}
103+
104+
/**
105+
* Creates two servers with different configurations and sees if they can talk.
106+
* Returns Success() if they can transfer a block, and Failure() if the block transfer was failed
107+
* properly. We will throw an out-of-band exception if something other than that goes wrong.
108+
*/
109+
private def testConnection(conf0: SparkConf, conf1: SparkConf): Try[Unit] = {
110+
val blockManager = mock[BlockDataManager]
111+
val blockId = ShuffleBlockId(0, 1, 2)
112+
val blockString = "Hello, world!"
113+
val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap(blockString.getBytes))
114+
when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer)
115+
116+
val securityManager0 = new SecurityManager(conf0)
117+
val exec0 = new NettyBlockTransferService(conf0, securityManager0)
118+
exec0.init(blockManager)
119+
120+
val securityManager1 = new SecurityManager(conf1)
121+
val exec1 = new NettyBlockTransferService(conf1, securityManager1)
122+
exec1.init(blockManager)
123+
124+
val result = fetchBlock(exec0, exec1, "1", blockId) match {
125+
case Success(buf) =>
126+
IOUtils.toString(buf.createInputStream()) should equal(blockString)
127+
buf.release()
128+
Success()
129+
case Failure(t) =>
130+
Failure(t)
131+
}
132+
exec0.close()
133+
exec1.close()
134+
result
135+
}
136+
137+
/** Synchronously fetches a single block, acting as the given executor fetching from another. */
138+
private def fetchBlock(
139+
self: BlockTransferService,
140+
from: BlockTransferService,
141+
execId: String,
142+
blockId: BlockId): Try[ManagedBuffer] = {
143+
144+
val promise = Promise[ManagedBuffer]()
145+
146+
self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString),
147+
new BlockFetchingListener {
148+
override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
149+
promise.failure(exception)
150+
}
151+
152+
override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
153+
promise.success(data.retain())
154+
}
155+
})
156+
157+
Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS))
158+
promise.future.value.get
159+
}
160+
}
161+

network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.util.Map;
3232

3333
import com.google.common.base.Charsets;
34+
import com.google.common.base.Preconditions;
3435
import com.google.common.base.Throwables;
3536
import com.google.common.collect.ImmutableMap;
3637
import com.google.common.io.BaseEncoding;
@@ -157,15 +158,13 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback
157158

158159
/* Encode a byte[] identifier as a Base64-encoded string. */
159160
public static String encodeIdentifier(String identifier) {
161+
Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled");
160162
return BaseEncoding.base64().encode(identifier.getBytes(Charsets.UTF_8));
161163
}
162164

163165
/** Encode a password as a base64-encoded char[] array. */
164166
public static char[] encodePassword(String password) {
165-
if (password != null) {
166-
return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray();
167-
} else {
168-
return new char[0];
169-
}
167+
Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled");
168+
return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray();
170169
}
171170
}

0 commit comments

Comments
 (0)