Skip to content

Commit 538f2a3

Browse files
committed
[SPARK-4187] [Core] Switch to binary protocol for external shuffle service messages
This PR elimiantes the network package's usage of the Java serializer and replaces it with Encodable, which is a lightweight binary protocol. Each message is preceded by a type id, which will allow us to change messages (by only adding new ones), or to change the format entirely by switching to a special id (such as -1). This protocol has the advantage over Java that we can guarantee that messages will remain compatible across compiled versions and JVMs, though it does not provide a clean way to do schema migration. In the future, it may be good to use a more heavy-weight serialization format like protobuf, thrift, or avro, but these all add several dependencies which are unnecessary at the present time.
1 parent f165b2b commit 538f2a3

29 files changed

+642
-282
lines changed

core/src/main/scala/org/apache/spark/network/BlockTransferService.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
7373
def uploadBlock(
7474
hostname: String,
7575
port: Int,
76+
execId: String,
7677
blockId: BlockId,
7778
blockData: ManagedBuffer,
7879
level: StorageLevel): Future[Unit]
@@ -110,9 +111,10 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
110111
def uploadBlockSync(
111112
hostname: String,
112113
port: Int,
114+
execId: String,
113115
blockId: BlockId,
114116
blockData: ManagedBuffer,
115117
level: StorageLevel): Unit = {
116-
Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf)
118+
Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf)
117119
}
118120
}

core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,10 @@ import org.apache.spark.network.BlockDataManager
2626
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
2727
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
2828
import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
29-
import org.apache.spark.network.shuffle.ShuffleStreamHandle
29+
import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock}
3030
import org.apache.spark.serializer.Serializer
3131
import org.apache.spark.storage.{BlockId, StorageLevel}
3232

33-
object NettyMessages {
34-
/** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */
35-
case class OpenBlocks(blockIds: Seq[BlockId])
36-
37-
/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */
38-
case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel)
39-
}
40-
4133
/**
4234
* Serves requests to open blocks by simply registering one chunk per block requested.
4335
* Handles opening and uploading arbitrary BlockManager blocks.
@@ -50,28 +42,29 @@ class NettyBlockRpcServer(
5042
blockManager: BlockDataManager)
5143
extends RpcHandler with Logging {
5244

53-
import NettyMessages._
54-
5545
private val streamManager = new OneForOneStreamManager()
5646

5747
override def receive(
5848
client: TransportClient,
5949
messageBytes: Array[Byte],
6050
responseContext: RpcResponseCallback): Unit = {
61-
val ser = serializer.newInstance()
62-
val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes))
51+
val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes)
6352
logTrace(s"Received request: $message")
6453

6554
message match {
66-
case OpenBlocks(blockIds) =>
67-
val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData)
55+
case openBlocks: OpenBlocks =>
56+
val blocks: Seq[ManagedBuffer] =
57+
openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
6858
val streamId = streamManager.registerStream(blocks.iterator)
6959
logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
70-
responseContext.onSuccess(
71-
ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array())
60+
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
7261

73-
case UploadBlock(blockId, blockData, level) =>
74-
blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level)
62+
case uploadBlock: UploadBlock =>
63+
// StorageLevel is serialized as bytes using our JavaSerializer.
64+
val level: StorageLevel =
65+
serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
66+
val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
67+
blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level)
7568
responseContext.onSuccess(new Array[Byte](0))
7669
}
7770
}

core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ import org.apache.spark.{SecurityManager, SparkConf}
2424
import org.apache.spark.network._
2525
import org.apache.spark.network.buffer.ManagedBuffer
2626
import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
27-
import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
2827
import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
2928
import org.apache.spark.network.server._
3029
import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
30+
import org.apache.spark.network.shuffle.protocol.UploadBlock
3131
import org.apache.spark.serializer.JavaSerializer
3232
import org.apache.spark.storage.{BlockId, StorageLevel}
3333
import org.apache.spark.util.Utils
@@ -46,6 +46,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
4646
private[this] var transportContext: TransportContext = _
4747
private[this] var server: TransportServer = _
4848
private[this] var clientFactory: TransportClientFactory = _
49+
private[this] var appId: String = _
4950

5051
override def init(blockDataManager: BlockDataManager): Unit = {
5152
val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = {
@@ -60,6 +61,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
6061
transportContext = new TransportContext(transportConf, rpcHandler)
6162
clientFactory = transportContext.createClientFactory(bootstrap.toList)
6263
server = transportContext.createServer()
64+
appId = conf.getAppId
6365
logInfo("Server created on " + server.getPort)
6466
}
6567

@@ -74,8 +76,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
7476
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
7577
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
7678
val client = clientFactory.createClient(host, port)
77-
new OneForOneBlockFetcher(client, blockIds.toArray, listener)
78-
.start(OpenBlocks(blockIds.map(BlockId.apply)))
79+
new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
7980
}
8081
}
8182

@@ -101,12 +102,17 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
101102
override def uploadBlock(
102103
hostname: String,
103104
port: Int,
105+
execId: String,
104106
blockId: BlockId,
105107
blockData: ManagedBuffer,
106108
level: StorageLevel): Future[Unit] = {
107109
val result = Promise[Unit]()
108110
val client = clientFactory.createClient(hostname, port)
109111

112+
// StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded
113+
// using our binary protocol.
114+
val levelBytes = serializer.newInstance().serialize(level).array()
115+
110116
// Convert or copy nio buffer into array in order to serialize it.
111117
val nioBuffer = blockData.nioByteBuffer()
112118
val array = if (nioBuffer.hasArray) {
@@ -117,8 +123,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
117123
data
118124
}
119125

120-
val ser = serializer.newInstance()
121-
client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(),
126+
client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray,
122127
new RpcResponseCallback {
123128
override def onSuccess(response: Array[Byte]): Unit = {
124129
logTrace(s"Successfully uploaded block $blockId")

core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
137137
override def uploadBlock(
138138
hostname: String,
139139
port: Int,
140+
execId: String,
140141
blockId: BlockId,
141142
blockData: ManagedBuffer,
142143
level: StorageLevel)

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ import org.apache.spark.io.CompressionCodec
3535
import org.apache.spark.network._
3636
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
3737
import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService}
38-
import org.apache.spark.network.shuffle.{ExecutorShuffleInfo, ExternalShuffleClient}
38+
import org.apache.spark.network.shuffle.ExternalShuffleClient
39+
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
3940
import org.apache.spark.network.util.{ConfigProvider, TransportConf}
4041
import org.apache.spark.serializer.Serializer
4142
import org.apache.spark.shuffle.ShuffleManager
@@ -939,7 +940,7 @@ private[spark] class BlockManager(
939940
data.rewind()
940941
logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer")
941942
blockTransferService.uploadBlockSync(
942-
peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel)
943+
peer.host, peer.port, peer.executorId, blockId, new NioManagedBuffer(data), tLevel)
943944
logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms"
944945
.format(System.currentTimeMillis - onePeerStartTime))
945946
peersReplicatedTo += peer

network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,23 +38,19 @@ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) {
3838

3939
@Override
4040
public int encodedLength() {
41-
return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length;
41+
return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString);
4242
}
4343

4444
@Override
4545
public void encode(ByteBuf buf) {
4646
streamChunkId.encode(buf);
47-
byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
48-
buf.writeInt(errorBytes.length);
49-
buf.writeBytes(errorBytes);
47+
Encoders.Strings.encode(buf, errorString);
5048
}
5149

5250
public static ChunkFetchFailure decode(ByteBuf buf) {
5351
StreamChunkId streamChunkId = StreamChunkId.decode(buf);
54-
int numErrorStringBytes = buf.readInt();
55-
byte[] errorBytes = new byte[numErrorStringBytes];
56-
buf.readBytes(errorBytes);
57-
return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8));
52+
String errorString = Encoders.Strings.decode(buf);
53+
return new ChunkFetchFailure(streamChunkId, errorString);
5854
}
5955

6056
@Override
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.protocol;
19+
20+
21+
import com.google.common.base.Charsets;
22+
import io.netty.buffer.ByteBuf;
23+
import io.netty.buffer.Unpooled;
24+
25+
/** Provides a canonical set of Encoders for simple types. */
26+
public class Encoders {
27+
28+
/** Strings are encoded with their length followed by UTF-8 bytes. */
29+
public static class Strings {
30+
public static int encodedLength(String s) {
31+
return 4 + s.getBytes(Charsets.UTF_8).length;
32+
}
33+
34+
public static void encode(ByteBuf buf, String s) {
35+
byte[] bytes = s.getBytes(Charsets.UTF_8);
36+
buf.writeInt(bytes.length);
37+
buf.writeBytes(bytes);
38+
}
39+
40+
public static String decode(ByteBuf buf) {
41+
int length = buf.readInt();
42+
byte[] bytes = new byte[length];
43+
buf.readBytes(bytes);
44+
return new String(bytes, Charsets.UTF_8);
45+
}
46+
}
47+
48+
/** Byte arrays are encoded with their length followed by bytes. */
49+
public static class ByteArrays {
50+
public static int encodedLength(byte[] arr) {
51+
return 4 + arr.length;
52+
}
53+
54+
public static void encode(ByteBuf buf, byte[] arr) {
55+
buf.writeInt(arr.length);
56+
buf.writeBytes(arr);
57+
}
58+
59+
public static byte[] decode(ByteBuf buf) {
60+
int length = buf.readInt();
61+
byte[] bytes = new byte[length];
62+
buf.readBytes(bytes);
63+
return bytes;
64+
}
65+
}
66+
67+
/** String arrays are encoded with the number of strings followed by per-String encoding. */
68+
public static class StringArrays {
69+
public static int encodedLength(String[] strings) {
70+
int totalLength = 4;
71+
for (String s : strings) {
72+
totalLength += Strings.encodedLength(s);
73+
}
74+
return totalLength;
75+
}
76+
77+
public static void encode(ByteBuf buf, String[] strings) {
78+
buf.writeInt(strings.length);
79+
for (String s : strings) {
80+
Strings.encode(buf, s);
81+
}
82+
}
83+
84+
public static String[] decode(ByteBuf buf) {
85+
int numStrings = buf.readInt();
86+
String[] strings = new String[numStrings];
87+
for (int i = 0; i < strings.length; i ++) {
88+
strings[i] = Strings.decode(buf);
89+
}
90+
return strings;
91+
}
92+
}
93+
}

network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,19 @@ public RpcFailure(long requestId, String errorString) {
3636

3737
@Override
3838
public int encodedLength() {
39-
return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length;
39+
return 8 + Encoders.Strings.encodedLength(errorString);
4040
}
4141

4242
@Override
4343
public void encode(ByteBuf buf) {
4444
buf.writeLong(requestId);
45-
byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
46-
buf.writeInt(errorBytes.length);
47-
buf.writeBytes(errorBytes);
45+
Encoders.Strings.encode(buf, errorString);
4846
}
4947

5048
public static RpcFailure decode(ByteBuf buf) {
5149
long requestId = buf.readLong();
52-
int numErrorStringBytes = buf.readInt();
53-
byte[] errorBytes = new byte[numErrorStringBytes];
54-
buf.readBytes(errorBytes);
55-
return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8));
50+
String errorString = Encoders.Strings.decode(buf);
51+
return new RpcFailure(requestId, errorString);
5652
}
5753

5854
@Override

network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,18 @@ public RpcRequest(long requestId, byte[] message) {
4444

4545
@Override
4646
public int encodedLength() {
47-
return 8 + 4 + message.length;
47+
return 8 + Encoders.ByteArrays.encodedLength(message);
4848
}
4949

5050
@Override
5151
public void encode(ByteBuf buf) {
5252
buf.writeLong(requestId);
53-
buf.writeInt(message.length);
54-
buf.writeBytes(message);
53+
Encoders.ByteArrays.encode(buf, message);
5554
}
5655

5756
public static RpcRequest decode(ByteBuf buf) {
5857
long requestId = buf.readLong();
59-
int messageLen = buf.readInt();
60-
byte[] message = new byte[messageLen];
61-
buf.readBytes(message);
58+
byte[] message = Encoders.ByteArrays.decode(buf);
6259
return new RpcRequest(requestId, message);
6360
}
6461

network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,17 @@ public RpcResponse(long requestId, byte[] response) {
3636
public Type type() { return Type.RpcResponse; }
3737

3838
@Override
39-
public int encodedLength() { return 8 + 4 + response.length; }
39+
public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); }
4040

4141
@Override
4242
public void encode(ByteBuf buf) {
4343
buf.writeLong(requestId);
44-
buf.writeInt(response.length);
45-
buf.writeBytes(response);
44+
Encoders.ByteArrays.encode(buf, response);
4645
}
4746

4847
public static RpcResponse decode(ByteBuf buf) {
4948
long requestId = buf.readLong();
50-
int responseLen = buf.readInt();
51-
byte[] response = new byte[responseLen];
52-
buf.readBytes(response);
49+
byte[] response = Encoders.ByteArrays.decode(buf);
5350
return new RpcResponse(requestId, response);
5451
}
5552

0 commit comments

Comments
 (0)