Skip to content
91 changes: 88 additions & 3 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, MapStatus}
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util._
Expand Down Expand Up @@ -337,6 +337,21 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
endPartition: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]

/**
* Called from executors to get the server URIs and output sizes for each shuffle block that
* needs to be read from a given range of map output partitions (startPartition is included but
* endPartition is excluded from the range) and a given mapId.
*
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
* and the second item is a sequence of (shuffle block id, shuffle block size, map index)
* tuples describing the shuffle blocks that are stored at that block manager.
*/
def getMapSizesByExecutorId(
shuffleId: Int,
startPartition: Int,
endPartition: Int,
mapId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]

/**
* Deletes map output status information for the specified shuffle stage.
*/
Expand Down Expand Up @@ -668,6 +683,31 @@ private[spark] class MapOutputTrackerMaster(
None
}

/**
* Return the location where the Mapper ran. The locations each includes both a host and an
* executor id on that host.
*
* @param dep shuffle dependency object
* @param mapId the map id
* @return a sequence of locations where task runs.
*/
def getMapLocation(dep: ShuffleDependency[_, _, _], mapId: Int): Seq[String] =
{
val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
if (shuffleStatus != null) {
shuffleStatus.withMapStatuses { statuses =>
if (mapId >= 0 && mapId < statuses.length) {
Seq( ExecutorCacheTaskLocation(statuses(mapId).location.host,
statuses(mapId).location.executorId).toString)
} else {
Nil
}
}
} else {
Nil
}
}

def incrementEpoch(): Unit = {
epochLock.synchronized {
epoch += 1
Expand Down Expand Up @@ -701,6 +741,29 @@ private[spark] class MapOutputTrackerMaster(
}
}

override def getMapSizesByExecutorId(
shuffleId: Int,
startPartition: Int,
endPartition: Int,
mapId: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, mapId $mapId" +
s"partitions $startPartition-$endPartition")
shuffleStatuses.get(shuffleId) match {
case Some (shuffleStatus) =>
shuffleStatus.withMapStatuses { statuses =>
MapOutputTracker.convertMapStatuses(
shuffleId,
startPartition,
endPartition,
statuses,
Some(mapId))
}
case None =>
Iterator.empty
}
}

override def stop(): Unit = {
mapOutputRequests.offer(PoisonPill)
threadpool.shutdown()
Expand Down Expand Up @@ -746,6 +809,25 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
}
}

override def getMapSizesByExecutorId(
shuffleId: Int,
startPartition: Int,
endPartition: Int,
mapId: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, mapId $mapId" +
s"partitions $startPartition-$endPartition")
val statuses = getStatuses(shuffleId)
try {
MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition,
statuses, Some(mapId))
} catch {
case e: MetadataFetchFailedException =>
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
mapStatuses.clear()
throw e
}
}

/**
* Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize
* on this array when reading it, because on the driver, we may be changing it in place.
Expand Down Expand Up @@ -888,10 +970,12 @@ private[spark] object MapOutputTracker extends Logging {
shuffleId: Int,
startPartition: Int,
endPartition: Int,
statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
statuses: Array[MapStatus],
mapId : Option[Int] = None): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
assert (statuses != null)
val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
for ((status, mapIndex) <- statuses.iterator.zipWithIndex) {
val iter = statuses.iterator.zipWithIndex
for ((status, mapIndex) <- mapId.map(id => iter.filter(_._2 == id)).getOrElse(iter)) {
if (status == null) {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
logError(errorMessage)
Expand All @@ -906,6 +990,7 @@ private[spark] object MapOutputTracker extends Logging {
}
}
}

splitsByAddress.iterator
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,33 @@ private[spark] class BlockStoreShuffleReader[K, C](
readMetrics: ShuffleReadMetricsReporter,
serializerManager: SerializerManager = SparkEnv.get.serializerManager,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker,
mapId: Option[Int] = None)
extends ShuffleReader[K, C] with Logging {

private val dep = handle.dependency

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val blocksByAddress = mapId match {
case (Some(mapId)) => mapOutputTracker.getMapSizesByExecutorId(
handle.shuffleId,
startPartition,
endPartition,
mapId)
case (None) => mapOutputTracker.getMapSizesByExecutorId(
handle.shuffleId,
startPartition,
endPartition)
case (_) => throw new IllegalArgumentException(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

case Some(..) =>
case None =>

"mapId should be both set or unset")
}

val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.blockStoreClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
blocksByAddress,
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
Expand Down
13 changes: 13 additions & 0 deletions core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,19 @@ private[spark] trait ShuffleManager {
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C]

/**
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to
* read from mapId.
* Called on executors by reduce tasks.
*/
def getMapReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter,
mapId: Int): ShuffleReader[K, C]

/**
* Remove a shuffle's metadata from the ShuffleManager.
* @return true if the metadata removed successfully, otherwise false.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,27 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
startPartition, endPartition, context, metrics)
}

/**
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to
* read from mapId.
* Called on executors by reduce tasks.
*/
override def getMapReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter,
mapId: Int): ShuffleReader[K, C] = {
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
startPartition,
endPartition,
context,
metrics,
mapId = Some(mapId))
}

/** Get a writer for a given partition. Called on executors by map tasks. */
override def getWriter[K, V](
handle: ShuffleHandle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,14 @@ object SQLConf {
"must be a positive integer.")
.createOptional

val OPTIMIZE_LOCAL_SHUFFLE_READER_ENABLED =
buildConf("spark.sql.adaptive.optimizedLocalShuffleReader.enabled")
.doc("When true and adaptive execution is enabled, this enables the optimization of" +
" converting the shuffle reader to local shuffle reader for the shuffle exchange" +
" of the broadcast hash join in probe side.")
.booleanConf
.createWithDefault(true)

val SUBEXPRESSION_ELIMINATION_ENABLED =
buildConf("spark.sql.subexpressionElimination.enabled")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, LocalShuffleReaderExec, QueryStageExec}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.metric.SQLMetricInfo
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -56,6 +56,7 @@ private[execution] object SparkPlanInfo {
case ReusedSubqueryExec(child) => child :: Nil
case a: AdaptiveSparkPlanExec => a.executedPlan :: Nil
case stage: QueryStageExec => stage.plan :: Nil
case localReader: LocalShuffleReaderExec => localReader.child :: Nil
case _ => plan.children ++ plan.subqueries
}
val metrics = plan.metrics.toSeq.map { case (key, metric) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ case class AdaptiveSparkPlanExec(
// plan should reach a final status of query stages (i.e., no more addition or removal of
// Exchange nodes) after running these rules.
private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
OptimizeLocalShuffleReader(conf),
ensureRequirements
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ trait AdaptiveSparkPlanHelper {
private def allChildren(p: SparkPlan): Seq[SparkPlan] = p match {
case a: AdaptiveSparkPlanExec => Seq(a.executedPlan)
case s: QueryStageExec => Seq(s.plan)
case l: LocalShuffleReaderExec => Seq(l.child)
case _ => p.children
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.adaptive

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter}


/**
* The [[Partition]] used by [[LocalShuffledRowRDD]]. A pre-shuffle partition
* (identified by `preShufflePartitionIndex`) contains a range of post-shuffle partitions
* (`startPostShufflePartitionIndex` to `endPostShufflePartitionIndex - 1`, inclusive).
*/
private final class LocalShuffleRowRDDPartition(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShuffleRow -> ShuffledRow

val preShufflePartitionIndex: Int) extends Partition {
override val index: Int = preShufflePartitionIndex
}

/**
* This is a specialized version of [[org.apache.spark.sql.execution.ShuffledRowRDD]]. This is used
* in Spark SQL adaptive execution when a shuffle join is converted to broadcast join at runtime
* because the map output of one input table is small enough for broadcast. This RDD represents the
* data of another input table of the join that reads from shuffle. Each partition of the RDD reads
* the whole data from just one mapper output locally. So actually there is no data transferred
* from the network.

* This RDD takes a [[ShuffleDependency]] (`dependency`).
*
* The `dependency` has the parent RDD of this RDD, which represents the dataset before shuffle
* (i.e. map output). Elements of this RDD are (partitionId, Row) pairs.
* Partition ids should be in the range [0, numPartitions - 1].
* `dependency.partitioner.numPartitions` is the number of pre-shuffle partitions. (i.e. the number
* of partitions of the map output). The post-shuffle partition number is the same to the parent
* RDD's partition number.
*/
class LocalShuffledRowRDD(
var dependency: ShuffleDependency[Int, InternalRow, InternalRow],
metrics: Map[String, SQLMetric])
extends RDD[InternalRow](dependency.rdd.context, Nil) {

private[this] val numReducers = dependency.partitioner.numPartitions
private[this] val numMappers = dependency.rdd.partitions.length

override def getDependencies: Seq[Dependency[_]] = List(dependency)

override def getPartitions: Array[Partition] = {

Array.tabulate[Partition](numMappers) { i =>
new LocalShuffleRowRDDPartition(i)
}
}

override def getPreferredLocations(partition: Partition): Seq[String] = {
val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
tracker.getMapLocation(dependency, partition.index)
}

override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
val localRowPartition = split.asInstanceOf[LocalShuffleRowRDDPartition]
val mapId = localRowPartition.index
val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics()
// `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator,
// as well as the `tempMetrics` for basic shuffle metrics.
val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics)

val reader = SparkEnv.get.shuffleManager.getMapReader(
dependency.shuffleHandle,
0,
numReducers,
context,
sqlMetricsReporter,
mapId)
reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2)
}

override def clearDependencies() {
super.clearDependencies()
dependency = null
}
}

Loading