Skip to content

Commit 05e9a0c

Browse files
committed
Address
1 parent 195adc7 commit 05e9a0c

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap
3737
import org.apache.spark.network.netty.SparkTransportConf
3838
import org.apache.spark.network.server._
3939
import org.apache.spark.rpc._
40-
import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance}
40+
import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance, SerializationStream}
4141
import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, ThreadUtils, Utils}
4242

4343
private[netty] class NettyRpcEnv(
@@ -253,6 +253,13 @@ private[netty] class NettyRpcEnv(
253253
javaSerializerInstance.serialize(content)
254254
}
255255

256+
/**
257+
* Returns [[SerializationStream]] that forwards the serialized bytes to `out`.
258+
*/
259+
private[netty] def serializeStream(out: OutputStream): SerializationStream = {
260+
javaSerializerInstance.serializeStream(out)
261+
}
262+
256263
private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = {
257264
NettyRpcEnv.currentClient.withValue(client) {
258265
deserialize { () =>
@@ -530,19 +537,23 @@ private[netty] class NettyRpcEndpointRef(
530537
*/
531538
private[netty] class RequestMessage(
532539
val senderAddress: RpcAddress,
533-
val receiver: NettyRpcEndpointRef, val content: Any) {
540+
val receiver: NettyRpcEndpointRef,
541+
val content: Any) {
534542

535-
/** Manually serialize [[RequestMessage]] to minimize the size of bytes. */
543+
/** Manually serialize [[RequestMessage]] to minimize the size. */
536544
def serialize(nettyEnv: NettyRpcEnv): ByteBuffer = {
537545
val bos = new ByteBufferOutputStream()
538546
val out = new DataOutputStream(bos)
539547
try {
540548
writeRpcAddress(out, senderAddress)
541549
writeRpcAddress(out, receiver.address)
542550
out.writeUTF(receiver.name)
543-
val contentBytes = nettyEnv.serialize(content)
544-
assert(contentBytes.hasArray)
545-
out.write(contentBytes.array, contentBytes.arrayOffset, contentBytes.remaining)
551+
val s = nettyEnv.serializeStream(out)
552+
try {
553+
s.writeObject(content)
554+
} finally {
555+
s.close()
556+
}
546557
} finally {
547558
out.close()
548559
}

core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
package org.apache.spark.rpc.netty
1919

20+
import org.scalatest.mock.MockitoSugar
21+
2022
import org.apache.spark._
23+
import org.apache.spark.network.client.TransportClient
2124
import org.apache.spark.rpc._
2225

23-
class NettyRpcEnvSuite extends RpcEnvSuite {
26+
class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar {
2427

2528
override def createRpcEnv(
2629
conf: SparkConf,
@@ -53,4 +56,32 @@ class NettyRpcEnvSuite extends RpcEnvSuite {
5356
}
5457
}
5558

59+
test("RequestMessage serialization") {
60+
def assertRequestMessageEquals(expected: RequestMessage, actual: RequestMessage): Unit = {
61+
assert(expected.senderAddress === actual.senderAddress)
62+
assert(expected.receiver === actual.receiver)
63+
assert(expected.content === actual.content)
64+
}
65+
66+
val nettyEnv = env.asInstanceOf[NettyRpcEnv]
67+
val client = mock[TransportClient]
68+
val senderAddress = RpcAddress("locahost", 12345)
69+
val receiverAddress = RpcEndpointAddress("localhost", 54321, "test")
70+
val receiver = new NettyRpcEndpointRef(nettyEnv.conf, receiverAddress, nettyEnv)
71+
72+
val msg = new RequestMessage(senderAddress, receiver, "foo")
73+
assertRequestMessageEquals(
74+
msg,
75+
RequestMessage(nettyEnv, client, msg.serialize(nettyEnv)))
76+
77+
val msg2 = new RequestMessage(null, receiver, "foo")
78+
assertRequestMessageEquals(
79+
msg2,
80+
RequestMessage(nettyEnv, client, msg2.serialize(nettyEnv)))
81+
82+
val msg3 = new RequestMessage(senderAddress, receiver, null)
83+
assertRequestMessageEquals(
84+
msg3,
85+
RequestMessage(nettyEnv, client, msg3.serialize(nettyEnv)))
86+
}
5687
}

0 commit comments

Comments
 (0)