From e2b8c4a2595c3196821ffed582ceea487d0d65d4 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 18 Jul 2014 18:01:35 -0700 Subject: [PATCH 1/7] Modify to propagete error using ConnectionManager --- .../apache/spark/network/BufferMessage.scala | 6 ++-- .../spark/network/ConnectionManager.scala | 1 - .../org/apache/spark/network/Message.scala | 2 ++ .../spark/network/MessageChunkHeader.scala | 19 ++++++++++-- .../spark/storage/BlockFetcherIterator.scala | 29 ++++++++++++------- .../spark/storage/BlockManagerWorker.scala | 8 +++-- 6 files changed, 46 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala index 04df2f3b0d69..09541e721fe8 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala @@ -48,7 +48,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val security = if (isSecurityNeg) 1 else 0 if (size == 0 && !gotChunkForSendingOnce) { val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null) + new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null) gotChunkForSendingOnce = true return Some(newChunk) } @@ -66,7 +66,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: } buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer) gotChunkForSendingOnce = true return Some(newChunk) } @@ -88,7 +88,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer) return Some(newChunk) } None diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 8a1cdb812962..9c5e5816356c 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -674,7 +674,6 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id } } - sendMessage(connectionManagerId, ackMessage.getOrElse { Message.createBufferMessage(bufferMessage.id) }) diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala index 7caccfdbb44f..04ea50f62918 100644 --- a/core/src/main/scala/org/apache/spark/network/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/Message.scala @@ -28,6 +28,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) { var startTime = -1L var finishTime = -1L var isSecurityNeg = false + var hasError = false def size: Int @@ -87,6 +88,7 @@ private[spark] object Message { case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) } + newMessage.hasError = header.hasError newMessage.senderAddress = header.address newMessage } diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala index ead663ede7a1..a9d0be082464 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala @@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader( val totalSize: Int, val chunkSize: Int, val other: Int, + val hasError: Boolean, val securityNeg: Int, val address: InetSocketAddress) { lazy val buffer = { @@ -41,6 +42,13 @@ private[spark] class MessageChunkHeader( putInt(totalSize). putInt(chunkSize). putInt(other). + put{ + if (hasError) { + 1.asInstanceOf[Byte] + } else { + 0.asInstanceOf[Byte] + } + }. putInt(securityNeg). putInt(ip.size). put(ip). @@ -56,7 +64,7 @@ private[spark] class MessageChunkHeader( private[spark] object MessageChunkHeader { - val HEADER_SIZE = 44 + val HEADER_SIZE = 45 def create(buffer: ByteBuffer): MessageChunkHeader = { if (buffer.remaining != HEADER_SIZE) { @@ -67,13 +75,20 @@ private[spark] object MessageChunkHeader { val totalSize = buffer.getInt() val chunkSize = buffer.getInt() val other = buffer.getInt() + val hasError = { + if (buffer.get() == 0) { + false + } else { + true + } + } val securityNeg = buffer.getInt() val ipSize = buffer.getInt() val ipBytes = new Array[Byte](ipSize) buffer.get(ipBytes) val ip = InetAddress.getByAddress(ipBytes) val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg, + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg, new InetSocketAddress(ip, port)) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 408a79708805..a928ca2c07ca 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -122,18 +122,25 @@ object BlockFetcherIterator { future.onSuccess { case Some(message) => { val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - throw new SparkException( - "Unexpected message " + blockMessage.getType + " received from " + cmId) + if (bufferMessage.hasError) { + logError("Could not get block(s) from " + cmId) + for ((blockId, size) <- req.blocks) { + results.put(new FetchResult(blockId, -1, null)) + } + } else { + val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) + for (blockMessage <- blockMessageArray) { + if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { + throw new SparkException( + "Unexpected message " + blockMessage.getType + " received from " + cmId) + } + val blockId = blockMessage.getId + val networkSize = blockMessage.getData.limit() + results.put(new FetchResult(blockId, sizeMap(blockId), + () => dataDeserialize(blockId, blockMessage.getData, serializer))) + _remoteBytesRead += networkSize + logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - results.put(new FetchResult(blockId, sizeMap(blockId), - () => dataDeserialize(blockId, blockMessage.getData, serializer))) - _remoteBytesRead += networkSize - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } } case None => { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index c7766a3a6567..419995eb9c47 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -44,8 +44,12 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) Some(new BlockMessageArray(responseMessages).toBufferMessage) } catch { - case e: Exception => logError("Exception handling buffer message", e) - None + case e: Exception => { + logError("Exception handling buffer message", e) + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) + } } } case otherMessage: Any => { From 4117b8fb0da568da1c6c41a06bab8004ef4dc422 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 20 Jul 2014 02:00:22 -0700 Subject: [PATCH 2/7] Modified ConnectionManager to be alble to handle error during processing message --- .../spark/network/ConnectionManager.scala | 45 ++++++++++++------- .../spark/network/MessageChunkHeader.scala | 16 +------ .../spark/storage/BlockManagerWorker.scala | 4 +- 3 files changed, 33 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 9c5e5816356c..ba78d470ec16 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -657,26 +657,37 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, sentMessageStatus.markDone() } } else { - val ackMessage = if (onReceiveCallback != null) { - logDebug("Calling back") - onReceiveCallback(bufferMessage, connectionManagerId) - } else { - logDebug("Not calling back as callback is null") - None - } + var ackMessage : Option[Message] = None + try { + ackMessage = if (onReceiveCallback != null) { + logDebug("Calling back") + onReceiveCallback(bufferMessage, connectionManagerId) + } else { + logDebug("Not calling back as callback is null") + None + } - if (ackMessage.isDefined) { - if (!ackMessage.get.isInstanceOf[BufferMessage]) { - logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " - + ackMessage.get.getClass) - } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { - logDebug("Response to " + bufferMessage + " does not have ack id set") - ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id + if (ackMessage.isDefined) { + if (!ackMessage.get.isInstanceOf[BufferMessage]) { + logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + + ackMessage.get.getClass) + } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { + logDebug("Response to " + bufferMessage + " does not have ack id set") + ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id + } } + } catch { + case e: Exception => { + logError(s"Exception was thrown during processing message", e) + val m = Message.createBufferMessage(bufferMessage.id) + m.hasError = true + ackMessage = Some(m) + } + } finally { + sendMessage(connectionManagerId, ackMessage.getOrElse { + Message.createBufferMessage(bufferMessage.id) + }) } - sendMessage(connectionManagerId, ackMessage.getOrElse { - Message.createBufferMessage(bufferMessage.id) - }) } } case _ => throw new Exception("Unknown type message received") diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala index a9d0be082464..f3ecca5f992e 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala @@ -42,13 +42,7 @@ private[spark] class MessageChunkHeader( putInt(totalSize). putInt(chunkSize). putInt(other). - put{ - if (hasError) { - 1.asInstanceOf[Byte] - } else { - 0.asInstanceOf[Byte] - } - }. + put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]). putInt(securityNeg). putInt(ip.size). put(ip). @@ -75,13 +69,7 @@ private[spark] object MessageChunkHeader { val totalSize = buffer.getInt() val chunkSize = buffer.getInt() val other = buffer.getInt() - val hasError = { - if (buffer.get() == 0) { - false - } else { - true - } - } + val hasError = buffer.get() != 0 val securityNeg = buffer.getInt() val ipSize = buffer.getInt() val ipBytes = new Array[Byte](ipSize) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index 419995eb9c47..9c459cd3a343 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -54,7 +54,9 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends } case otherMessage: Any => { logError("Unknown type message received: " + otherMessage) - None + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) } } } From 12d3de84a4d9d6d47df92059bc41965723bf02f7 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 20 Jul 2014 11:02:24 -0700 Subject: [PATCH 3/7] Added BlockFetcherIteratorSuite.scala --- .../storage/BlockFetcherIteratorSuite.scala | 74 +++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala new file mode 100644 index 000000000000..4cc81572693a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.scalatest.{FunSuite, Matchers} + +import org.mockito.Mockito.{mock, when} +import org.mockito.Matchers.any + +import java.nio.ByteBuffer + +import scala.concurrent.future +import scala.concurrent.ExecutionContext.Implicits.global +import org.apache.spark._ +import org.apache.spark.storage.BlockFetcherIterator._ +import org.apache.spark.network.{ConnectionManager, ConnectionManagerId, + Message} + +class BlockFetcherIteratorSuite extends FunSuite with Matchers { + + test("block fetch from remote fails using BasicBlockFetcherIterator") { + val conf = new SparkConf + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + val message = Message.createBufferMessage(0) + message.hasError = true + val someMessage = Some(message) + + val f = future { + someMessage + } + when(blockManager.connectionManager).thenReturn(connManager) + when(connManager.sendMessageReliably(any(), + any())).thenReturn(f) + when(blockManager.futureExecContext).thenReturn(global) + when(blockManager.blockManagerId).thenReturn( + BlockManagerId("test-client", "test-client", 1, 0)) + when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) + + val dummyBlId1 = ShuffleBlockId(0,0,0) + val dummyBlId2 = ShuffleBlockId(0,1,0) + val bmId = BlockManagerId("test-server", "test-server",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, Seq((dummyBlId1, 1))), + (bmId, Seq((dummyBlId2, 1))) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null) + + iterator.initialize() + iterator.foreach{ + case (_, r) => { + (!r.isDefined) should be(true) + } + } + } + +} From 281589c5018b0cadbebc6f8cfac8a831dd40edeb Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 22 Jul 2014 22:55:35 -0700 Subject: [PATCH 4/7] Add a test case to BlockFetcherIteratorSuite.scala for fetching block from remote from successfully --- .../storage/BlockFetcherIteratorSuite.scala | 73 ++++++++++++++++--- 1 file changed, 64 insertions(+), 9 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala index 4cc81572693a..6adb9f51436d 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -24,8 +24,10 @@ import org.mockito.Matchers.any import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global + import org.apache.spark._ import org.apache.spark.storage.BlockFetcherIterator._ import org.apache.spark.network.{ConnectionManager, ConnectionManagerId, @@ -34,30 +36,29 @@ import org.apache.spark.network.{ConnectionManager, ConnectionManagerId, class BlockFetcherIteratorSuite extends FunSuite with Matchers { test("block fetch from remote fails using BasicBlockFetcherIterator") { - val conf = new SparkConf val blockManager = mock(classOf[BlockManager]) val connManager = mock(classOf[ConnectionManager]) - val message = Message.createBufferMessage(0) - message.hasError = true - val someMessage = Some(message) + when(blockManager.connectionManager).thenReturn(connManager) val f = future { + val message = Message.createBufferMessage(0) + message.hasError = true + val someMessage = Some(message) someMessage } - when(blockManager.connectionManager).thenReturn(connManager) when(connManager.sendMessageReliably(any(), any())).thenReturn(f) when(blockManager.futureExecContext).thenReturn(global) + when(blockManager.blockManagerId).thenReturn( BlockManagerId("test-client", "test-client", 1, 0)) when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) - val dummyBlId1 = ShuffleBlockId(0,0,0) - val dummyBlId2 = ShuffleBlockId(0,1,0) + val blId1 = ShuffleBlockId(0,0,0) + val blId2 = ShuffleBlockId(0,1,0) val bmId = BlockManagerId("test-server", "test-server",1 , 0) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, Seq((dummyBlId1, 1))), - (bmId, Seq((dummyBlId2, 1))) + (bmId, Seq((blId1, 1), (blId2, 1))) ) val iterator = new BasicBlockFetcherIterator(blockManager, @@ -71,4 +72,58 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { } } + test("block fetch from remote succeed using BasicBlockFetcherIterator") { + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + when(blockManager.connectionManager).thenReturn(connManager) + + val blId1 = ShuffleBlockId(0,0,0) + val blId2 = ShuffleBlockId(0,1,0) + val buf1 = ByteBuffer.allocate(4) + val buf2 = ByteBuffer.allocate(4) + buf1.putInt(1) + buf1.flip() + buf2.putInt(1) + buf2.flip() + val blockMessage1 = BlockMessage.fromGotBlock(GotBlock(blId1, buf1)) + val blockMessage2 = BlockMessage.fromGotBlock(GotBlock(blId2, buf2)) + val blockMessageArray = new BlockMessageArray( + Seq(blockMessage1, blockMessage2)) + + val bufferMessage = blockMessageArray.toBufferMessage + val buffer = ByteBuffer.allocate(bufferMessage.size) + val arrayBuffer = new ArrayBuffer[ByteBuffer] + bufferMessage.buffers.foreach{ b => + buffer.put(b) + } + buffer.flip + arrayBuffer += buffer + + val someMessage = Some(Message.createBufferMessage(arrayBuffer)) + + val f = future { + someMessage + } + when(connManager.sendMessageReliably(any(), + any())).thenReturn(f) + when(blockManager.futureExecContext).thenReturn(global) + + when(blockManager.blockManagerId).thenReturn( + BlockManagerId("test-client", "test-client", 1, 0)) + when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) + + val bmId = BlockManagerId("test-server", "test-server",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, Seq((blId1, 1), (blId2, 1))) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null) + iterator.initialize() + iterator.foreach{ + case (_, r) => { + (r.isDefined) should be(true) + } + } + } } From 22d7ebde6d706e4336154af5bec87f36d1501b20 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 23 Jul 2014 17:45:01 -0700 Subject: [PATCH 5/7] Add test cases to BlockManagerSuite for SPARK-2583 --- .../storage/BlockFetcherIteratorSuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 112 +++++++++++++++++- 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala index 6adb9f51436d..db37efbf70fe 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -96,7 +96,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { bufferMessage.buffers.foreach{ b => buffer.put(b) } - buffer.flip + buffer.flip() arrayBuffer += buffer val someMessage = Some(Message.createBufferMessage(arrayBuffer)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 23cb6905bfde..f7f1ded1d4a6 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -24,7 +24,10 @@ import akka.actor._ import org.apache.spark.SparkConf import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} -import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.Mockito.{doAnswer, mock, spy, when} +import org.mockito.stubbing.Answer +import org.mockito.Matchers.any import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ @@ -33,8 +36,11 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf, SparkContext} import org.apache.spark.executor.DataReadMethod +import org.apache.spark.network.{BufferMessage, ConnectionManagerId, Message} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.storage.{BlockMessage, BlockMessageArray} +import org.apache.spark.storage.BlockMessage._ import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} import scala.collection.mutable.ArrayBuffer @@ -1015,4 +1021,108 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter assert(store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was not in store") assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store") } + + test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + + val worker = spy(new BlockManagerWorker(store)) + val connManagerId = mock(classOf[ConnectionManagerId]) + + // setup request block messages + val reqBlId1 = ShuffleBlockId(0,0,0) + val reqBlId2 = ShuffleBlockId(0,1,0) + val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1)) + val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2)) + val reqBlockMessages = new BlockMessageArray( + Seq(reqBlockMessage1, reqBlockMessage2)) + val reqBufferMessage = reqBlockMessages.toBufferMessage + + val answer = new Answer[Option[BlockMessage]] { + override def answer(invocation: InvocationOnMock) + :Option[BlockMessage]= { + throw new Exception + } + } + + doAnswer(answer).when(worker).processBlockMessage(any()) + + // Test when exception was thrown during processing block messages + var ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId) + + assert(ackMessage.isDefined, "When Exception was thrown in " + + "BlockManagerWorker#processBlockMessage, " + + "ackMessage should be defined") + assert(ackMessage.get.hasError, "When Exception was thown in " + + "BlockManagerWorker#processBlockMessage, " + + "ackMessage should have error") + + val notBufferMessage = mock(classOf[Message]) + + // Test when not BufferMessage was received + ackMessage = worker.onBlockMessageReceive(notBufferMessage, connManagerId) + assert(ackMessage.isDefined, "When not BufferMessage was passed to " + + "BlockManagerWorker#onBlockMessageReceive, " + + "ackMessage should be defined") + assert(ackMessage.get.hasError, "When not BufferMessage was passed to " + + "BlockManagerWorker#onBlockMessageReceive, " + + "ackMessage should have error") + } + + test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + + val worker = spy(new BlockManagerWorker(store)) + val connManagerId = mock(classOf[ConnectionManagerId]) + + // setup request block messages + val reqBlId1 = ShuffleBlockId(0,0,0) + val reqBlId2 = ShuffleBlockId(0,1,0) + val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1)) + val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2)) + val reqBlockMessages = new BlockMessageArray( + Seq(reqBlockMessage1, reqBlockMessage2)) + + val tmpBufferMessage = reqBlockMessages.toBufferMessage + val buffer = ByteBuffer.allocate(tmpBufferMessage.size) + val arrayBuffer = new ArrayBuffer[ByteBuffer] + tmpBufferMessage.buffers.foreach{ b => + buffer.put(b) + } + buffer.flip() + arrayBuffer += buffer + val reqBufferMessage = Message.createBufferMessage(arrayBuffer) + + // setup ack block messages + val buf1 = ByteBuffer.allocate(4) + val buf2 = ByteBuffer.allocate(4) + buf1.putInt(1) + buf1.flip() + buf2.putInt(1) + buf2.flip() + val ackBlockMessage1 = BlockMessage.fromGotBlock(GotBlock(reqBlId1, buf1)) + val ackBlockMessage2 = BlockMessage.fromGotBlock(GotBlock(reqBlId2, buf2)) + + val answer = new Answer[Option[BlockMessage]] { + override def answer(invocation: InvocationOnMock) + :Option[BlockMessage]= { + if (invocation.getArguments()(0).asInstanceOf[BlockMessage].eq( + reqBlockMessage1)) { + return Some(ackBlockMessage1) + } else { + return Some(ackBlockMessage2) + } + } + } + + doAnswer(answer).when(worker).processBlockMessage(any()) + + val ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId) + assert(ackMessage.isDefined, "When BlockManagerWorker#onBlockMessageReceive " + + "was executed successfully, ackMessage should be defined") + assert(!ackMessage.get.hasError, "When BlockManagerWorker#onBlockMessageReceive " + + "was executed successfully, ackMessage should not have error") + } + } From 326a17f05e078f81486a9c64361b40eb991d5e01 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Thu, 24 Jul 2014 00:21:42 -0700 Subject: [PATCH 6/7] Add test cases to ConnectionManagerSuite.scala for SPARK-2583 --- .../network/ConnectionManagerSuite.scala | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala index 415ad8c432c1..162d0a016c13 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala @@ -223,7 +223,31 @@ class ConnectionManagerSuite extends FunSuite { managerServer.stop() } + test("Ack error message") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + val managerServer = new ConnectionManager(0, conf, securityManager) + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + throw new Exception + }) + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer) + + val future = manager.sendMessageReliably(managerServer.id, bufferMessage) + + val message = Await.result(future, 1 second) + assert(message.isDefined) + assert(message.get.hasError) + + manager.stop() + managerServer.stop() + + } } From ee91bb793666b260c9018e532dcf32c7115b3194 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 30 Jul 2014 04:57:42 +0900 Subject: [PATCH 7/7] Modified BufferMessage.scala to keep the spark code style --- .../main/scala/org/apache/spark/network/BufferMessage.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala index 09541e721fe8..af35f1fc3e45 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala @@ -66,7 +66,8 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: } buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, + hasError, security, senderAddress), newBuffer) gotChunkForSendingOnce = true return Some(newChunk) }