Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,17 @@ package object config {
.intConf
.createWithDefault(3)

private[spark] val REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS =
ConfigBuilder("spark.reducer.maxBlocksInFlightPerAddress")
.doc("This configuration limits the number of remote blocks being fetched per reduce task" +
" from a given host port. When a large number of blocks are being requested from a given" +
" address in a single fetch or simultaneously, this could crash the serving executor or" +
" Node Manager. This is especially useful to reduce the load on the Node Manager when" +
" external shuffle is enabled. You can mitigate the issue by setting it to a lower value.")
.intConf
.checkValue(_ > 0, "The max no. of blocks in flight cannot be non-positive.")
.createWithDefault(Int.MaxValue)

private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM =
ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem")
.doc("The blocks of a shuffle request will be fetched to disk when size of the request is " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.concurrent.LinkedBlockingQueue
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -52,6 +52,8 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream
* @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
* @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
* @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point
* for a given remote host:port.
* @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory.
* @param detectCorrupt whether to detect any corruption in fetched blocks.
*/
Expand All @@ -64,6 +66,7 @@ final class ShuffleBlockFetcherIterator(
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int,
maxBlocksInFlightPerAddress: Int,
maxReqSizeShuffleToMem: Long,
detectCorrupt: Boolean)
extends Iterator[(BlockId, InputStream)] with TempShuffleFileManager with Logging {
Expand Down Expand Up @@ -110,12 +113,21 @@ final class ShuffleBlockFetcherIterator(
*/
private[this] val fetchRequests = new Queue[FetchRequest]

/**
* Queue of fetch requests which could not be issued the first time they were dequeued. These
* requests are tried again when the fetch constraints are satisfied.
*/
private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]()

/** Current bytes in flight from our requests */
private[this] var bytesInFlight = 0L

/** Current number of requests in flight */
private[this] var reqsInFlight = 0

/** Current number of blocks in flight per host:port */
private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]()

/**
* The blocks that can't be decompressed successfully, it is used to guarantee that we retry
* at most once for those corrupted blocks.
Expand Down Expand Up @@ -248,7 +260,8 @@ final class ShuffleBlockFetcherIterator(
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize
+ ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress)

// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
Expand Down Expand Up @@ -277,11 +290,13 @@ final class ShuffleBlockFetcherIterator(
} else if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
}
if (curRequestSize >= targetRequestSize) {
if (curRequestSize >= targetRequestSize ||
curBlocks.size >= maxBlocksInFlightPerAddress) {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
logDebug(s"Creating fetch request of $curRequestSize at $address "
+ s"with ${curBlocks.size} blocks")
curBlocks = new ArrayBuffer[(BlockId, Long)]
logDebug(s"Creating fetch request of $curRequestSize at $address")
curRequestSize = 0
}
}
Expand Down Expand Up @@ -375,6 +390,7 @@ final class ShuffleBlockFetcherIterator(
result match {
case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
if (address != blockManager.blockManagerId) {
numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
shuffleMetrics.incRemoteBytesRead(buf.size)
if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
Expand Down Expand Up @@ -443,12 +459,57 @@ final class ShuffleBlockFetcherIterator(
}

private def fetchUpToMaxBytes(): Unit = {
// Send fetch requests up to maxBytesInFlight
while (fetchRequests.nonEmpty &&
(bytesInFlight == 0 ||
(reqsInFlight + 1 <= maxReqsInFlight &&
bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) {
sendRequest(fetchRequests.dequeue())
// Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
// immediately, defer the request until the next time it can be processed.

// Process any outstanding deferred fetch requests if possible.
if (deferredFetchRequests.nonEmpty) {
for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
while (isRemoteBlockFetchable(defReqQueue) &&
!isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
val request = defReqQueue.dequeue()
logDebug(s"Processing deferred fetch request for $remoteAddress with "
+ s"${request.blocks.length} blocks")
send(remoteAddress, request)
if (defReqQueue.isEmpty) {
deferredFetchRequests -= remoteAddress
}
}
}
}

// Process any regular fetch requests if possible.
while (isRemoteBlockFetchable(fetchRequests)) {
val request = fetchRequests.dequeue()
val remoteAddress = request.address
if (isRemoteAddressMaxedOut(remoteAddress, request)) {
logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks")
val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
defReqQueue.enqueue(request)
deferredFetchRequests(remoteAddress) = defReqQueue
} else {
send(remoteAddress, request)
}
}

def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
sendRequest(request)
numBlocksInFlightPerAddress(remoteAddress) =
numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
}

def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = {
fetchReqQueue.nonEmpty &&
(bytesInFlight == 0 ||
(reqsInFlight + 1 <= maxReqsInFlight &&
bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight))
}

// Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a
// given remote address.
def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = {
numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size >
maxBlocksInFlightPerAddress
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ private[spark] class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Orderin
this
}

def poll(): A = {
underlying.poll()
}

override def +=(elem1: A, elem2: A, elems: A*): this.type = {
this += elem1 += elem2 ++= elems
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
Int.MaxValue,
true)

// 3 local blocks fetched in initialization
Expand Down Expand Up @@ -187,6 +188,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
Int.MaxValue,
true)

verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
Expand Down Expand Up @@ -254,6 +256,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
Int.MaxValue,
true)

// Continue only after the mock calls onBlockFetchFailure
Expand Down Expand Up @@ -319,6 +322,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
Int.MaxValue,
true)

// Continue only after the mock calls onBlockFetchFailure
Expand Down Expand Up @@ -400,6 +404,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
Int.MaxValue,
false)

// Continue only after the mock calls onBlockFetchFailure
Expand Down Expand Up @@ -457,6 +462,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
(_, in) => in,
maxBytesInFlight = Int.MaxValue,
maxReqsInFlight = Int.MaxValue,
maxBlocksInFlightPerAddress = Int.MaxValue,
maxReqSizeShuffleToMem = 200,
detectCorrupt = true)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util

import org.apache.spark.SparkFunSuite

class BoundedPriorityQueueSuite extends SparkFunSuite {
test("BoundedPriorityQueue poll test") {
val pq = new BoundedPriorityQueue[Double](4)

pq += 0.1
pq += 1.5
pq += 1.0
pq += 0.3
pq += 0.01

assert(pq.isEmpty == false)
assert(pq.poll() == 0.1)
assert(pq.poll() == 0.3)
assert(pq.poll() == 1.0)
assert(pq.poll() == 1.5)
assert(pq.isEmpty == true)

val pq2 = new BoundedPriorityQueue[(Int, Double)](4)(Ordering.by(_._2))
pq2 += 1 -> 0.5
pq2 += 5 -> 0.1
pq2 += 3 -> 0.3
pq2 += 4 -> 0.2
pq2 += 1 -> 0.4

assert(pq2.poll()._2 == 0.2)
assert(pq2.poll()._2 == 0.3)
assert(pq2.poll()._2 == 0.4)
assert(pq2.poll()._2 == 0.5)
}
}
9 changes: 9 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,15 @@ Apart from these, the following properties are also available, and may be useful
</td>
</tr>
<tr>
<td><code>spark.reducer.maxBlocksInFlightPerAddress</code></td>
<td>Int.MaxValue</td>
<td>
This configuration limits the number of remote blocks being fetched per reduce task from a
given host port. When a large number of blocks are being requested from a given address in a
single fetch or simultaneously, this could crash the serving executor or Node Manager. This
is especially useful to reduce the load on the Node Manager when external shuffle is enabled.
You can mitigate this issue by setting it to a lower value.
</td>
<td><code>spark.reducer.maxReqSizeShuffleToMem</code></td>
<td>Long.MaxValue</td>
<td>
Expand Down
11 changes: 11 additions & 0 deletions docs/running-on-mesos.md
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,17 @@ See the [configuration page](configuration.html) for information on Spark config
Fetcher Cache</a>
</td>
</tr>
<tr>
<td><code>spark.mesos.driver.failoverTimeout</code></td>
<td><code>0.0</code></td>
<td>
The amount of time (in seconds) that the master will wait for the
driver to reconnect, after being temporarily disconnected, before
it tears down the driver framework by killing all its
executors. The default value is zero, meaning no timeout: if the
driver disconnects, the master immediately tears down the framework.
</td>
</tr>
</table>

# Troubleshooting and Debugging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,16 @@ package object config {

private [spark] val DRIVER_LABELS =
ConfigBuilder("spark.mesos.driver.labels")
.doc("Mesos labels to add to the driver. Labels are free-form key-value pairs. Key-value" +
.doc("Mesos labels to add to the driver. Labels are free-form key-value pairs. Key-value " +
"pairs should be separated by a colon, and commas used to list more than one." +
"Ex. key:value,key2:value2")
.stringConf
.createOptional

private [spark] val DRIVER_FAILOVER_TIMEOUT =
ConfigBuilder("spark.mesos.driver.failoverTimeout")
.doc("Amount of time in seconds that the master will wait to hear from the driver, " +
"during a temporary disconnection, before tearing down all the executors.")
.doubleConf
.createWithDefault(0.0)
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _}
import org.apache.mesos.SchedulerDriver

import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState}
import org.apache.spark.deploy.mesos.config._
import org.apache.spark.internal.config
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient
Expand Down Expand Up @@ -177,7 +178,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
sc.conf,
sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.webUrl)),
None,
None,
Some(sc.conf.get(DRIVER_FAILOVER_TIMEOUT)),
sc.conf.getOption("spark.mesos.driver.frameworkId")
)

Expand Down
Loading