Skip to content

Commit a6e9bbc

Browse files
zsxwingcmonkey
authored andcommitted
[SPARK-19365][CORE] Optimize RequestMessage serialization
## What changes were proposed in this pull request? Right now Netty PRC serializes `RequestMessage` using Java serialization, and the size of a single message (e.g., RequestMessage(..., "hello")`) is almost 1KB. This PR optimizes it by serializing `RequestMessage` manually (eliminate unnecessary information from most messages, e.g., class names of `RequestMessage`, `NettyRpcEndpointRef`, ...), and reduces the above message size to 100+ bytes. ## How was this patch tested? Jenkins I did a simple test to measure the improvement: Before ``` $ bin/spark-shell --master local-cluster[1,4,1024] ... scala> for (i <- 1 to 10) { | val start = System.nanoTime | val s = sc.parallelize(1 to 1000000, 10 * 1000).count() | val end = System.nanoTime | println(s"$i\t" + ((end - start)/1000/1000)) | } 1 6830 2 4353 3 3322 4 3107 5 3235 6 3139 7 3156 8 3166 9 3091 10 3029 ``` After: ``` $ bin/spark-shell --master local-cluster[1,4,1024] ... scala> for (i <- 1 to 10) { | val start = System.nanoTime | val s = sc.parallelize(1 to 1000000, 10 * 1000).count() | val end = System.nanoTime | println(s"$i\t" + ((end - start)/1000/1000)) | } 1 6431 2 3643 3 2913 4 2679 5 2760 6 2710 7 2747 8 2793 9 2679 10 2651 ``` I also captured the TCP packets for this test. Before this patch, the total size of TCP packets is ~1.5GB. After it, it reduces to ~1.2GB. Author: Shixiong Zhu <[email protected]> Closes apache#16706 from zsxwing/rpc-opt.
1 parent b5df92d commit a6e9bbc

File tree

4 files changed

+132
-27
lines changed

4 files changed

+132
-27
lines changed

core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ import org.apache.spark.SparkException
2525
* The `rpcAddress` may be null, in which case the endpoint is registered via a client-only
2626
* connection and can only be reached via the client that sent the endpoint reference.
2727
*
28-
* @param rpcAddress The socket address of the endpoint.
28+
* @param rpcAddress The socket address of the endpoint. It's `null` when this address pointing to
29+
* an endpoint in a client `NettyRpcEnv`.
2930
* @param name Name of the endpoint.
3031
*/
31-
private[spark] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) {
32+
private[spark] case class RpcEndpointAddress(rpcAddress: RpcAddress, name: String) {
3233

3334
require(name != null, "RpcEndpoint name must be provided.")
3435

core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala

Lines changed: 96 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap
3737
import org.apache.spark.network.netty.SparkTransportConf
3838
import org.apache.spark.network.server._
3939
import org.apache.spark.rpc._
40-
import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance}
41-
import org.apache.spark.util.{ThreadUtils, Utils}
40+
import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance, SerializationStream}
41+
import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, ThreadUtils, Utils}
4242

4343
private[netty] class NettyRpcEnv(
4444
val conf: SparkConf,
@@ -189,7 +189,7 @@ private[netty] class NettyRpcEnv(
189189
}
190190
} else {
191191
// Message to a remote RPC endpoint.
192-
postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message)))
192+
postToOutbox(message.receiver, OneWayOutboxMessage(message.serialize(this)))
193193
}
194194
}
195195

@@ -224,7 +224,7 @@ private[netty] class NettyRpcEnv(
224224
}(ThreadUtils.sameThread)
225225
dispatcher.postLocalMessage(message, p)
226226
} else {
227-
val rpcMessage = RpcOutboxMessage(serialize(message),
227+
val rpcMessage = RpcOutboxMessage(message.serialize(this),
228228
onFailure,
229229
(client, response) => onSuccess(deserialize[Any](client, response)))
230230
postToOutbox(message.receiver, rpcMessage)
@@ -253,6 +253,13 @@ private[netty] class NettyRpcEnv(
253253
javaSerializerInstance.serialize(content)
254254
}
255255

256+
/**
257+
* Returns [[SerializationStream]] that forwards the serialized bytes to `out`.
258+
*/
259+
private[netty] def serializeStream(out: OutputStream): SerializationStream = {
260+
javaSerializerInstance.serializeStream(out)
261+
}
262+
256263
private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = {
257264
NettyRpcEnv.currentClient.withValue(client) {
258265
deserialize { () =>
@@ -480,16 +487,13 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
480487
*/
481488
private[netty] class NettyRpcEndpointRef(
482489
@transient private val conf: SparkConf,
483-
endpointAddress: RpcEndpointAddress,
484-
@transient @volatile private var nettyEnv: NettyRpcEnv)
485-
extends RpcEndpointRef(conf) with Serializable with Logging {
490+
private val endpointAddress: RpcEndpointAddress,
491+
@transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) {
486492

487493
@transient @volatile var client: TransportClient = _
488494

489-
private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null
490-
private val _name = endpointAddress.name
491-
492-
override def address: RpcAddress = if (_address != null) _address.rpcAddress else null
495+
override def address: RpcAddress =
496+
if (endpointAddress.rpcAddress != null) endpointAddress.rpcAddress else null
493497

494498
private def readObject(in: ObjectInputStream): Unit = {
495499
in.defaultReadObject()
@@ -501,34 +505,103 @@ private[netty] class NettyRpcEndpointRef(
501505
out.defaultWriteObject()
502506
}
503507

504-
override def name: String = _name
508+
override def name: String = endpointAddress.name
505509

506510
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
507-
nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout)
511+
nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout)
508512
}
509513

510514
override def send(message: Any): Unit = {
511515
require(message != null, "Message is null")
512-
nettyEnv.send(RequestMessage(nettyEnv.address, this, message))
516+
nettyEnv.send(new RequestMessage(nettyEnv.address, this, message))
513517
}
514518

515-
override def toString: String = s"NettyRpcEndpointRef(${_address})"
516-
517-
def toURI: URI = new URI(_address.toString)
519+
override def toString: String = s"NettyRpcEndpointRef(${endpointAddress})"
518520

519521
final override def equals(that: Any): Boolean = that match {
520-
case other: NettyRpcEndpointRef => _address == other._address
522+
case other: NettyRpcEndpointRef => endpointAddress == other.endpointAddress
521523
case _ => false
522524
}
523525

524-
final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode()
526+
final override def hashCode(): Int =
527+
if (endpointAddress == null) 0 else endpointAddress.hashCode()
525528
}
526529

527530
/**
528531
* The message that is sent from the sender to the receiver.
532+
*
533+
* @param senderAddress the sender address. It's `null` if this message is from a client
534+
* `NettyRpcEnv`.
535+
* @param receiver the receiver of this message.
536+
* @param content the message content.
529537
*/
530-
private[netty] case class RequestMessage(
531-
senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any)
538+
private[netty] class RequestMessage(
539+
val senderAddress: RpcAddress,
540+
val receiver: NettyRpcEndpointRef,
541+
val content: Any) {
542+
543+
/** Manually serialize [[RequestMessage]] to minimize the size. */
544+
def serialize(nettyEnv: NettyRpcEnv): ByteBuffer = {
545+
val bos = new ByteBufferOutputStream()
546+
val out = new DataOutputStream(bos)
547+
try {
548+
writeRpcAddress(out, senderAddress)
549+
writeRpcAddress(out, receiver.address)
550+
out.writeUTF(receiver.name)
551+
val s = nettyEnv.serializeStream(out)
552+
try {
553+
s.writeObject(content)
554+
} finally {
555+
s.close()
556+
}
557+
} finally {
558+
out.close()
559+
}
560+
bos.toByteBuffer
561+
}
562+
563+
private def writeRpcAddress(out: DataOutputStream, rpcAddress: RpcAddress): Unit = {
564+
if (rpcAddress == null) {
565+
out.writeBoolean(false)
566+
} else {
567+
out.writeBoolean(true)
568+
out.writeUTF(rpcAddress.host)
569+
out.writeInt(rpcAddress.port)
570+
}
571+
}
572+
573+
override def toString: String = s"RequestMessage($senderAddress, $receiver, $content)"
574+
}
575+
576+
private[netty] object RequestMessage {
577+
578+
private def readRpcAddress(in: DataInputStream): RpcAddress = {
579+
val hasRpcAddress = in.readBoolean()
580+
if (hasRpcAddress) {
581+
RpcAddress(in.readUTF(), in.readInt())
582+
} else {
583+
null
584+
}
585+
}
586+
587+
def apply(nettyEnv: NettyRpcEnv, client: TransportClient, bytes: ByteBuffer): RequestMessage = {
588+
val bis = new ByteBufferInputStream(bytes)
589+
val in = new DataInputStream(bis)
590+
try {
591+
val senderAddress = readRpcAddress(in)
592+
val endpointAddress = RpcEndpointAddress(readRpcAddress(in), in.readUTF())
593+
val ref = new NettyRpcEndpointRef(nettyEnv.conf, endpointAddress, nettyEnv)
594+
ref.client = client
595+
new RequestMessage(
596+
senderAddress,
597+
ref,
598+
// The remaining bytes in `bytes` are the message content.
599+
nettyEnv.deserialize(client, bytes))
600+
} finally {
601+
in.close()
602+
}
603+
}
604+
}
532605

533606
/**
534607
* A response that indicates some failure happens in the receiver side.
@@ -574,10 +647,10 @@ private[netty] class NettyRpcHandler(
574647
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
575648
assert(addr != null)
576649
val clientAddr = RpcAddress(addr.getHostString, addr.getPort)
577-
val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)
650+
val requestMessage = RequestMessage(nettyEnv, client, message)
578651
if (requestMessage.senderAddress == null) {
579652
// Create a new message with the socket address of the client as the sender.
580-
RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)
653+
new RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)
581654
} else {
582655
// The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for
583656
// the listening address

core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
package org.apache.spark.rpc.netty
1919

20+
import org.scalatest.mock.MockitoSugar
21+
2022
import org.apache.spark._
23+
import org.apache.spark.network.client.TransportClient
2124
import org.apache.spark.rpc._
2225

23-
class NettyRpcEnvSuite extends RpcEnvSuite {
26+
class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar {
2427

2528
override def createRpcEnv(
2629
conf: SparkConf,
@@ -53,4 +56,32 @@ class NettyRpcEnvSuite extends RpcEnvSuite {
5356
}
5457
}
5558

59+
test("RequestMessage serialization") {
60+
def assertRequestMessageEquals(expected: RequestMessage, actual: RequestMessage): Unit = {
61+
assert(expected.senderAddress === actual.senderAddress)
62+
assert(expected.receiver === actual.receiver)
63+
assert(expected.content === actual.content)
64+
}
65+
66+
val nettyEnv = env.asInstanceOf[NettyRpcEnv]
67+
val client = mock[TransportClient]
68+
val senderAddress = RpcAddress("locahost", 12345)
69+
val receiverAddress = RpcEndpointAddress("localhost", 54321, "test")
70+
val receiver = new NettyRpcEndpointRef(nettyEnv.conf, receiverAddress, nettyEnv)
71+
72+
val msg = new RequestMessage(senderAddress, receiver, "foo")
73+
assertRequestMessageEquals(
74+
msg,
75+
RequestMessage(nettyEnv, client, msg.serialize(nettyEnv)))
76+
77+
val msg2 = new RequestMessage(null, receiver, "foo")
78+
assertRequestMessageEquals(
79+
msg2,
80+
RequestMessage(nettyEnv, client, msg2.serialize(nettyEnv)))
81+
82+
val msg3 = new RequestMessage(senderAddress, receiver, null)
83+
assertRequestMessageEquals(
84+
msg3,
85+
RequestMessage(nettyEnv, client, msg3.serialize(nettyEnv)))
86+
}
5687
}

core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
3434
val env = mock(classOf[NettyRpcEnv])
3535
val sm = mock(classOf[StreamManager])
3636
when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any()))
37-
.thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null))
37+
.thenReturn(new RequestMessage(RpcAddress("localhost", 12345), null, null))
3838

3939
test("receive") {
4040
val dispatcher = mock(classOf[Dispatcher])

0 commit comments

Comments
 (0)