Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
val security = if (isSecurityNeg) 1 else 0
if (size == 0 && !gotChunkForSendingOnce) {
val newChunk = new MessageChunk(
new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null)
new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null)
gotChunkForSendingOnce = true
return Some(newChunk)
}
Expand All @@ -66,7 +66,8 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
}
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
typ, id, size, newBuffer.remaining, ackId,
hasError, security, senderAddress), newBuffer)
gotChunkForSendingOnce = true
return Some(newChunk)
}
Expand All @@ -88,7 +89,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer)
return Some(newChunk)
}
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,27 +660,37 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
sentMessageStatus.markDone()
}
} else {
val ackMessage = if (onReceiveCallback != null) {
logDebug("Calling back")
onReceiveCallback(bufferMessage, connectionManagerId)
} else {
logDebug("Not calling back as callback is null")
None
}
var ackMessage : Option[Message] = None
try {
ackMessage = if (onReceiveCallback != null) {
logDebug("Calling back")
onReceiveCallback(bufferMessage, connectionManagerId)
} else {
logDebug("Not calling back as callback is null")
None
}

if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
+ ackMessage.get.getClass)
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
logDebug("Response to " + bufferMessage + " does not have ack id set")
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
+ ackMessage.get.getClass)
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
logDebug("Response to " + bufferMessage + " does not have ack id set")
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
}
}
} catch {
case e: Exception => {
logError(s"Exception was thrown during processing message", e)
val m = Message.createBufferMessage(bufferMessage.id)
m.hasError = true
ackMessage = Some(m)
}
} finally {
sendMessage(connectionManagerId, ackMessage.getOrElse {
Message.createBufferMessage(bufferMessage.id)
})
}

sendMessage(connectionManagerId, ackMessage.getOrElse {
Message.createBufferMessage(bufferMessage.id)
})
}
}
case _ => throw new Exception("Unknown type message received")
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/network/Message.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
var startTime = -1L
var finishTime = -1L
var isSecurityNeg = false
var hasError = false

def size: Int

Expand Down Expand Up @@ -87,6 +88,7 @@ private[spark] object Message {
case BUFFER_MESSAGE => new BufferMessage(header.id,
ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
}
newMessage.hasError = header.hasError
newMessage.senderAddress = header.address
newMessage
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader(
val totalSize: Int,
val chunkSize: Int,
val other: Int,
val hasError: Boolean,
val securityNeg: Int,
val address: InetSocketAddress) {
lazy val buffer = {
Expand All @@ -41,6 +42,7 @@ private[spark] class MessageChunkHeader(
putInt(totalSize).
putInt(chunkSize).
putInt(other).
put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]).
putInt(securityNeg).
putInt(ip.size).
put(ip).
Expand All @@ -56,7 +58,7 @@ private[spark] class MessageChunkHeader(


private[spark] object MessageChunkHeader {
val HEADER_SIZE = 44
val HEADER_SIZE = 45

def create(buffer: ByteBuffer): MessageChunkHeader = {
if (buffer.remaining != HEADER_SIZE) {
Expand All @@ -67,13 +69,14 @@ private[spark] object MessageChunkHeader {
val totalSize = buffer.getInt()
val chunkSize = buffer.getInt()
val other = buffer.getInt()
val hasError = buffer.get() != 0
val securityNeg = buffer.getInt()
val ipSize = buffer.getInt()
val ipBytes = new Array[Byte](ipSize)
buffer.get(ipBytes)
val ip = InetAddress.getByAddress(ipBytes)
val port = buffer.getInt()
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg,
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg,
new InetSocketAddress(ip, port))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,25 @@ object BlockFetcherIterator {
future.onSuccess {
case Some(message) => {
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
for (blockMessage <- blockMessageArray) {
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
throw new SparkException(
"Unexpected message " + blockMessage.getType + " received from " + cmId)
if (bufferMessage.hasError) {
logError("Could not get block(s) from " + cmId)
for ((blockId, size) <- req.blocks) {
results.put(new FetchResult(blockId, -1, null))
}
} else {
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
for (blockMessage <- blockMessageArray) {
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
throw new SparkException(
"Unexpected message " + blockMessage.getType + " received from " + cmId)
}
val blockId = blockMessage.getId
val networkSize = blockMessage.getData.limit()
results.put(new FetchResult(blockId, sizeMap(blockId),
() => dataDeserialize(blockId, blockMessage.getData, serializer)))
_remoteBytesRead += networkSize
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
val blockId = blockMessage.getId
val networkSize = blockMessage.getData.limit()
results.put(new FetchResult(blockId, sizeMap(blockId),
() => dataDeserialize(blockId, blockMessage.getData, serializer)))
_remoteBytesRead += networkSize
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
}
case None => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,19 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
Some(new BlockMessageArray(responseMessages).toBufferMessage)
} catch {
case e: Exception => logError("Exception handling buffer message", e)
None
case e: Exception => {
logError("Exception handling buffer message", e)
val errorMessage = Message.createBufferMessage(msg.id)
errorMessage.hasError = true
Some(errorMessage)
}
}
}
case otherMessage: Any => {
logError("Unknown type message received: " + otherMessage)
None
val errorMessage = Message.createBufferMessage(msg.id)
errorMessage.hasError = true
Some(errorMessage)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,31 @@ class ConnectionManagerSuite extends FunSuite {
managerServer.stop()
}

test("Ack error message") {
val conf = new SparkConf
conf.set("spark.authenticate", "false")
val securityManager = new SecurityManager(conf)
val manager = new ConnectionManager(0, conf, securityManager)
val managerServer = new ConnectionManager(0, conf, securityManager)
managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
throw new Exception
})

val size = 10 * 1024 * 1024
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
val bufferMessage = Message.createBufferMessage(buffer)

val future = manager.sendMessageReliably(managerServer.id, bufferMessage)

val message = Await.result(future, 1 second)
assert(message.isDefined)
assert(message.get.hasError)

manager.stop()
managerServer.stop()

}

}

Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.storage

import org.scalatest.{FunSuite, Matchers}

import org.mockito.Mockito.{mock, when}
import org.mockito.Matchers.any

import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.future
import scala.concurrent.ExecutionContext.Implicits.global

import org.apache.spark._
import org.apache.spark.storage.BlockFetcherIterator._
import org.apache.spark.network.{ConnectionManager, ConnectionManagerId,
Message}

class BlockFetcherIteratorSuite extends FunSuite with Matchers {

test("block fetch from remote fails using BasicBlockFetcherIterator") {
val blockManager = mock(classOf[BlockManager])
val connManager = mock(classOf[ConnectionManager])
when(blockManager.connectionManager).thenReturn(connManager)

val f = future {
val message = Message.createBufferMessage(0)
message.hasError = true
val someMessage = Some(message)
someMessage
}
when(connManager.sendMessageReliably(any(),
any())).thenReturn(f)
when(blockManager.futureExecContext).thenReturn(global)

when(blockManager.blockManagerId).thenReturn(
BlockManagerId("test-client", "test-client", 1, 0))
when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024)

val blId1 = ShuffleBlockId(0,0,0)
val blId2 = ShuffleBlockId(0,1,0)
val bmId = BlockManagerId("test-server", "test-server",1 , 0)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(bmId, Seq((blId1, 1), (blId2, 1)))
)

val iterator = new BasicBlockFetcherIterator(blockManager,
blocksByAddress, null)

iterator.initialize()
iterator.foreach{
case (_, r) => {
(!r.isDefined) should be(true)
}
}
}

test("block fetch from remote succeed using BasicBlockFetcherIterator") {
val blockManager = mock(classOf[BlockManager])
val connManager = mock(classOf[ConnectionManager])
when(blockManager.connectionManager).thenReturn(connManager)

val blId1 = ShuffleBlockId(0,0,0)
val blId2 = ShuffleBlockId(0,1,0)
val buf1 = ByteBuffer.allocate(4)
val buf2 = ByteBuffer.allocate(4)
buf1.putInt(1)
buf1.flip()
buf2.putInt(1)
buf2.flip()
val blockMessage1 = BlockMessage.fromGotBlock(GotBlock(blId1, buf1))
val blockMessage2 = BlockMessage.fromGotBlock(GotBlock(blId2, buf2))
val blockMessageArray = new BlockMessageArray(
Seq(blockMessage1, blockMessage2))

val bufferMessage = blockMessageArray.toBufferMessage
val buffer = ByteBuffer.allocate(bufferMessage.size)
val arrayBuffer = new ArrayBuffer[ByteBuffer]
bufferMessage.buffers.foreach{ b =>
buffer.put(b)
}
buffer.flip()
arrayBuffer += buffer

val someMessage = Some(Message.createBufferMessage(arrayBuffer))

val f = future {
someMessage
}
when(connManager.sendMessageReliably(any(),
any())).thenReturn(f)
when(blockManager.futureExecContext).thenReturn(global)

when(blockManager.blockManagerId).thenReturn(
BlockManagerId("test-client", "test-client", 1, 0))
when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024)

val bmId = BlockManagerId("test-server", "test-server",1 , 0)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(bmId, Seq((blId1, 1), (blId2, 1)))
)

val iterator = new BasicBlockFetcherIterator(blockManager,
blocksByAddress, null)
iterator.initialize()
iterator.foreach{
case (_, r) => {
(r.isDefined) should be(true)
}
}
}
}
Loading