Skip to content

Commit 60c2edf

Browse files
committed
Beefed up server/client integration tests.
1 parent 38d88d5 commit 60c2edf

File tree

4 files changed

+54
-68
lines changed

4 files changed

+54
-68
lines changed

core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] wi
3636
override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) {
3737
val totalLen = in.readInt()
3838
val blockIdLen = in.readInt()
39-
val blockIdBytes = new Array[Byte](blockIdLen)
39+
val blockIdBytes = new Array[Byte](math.abs(blockIdLen))
4040
in.readBytes(blockIdBytes)
4141
val blockId = new String(blockIdBytes)
42-
val blockLen = math.abs(totalLen) - blockIdLen - 4
42+
val blockLen = totalLen - math.abs(blockIdLen) - 4
4343

4444
def server = ctx.channel.remoteAddress.toString
4545

46-
// totalLen is negative when it is an error message.
47-
if (totalLen < 0) {
46+
// blockIdLen is negative when it is an error message.
47+
if (blockIdLen < 0) {
4848
val errorMessageBytes = new Array[Byte](blockLen)
4949
in.readBytes(errorMessageBytes)
5050
val errorMsg = new String(errorMessageBytes)

core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] {
3333
msg.error match {
3434
case Some(errorMsg) =>
3535
val errorBytes = errorMsg.getBytes
36-
out.writeInt(-(4 + blockId.length + errorBytes.size))
37-
out.writeInt(blockId.length)
36+
out.writeInt(4 + blockId.length + errorBytes.size)
37+
out.writeInt(-blockId.length)
3838
out.writeBytes(blockId)
3939
out.writeBytes(errorBytes)
4040
case None =>

core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ import org.apache.spark.util.Utils
5454
* frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data.
5555
*
5656
* frame-length should not include the length of itself.
57-
* If frame-length is negative, then this is an error message rather than block-data. The real
57+
* If block-id-length is negative, then this is an error message rather than block-data. The real
5858
* length is the absolute value of the frame-length.
5959
*
6060
*/

core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala

Lines changed: 47 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
7575
clientFactory.stop()
7676
}
7777

78+
/** A ByteBuf for buffer_block */
7879
lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf)
7980

81+
/** A ByteBuf for file_block */
8082
lazy val fileBlockReference = {
8183
val bytes = new Array[Byte](testFile.length.toInt)
8284
val fp = new RandomAccessFile(testFile, "r")
@@ -85,84 +87,68 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
8587
Unpooled.wrappedBuffer(bytes, 10, testFile.length.toInt - 25)
8688
}
8789

88-
test("fetch a ByteBuffer block") {
90+
def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ReferenceCountedBuffer], Set[String]) =
91+
{
8992
val client = clientFactory.createClient(server.hostName, server.port)
9093
val sem = new Semaphore(0)
91-
var receivedBlockId: String = null
92-
var receivedBuffer = null.asInstanceOf[ReferenceCountedBuffer]
94+
val receivedBlockIds = Collections.synchronizedSet(new HashSet[String])
95+
val errorBlockIds = Collections.synchronizedSet(new HashSet[String])
96+
val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer])
9397

9498
client.fetchBlocks(
95-
Seq(bufferBlockId),
99+
blockIds,
96100
(blockId, buf) => {
97-
receivedBlockId = blockId
101+
receivedBlockIds.add(blockId)
98102
buf.retain()
99-
receivedBuffer = buf
103+
receivedBuffers.add(buf)
100104
sem.release()
101105
},
102-
(blockId, errorMsg) => sem.release()
106+
(blockId, errorMsg) => {
107+
errorBlockIds.add(blockId)
108+
sem.release()
109+
}
103110
)
104-
105-
// This should block until the blocks are fetched
106-
sem.acquire()
107-
108-
assert(receivedBlockId === bufferBlockId)
109-
assert(receivedBuffer.underlying == byteBufferBlockReference)
110-
receivedBuffer.release()
111+
sem.acquire(blockIds.size)
111112
client.close()
113+
(receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet)
112114
}
113115

114-
test("fetch a FileSegment block via zero-copy send") {
115-
val client = clientFactory.createClient(server.hostName, server.port)
116-
val sem = new Semaphore(0)
117-
var receivedBlockId: String = null
118-
var receivedBuffer = null.asInstanceOf[ReferenceCountedBuffer]
119-
120-
client.fetchBlocks(
121-
Seq(fileBlockId),
122-
(blockId, buf) => {
123-
receivedBlockId = blockId
124-
buf.retain()
125-
receivedBuffer = buf
126-
sem.release()
127-
},
128-
(blockId, errorMsg) => sem.release()
129-
)
116+
test("fetch a ByteBuffer block") {
117+
val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId))
118+
assert(blockIds === Set(bufferBlockId))
119+
assert(buffers.map(_.underlying) === Set(byteBufferBlockReference))
120+
assert(failBlockIds.isEmpty)
121+
buffers.foreach(_.release())
122+
}
130123

131-
// This should block until the blocks are fetched
132-
sem.acquire()
124+
test("fetch a FileSegment block via zero-copy send") {
125+
val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId))
126+
assert(blockIds === Set(fileBlockId))
127+
assert(buffers.map(_.underlying) === Set(fileBlockReference))
128+
assert(failBlockIds.isEmpty)
129+
buffers.foreach(_.release())
130+
}
133131

134-
assert(receivedBlockId === fileBlockId)
135-
assert(receivedBuffer.underlying == fileBlockReference)
136-
receivedBuffer.release()
137-
client.close()
132+
test("fetch a non-existent block") {
133+
val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block"))
134+
assert(blockIds.isEmpty)
135+
assert(buffers.isEmpty)
136+
assert(failBlockIds === Set("random-block"))
138137
}
139138

140139
test("fetch both ByteBuffer block and FileSegment block") {
141-
val client = clientFactory.createClient(server.hostName, server.port)
142-
val sem = new Semaphore(0)
143-
val receivedBlockIds = Collections.synchronizedSet(new HashSet[String])
144-
val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer])
145-
146-
client.fetchBlocks(
147-
Seq(bufferBlockId, fileBlockId),
148-
(blockId, buf) => {
149-
receivedBlockIds.add(blockId)
150-
buf.retain()
151-
receivedBuffers.add(buf)
152-
sem.release()
153-
},
154-
(blockId, errorMsg) => sem.release()
155-
)
156-
157-
sem.acquire(2)
158-
assert(receivedBlockIds.contains(bufferBlockId))
159-
assert(receivedBlockIds.contains(fileBlockId))
160-
161-
val byteBufferReference = byteBufferBlockReference
162-
val fileReference = fileBlockReference
140+
val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId))
141+
assert(blockIds === Set(bufferBlockId, fileBlockId))
142+
assert(buffers.map(_.underlying) === Set(byteBufferBlockReference, fileBlockReference))
143+
assert(failBlockIds.isEmpty)
144+
buffers.foreach(_.release())
145+
}
163146

164-
assert(receivedBuffers.map(_.underlying) === Set(byteBufferReference, fileReference))
165-
receivedBuffers.foreach(_.release())
166-
client.close()
147+
test("fetch both ByteBuffer block and a non-existent block") {
148+
val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block"))
149+
assert(blockIds === Set(bufferBlockId))
150+
assert(buffers.map(_.underlying) === Set(byteBufferBlockReference))
151+
assert(failBlockIds === Set("random-block"))
152+
buffers.foreach(_.release())
167153
}
168154
}

0 commit comments

Comments
 (0)