diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ExecutorLocation.java b/core/src/main/java/org/apache/spark/shuffle/api/ExecutorLocation.java new file mode 100644 index 0000000000000..7c81004dd7874 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/ExecutorLocation.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * A type of {@link Location} which based on the executor. + * + * @since 3.2.0 + */ +@Private +public interface ExecutorLocation extends HostLocation { + String executorId(); +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/HostLocation.java b/core/src/main/java/org/apache/spark/shuffle/api/HostLocation.java new file mode 100644 index 0000000000000..2258a5cc1560a --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/HostLocation.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * A type of {@link Location} which based on the host. + * + * @since 3.2.0 + */ +@Private +public interface HostLocation extends Location { + String host(); + + int port(); +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/Location.java b/core/src/main/java/org/apache/spark/shuffle/api/Location.java new file mode 100644 index 0000000000000..ddbda4a9df517 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/Location.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import org.apache.spark.annotation.Private; + +import java.io.Externalizable; +import java.io.ObjectInput; +import java.io.ObjectOutput; + + +/** + * :: Private :: + * An interface for plugging in the location of shuffle files, in order to support store shuffle + * data in different storage, e.g., BlockManager, HDFS, S3. It would be generated by + * {@link ShuffleMapOutputWriter} after writing a shuffle data file and used by ShuffleMapOutputReader + * to read the shuffle data. + * + * Since the location is returned by {@link ShuffleMapOutputWriter#commitAllPartitions()} at executor + * and would be sent to driver, users must ensure the location is serializable by + * + * - implement a 0-arg constructor + * - implement {@link java.io.Externalizable#readExternal(ObjectInput)} for deserialization + * - implement {@link java.io.Externalizable#writeExternal(ObjectOutput)} for serialization + * + * Since the location will be used as keys in maps or comparing with others, users must ensure that + * invoking {@link java.lang.Object#equals(Object)} or {@link java.lang.Object#hashCode()} on the + * {@link Location} instances would distinguish the different locations. + * + * Spark has its own default implementation of {@link Location} as + * {@link org.apache.spark.storage.BlockManagerId}, which is a subclass of {@link ExecutorLocation} + * since each {@link org.apache.spark.storage.BlockManager} must belong to a certain executor. + * And {@link ExecutorLocation} is a subclass of {@link HostLocation} since each executor must + * belong to a certain host. Users should choose the appropriate location interface according to their + * own use cases. + * + * :: Caution :: + * Spark would reuse the same location instance for locations which are equal due to the + * performance concern. Thus, users should also guarantee the implemented {@link Location} + * is IMMUTABLE. + * + * @since 3.2.0 + */ +@Private +public interface Location extends Externalizable { +} diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index cdec1982b4487..533eba8ce0d0b 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -37,7 +37,8 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} +import org.apache.spark.shuffle.api.{ExecutorLocation, HostLocation, Location} +import org.apache.spark.storage.{BlockId, ShuffleBlockId} import org.apache.spark.util._ /** @@ -124,13 +125,13 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { /** * Update the map output location (e.g. during migration). */ - def updateMapOutput(mapId: Long, bmAddress: BlockManagerId): Unit = withWriteLock { + def updateMapOutput(mapId: Long, loc: Location): Unit = withWriteLock { try { val mapStatusOpt = mapStatuses.find(_.mapId == mapId) mapStatusOpt match { case Some(mapStatus) => - logInfo(s"Updating map output for ${mapId} to ${bmAddress}") - mapStatus.updateLocation(bmAddress) + logInfo(s"Updating map output for $mapId to $loc") + mapStatus.updateLocation(loc) invalidateSerializedMapOutputStatusCache() case None => logWarning(s"Asked to update map output ${mapId} for untracked map status.") @@ -146,9 +147,9 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { * This is a no-op if there is no registered map output or if the registered output is from a * different block manager. */ - def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock { - logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}") - if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) { + def removeMapOutput(mapIndex: Int, loc: Location): Unit = withWriteLock { + logDebug(s"Removing existing map output $mapIndex $loc") + if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == loc) { _numAvailableOutputs -= 1 mapStatuses(mapIndex) = null invalidateSerializedMapOutputStatusCache() @@ -161,7 +162,10 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { */ def removeOutputsOnHost(host: String): Unit = withWriteLock { logDebug(s"Removing outputs for host ${host}") - removeOutputsByFilter(x => x.host == host) + removeOutputsByFilter { x => + assert(x.isInstanceOf[HostLocation], s"Required HostLocation, but got $x") + x.asInstanceOf[HostLocation].host == host + } } /** @@ -171,14 +175,17 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { */ def removeOutputsOnExecutor(execId: String): Unit = withWriteLock { logDebug(s"Removing outputs for execId ${execId}") - removeOutputsByFilter(x => x.executorId == execId) + removeOutputsByFilter { x => + assert(x.isInstanceOf[ExecutorLocation], s"Required ExecutorLocation, but got $x") + x.asInstanceOf[ExecutorLocation].executorId == execId + } } /** * Removes all shuffle outputs which satisfies the filter. Note that this will also * remove outputs which are served by an external shuffle server (if one exists). */ - def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock { + def removeOutputsByFilter(f: Location => Boolean): Unit = withWriteLock { for (mapIndex <- mapStatuses.indices) { if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) { _numAvailableOutputs -= 1 @@ -344,7 +351,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + : Iterator[(Location, Seq[(BlockId, Long, Int)])] = { getMapSizesByExecutorId(shuffleId, 0, Int.MaxValue, reduceId, reduceId + 1) } @@ -365,7 +372,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging startMapIndex: Int, endMapIndex: Int, startPartition: Int, - endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + endPartition: Int): Iterator[(Location, Seq[(BlockId, Long, Int)])] /** * Deletes map output status information for the specified shuffle stage. @@ -488,10 +495,10 @@ private[spark] class MapOutputTrackerMaster( } } - def updateMapOutput(shuffleId: Int, mapId: Long, bmAddress: BlockManagerId): Unit = { + def updateMapOutput(shuffleId: Int, mapId: Long, loc: Location): Unit = { shuffleStatuses.get(shuffleId) match { case Some(shuffleStatus) => - shuffleStatus.updateMapOutput(mapId, bmAddress) + shuffleStatus.updateMapOutput(mapId, loc) case None => logError(s"Asked to update map output for unknown shuffle ${shuffleId}") } @@ -502,10 +509,10 @@ private[spark] class MapOutputTrackerMaster( } /** Unregister map output information of the given shuffle, mapper and block manager */ - def unregisterMapOutput(shuffleId: Int, mapIndex: Int, bmAddress: BlockManagerId): Unit = { + def unregisterMapOutput(shuffleId: Int, mapIndex: Int, loc: Location): Unit = { shuffleStatuses.get(shuffleId) match { case Some(shuffleStatus) => - shuffleStatus.removeMapOutput(mapIndex, bmAddress) + shuffleStatus.removeMapOutput(mapIndex, loc) incrementEpoch() case None => throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") @@ -643,10 +650,12 @@ private[spark] class MapOutputTrackerMaster( : Seq[String] = { if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD && dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) { - val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId, + val locations = getLocationsWithLargestOutputs(dep.shuffleId, partitionId, dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION) - if (blockManagerIds.nonEmpty) { - blockManagerIds.get.map(_.host) + if (locations.nonEmpty) { + locations.get + .filter(_.isInstanceOf[HostLocation]) + .map(_.asInstanceOf[HostLocation].host) } else { Nil } @@ -670,14 +679,14 @@ private[spark] class MapOutputTrackerMaster( reducerId: Int, numReducers: Int, fractionThreshold: Double) - : Option[Array[BlockManagerId]] = { + : Option[Array[Location]] = { val shuffleStatus = shuffleStatuses.get(shuffleId).orNull if (shuffleStatus != null) { shuffleStatus.withMapStatuses { statuses => if (statuses.nonEmpty) { // HashMap to add up sizes of all blocks at the same location - val locs = new HashMap[BlockManagerId, Long] + val locs = new HashMap[Location, Long] var totalOutputSize = 0L var mapIdx = 0 while (mapIdx < statuses.length) { @@ -728,7 +737,9 @@ private[spark] class MapOutputTrackerMaster( if (startMapIndex < endMapIndex && (startMapIndex >= 0 && endMapIndex <= statuses.length)) { val statusesPicked = statuses.slice(startMapIndex, endMapIndex).filter(_ != null) - statusesPicked.map(_.location.host).toSeq + statusesPicked + .filter(_.location.isInstanceOf[HostLocation]) + .map(_.location.asInstanceOf[HostLocation].host).toSeq } else { Nil } @@ -758,7 +769,7 @@ private[spark] class MapOutputTrackerMaster( startMapIndex: Int, endMapIndex: Int, startPartition: Int, - endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + endPartition: Int): Iterator[(Location, Seq[(BlockId, Long, Int)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId") shuffleStatuses.get(shuffleId) match { case Some(shuffleStatus) => @@ -810,7 +821,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr startMapIndex: Int, endMapIndex: Int, startPartition: Int, - endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + endPartition: Int): Iterator[(Location, Seq[(BlockId, Long, Int)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId") val statuses = getStatuses(shuffleId, conf) try { @@ -989,9 +1000,9 @@ private[spark] object MapOutputTracker extends Logging { endPartition: Int, statuses: Array[MapStatus], startMapIndex : Int, - endMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + endMapIndex: Int): Iterator[(Location, Seq[(BlockId, Long, Int)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] + val splitsByAddress = new HashMap[Location, ListBuffer[(BlockId, Long, Int)]] val iter = statuses.iterator.zipWithIndex for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) { if (status == null) { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 4a2281a4e8785..d360e24d8842b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1217,6 +1217,24 @@ package object config { .stringConf .createWithDefault(classOf[LocalDiskShuffleDataIO].getName) + private[spark] val SHUFFLE_LOCATION_PLUGIN_CLASS = + ConfigBuilder("spark.shuffle.sort.location.plugin.class") + .doc("Qualified name of the class that used to initiate plugin location instance. " + + "If not specified, Spark will use its native location (a.k.a BlockManagerId) by default.") + .version("3.2.0") + .stringConf + .createOptional + + private[spark] val SHUFFLE_LOCATION_CACHE_SIZE = + ConfigBuilder("spark.shuffle.sort.location.cacheSize") + .doc("The cache size for the location instances. Bigger size means that Spark will have " + + "more chances to reuse the location instance for the same location but takes more memory.") + .version("3.2.0") + .intConf + // In the case of `BlockManagerId`, which takes 48B for each, the total memory cost should + // be below 1MB which is feasible. + .createWithDefault(10000) + private[spark] val SHUFFLE_FILE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.file.buffer") .doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index d711b432ae6df..4557e36da280a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1658,7 +1658,7 @@ private[spark] class DAGScheduler( val shuffleStage = stage.asInstanceOf[ShuffleMapStage] shuffleStage.pendingPartitions -= task.partitionId val status = event.result.asInstanceOf[MapStatus] - val execId = status.location.executorId + val execId = event.taskInfo.executorId logDebug("ShuffleMapTask finished on " + execId) if (executorFailureEpoch.contains(execId) && smt.epoch <= executorFailureEpoch(execId)) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 1239c32cee3ab..113621d56bf48 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -25,7 +25,8 @@ import org.roaringbitmap.RoaringBitmap import org.apache.spark.SparkEnv import org.apache.spark.internal.config -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.scheduler.MapStatus.locationFactory +import org.apache.spark.shuffle.api.Location import org.apache.spark.util.Utils /** @@ -35,9 +36,9 @@ import org.apache.spark.util.Utils */ private[spark] sealed trait MapStatus { /** Location where this task output is. */ - def location: BlockManagerId + def location: Location - def updateLocation(newLoc: BlockManagerId): Unit + def updateLocation(newLoc: Location): Unit /** * Estimated size for the reduce block, in bytes. @@ -52,8 +53,8 @@ private[spark] sealed trait MapStatus { * partitionId of the task or taskContext.taskAttemptId is used. */ def mapId: Long -} +} private[spark] object MapStatus { @@ -65,8 +66,10 @@ private[spark] object MapStatus { .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) + val locationFactory = new MapStatusLocationFactory(SparkEnv.get.conf) + def apply( - loc: BlockManagerId, + loc: Location, uncompressedSizes: Array[Long], mapTaskId: Long): MapStatus = { if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { @@ -115,7 +118,7 @@ private[spark] object MapStatus { * @param _mapTaskId unique task id for the task */ private[spark] class CompressedMapStatus( - private[this] var loc: BlockManagerId, + private[this] var loc: Location, private[this] var compressedSizes: Array[Byte], private[this] var _mapTaskId: Long) extends MapStatus with Externalizable { @@ -123,13 +126,13 @@ private[spark] class CompressedMapStatus( // For deserialization only protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1) - def this(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskId: Long) = { + def this(loc: Location, uncompressedSizes: Array[Long], mapTaskId: Long) = { this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId) } - override def location: BlockManagerId = loc + override def location: Location = loc - override def updateLocation(newLoc: BlockManagerId): Unit = { + override def updateLocation(newLoc: Location): Unit = { loc = newLoc } @@ -147,7 +150,7 @@ private[spark] class CompressedMapStatus( } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - loc = BlockManagerId(in) + loc = locationFactory.load(in) val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) @@ -168,7 +171,7 @@ private[spark] class CompressedMapStatus( * @param _mapTaskId unique task id for the task */ private[spark] class HighlyCompressedMapStatus private ( - private[this] var loc: BlockManagerId, + private[this] var loc: Location, private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, @@ -183,9 +186,9 @@ private[spark] class HighlyCompressedMapStatus private ( protected def this() = this(null, -1, null, -1, null, -1) // For deserialization only - override def location: BlockManagerId = loc + override def location: Location = loc - override def updateLocation(newLoc: BlockManagerId): Unit = { + override def updateLocation(newLoc: Location): Unit = { loc = newLoc } @@ -216,7 +219,7 @@ private[spark] class HighlyCompressedMapStatus private ( } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - loc = BlockManagerId(in) + loc = locationFactory.load(in) numNonEmptyBlocks = -1 // SPARK-32436 Scala 2.13 doesn't initialize this during deserialization emptyBlocks = new RoaringBitmap() emptyBlocks.deserialize(in) @@ -235,7 +238,7 @@ private[spark] class HighlyCompressedMapStatus private ( private[spark] object HighlyCompressedMapStatus { def apply( - loc: BlockManagerId, + loc: Location, uncompressedSizes: Array[Long], mapTaskId: Long): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatusLocationFactory.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatusLocationFactory.scala new file mode 100644 index 0000000000000..f2ba8357a4e37 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatusLocationFactory.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.ObjectInput + +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.internal.config +import org.apache.spark.shuffle.api.Location +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils + +/** + * Factory for creating the [[Location]] of [[MapStatus]]. It creates [[BlockManagerId]] + * by default if the custom location is not configured. + */ +private[spark] class MapStatusLocationFactory(conf: SparkConf) { + private val locationConstructor = { + val locationBaseClass = classOf[Location] + conf.get(config.SHUFFLE_LOCATION_PLUGIN_CLASS).map { className => + val clazz = Utils.classForName(className) + require(locationBaseClass.isAssignableFrom(clazz), + s"$className is not a subclass of ${locationBaseClass.getName}.") + try { + clazz.getConstructor() + } catch { + case _: NoSuchMethodException => + throw new SparkException(s"$className did not have a zero-argument constructor.") + } + }.orNull + } + + // The cache is for reusing the same location instance for the equal locations, + // which helps reduce the objects in JVM. + val locationCache: LoadingCache[Location, Location] = CacheBuilder.newBuilder() + .maximumSize(conf.get(config.SHUFFLE_LOCATION_CACHE_SIZE)) + .build( + new CacheLoader[Location, Location]() { + override def load(loc: Location): Location = loc + } + ) + + def load(in: ObjectInput): Location = Utils.tryOrIOException { + locationCache.get { + Option(locationConstructor).map { ctr => + val loc = ctr.newInstance().asInstanceOf[Location] + loc.readExternal(in) + loc + }.getOrElse(BlockManagerId(in)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index bc2a0fbc36d5b..ee16da987227f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -21,7 +21,8 @@ import org.apache.spark._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator} +import org.apache.spark.shuffle.api.Location +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -30,7 +31,7 @@ import org.apache.spark.util.collection.ExternalSorter */ private[spark] class BlockStoreShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], - blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + blocksByAddress: Iterator[(Location, Seq[(BlockId, Long, Int)])], context: TaskContext, readMetrics: ShuffleReadMetricsReporter, serializerManager: SerializerManager = SparkEnv.get.serializerManager, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index c6a4457d8f910..34f21229cbd84 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -17,12 +17,12 @@ package org.apache.spark.storage -import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} - -import com.google.common.cache.{CacheBuilder, CacheLoader} +import java.io.{IOException, ObjectInput, ObjectOutput} import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.shuffle.api.ExecutorLocation import org.apache.spark.util.Utils /** @@ -40,7 +40,7 @@ class BlockManagerId private ( private var host_ : String, private var port_ : Int, private var topologyInfo_ : Option[String]) - extends Externalizable { + extends ExecutorLocation { private def this() = this(null, null, 0, None) // For deserialization only @@ -129,21 +129,11 @@ private[spark] object BlockManagerId { def apply(in: ObjectInput): BlockManagerId = { val obj = new BlockManagerId() obj.readExternal(in) - getCachedBlockManagerId(obj) + obj } - /** - * The max cache size is hardcoded to 10000, since the size of a BlockManagerId - * object is about 48B, the total memory cost should be below 1MB which is feasible. - */ - val blockManagerIdCache = CacheBuilder.newBuilder() - .maximumSize(10000) - .build(new CacheLoader[BlockManagerId, BlockManagerId]() { - override def load(id: BlockManagerId) = id - }) - def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { - blockManagerIdCache.get(id) + MapStatus.locationFactory.locationCache.get(id).asInstanceOf[BlockManagerId] } private[spark] val SHUFFLE_MERGER_IDENTIFIER = "shuffle-push-merger" diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index fa4e46590aa5e..6fe680c37446a 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -34,6 +34,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} +import org.apache.spark.shuffle.api.Location import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} /** @@ -71,7 +72,7 @@ final class ShuffleBlockFetcherIterator( context: TaskContext, shuffleClient: BlockStoreClient, blockManager: BlockManager, - blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + blocksByAddress: Iterator[(Location, Seq[(BlockId, Long, Int)])], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, @@ -135,7 +136,7 @@ final class ShuffleBlockFetcherIterator( * Queue of fetch requests which could not be issued the first time they were dequeued. These * requests are tried again when the fetch constraints are satisfied. */ - private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]() + private[this] val deferredFetchRequests = new HashMap[Location, Queue[FetchRequest]]() /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L @@ -144,7 +145,7 @@ final class ShuffleBlockFetcherIterator( private[this] var reqsInFlight = 0 /** Current number of blocks in flight per host:port */ - private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + private[this] val numBlocksInFlightPerAddress = new HashMap[Location, Int]() /** * The blocks that can't be decompressed successfully, it is used to guarantee that we retry @@ -235,8 +236,11 @@ final class ShuffleBlockFetcherIterator( } private[this] def sendRequest(req: FetchRequest): Unit = { + val location = req.location + assert(location.isInstanceOf[BlockManagerId], s"Required BlockManagerId, but got $location") + val address = location.asInstanceOf[BlockManagerId] logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + req.blocks.size, Utils.bytesToString(req.size), address.hostPort)) bytesInFlight += req.size reqsInFlight += 1 @@ -246,7 +250,6 @@ final class ShuffleBlockFetcherIterator( }.toMap val remainingBlocks = new HashSet[String]() ++= infoMap.keys val blockIds = req.blocks.map(_.blockId.toString) - val address = req.address val blockFetchingListener = new BlockFetchingListener { override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { @@ -267,7 +270,7 @@ final class ShuffleBlockFetcherIterator( } override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { - logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + logError(s"Failed to get block(s) from ${address.host}:${address.port}", e) results.put(new FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e)) } } @@ -297,7 +300,11 @@ final class ShuffleBlockFetcherIterator( val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId for ((address, blockInfos) <- blocksByAddress) { - if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) { + // TODO (SPARK-35186): Define the local, host-local or remote fetch + // for different locations in a consistent way as some location + // implementations may not have executor ids. + if (Seq(blockManager.blockManagerId.executorId, fallback) + .contains(address.asInstanceOf[BlockManagerId].executorId)) { checkBlockSizes(blockInfos) val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) @@ -305,14 +312,14 @@ final class ShuffleBlockFetcherIterator( localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) localBlockBytes += mergedBlockInfos.map(_.size).sum } else if (blockManager.hostLocalDirManager.isDefined && - address.host == blockManager.blockManagerId.host) { + address.asInstanceOf[BlockManagerId].host == blockManager.blockManagerId.host) { checkBlockSizes(blockInfos) val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) numBlocksToFetch += mergedBlockInfos.size val blocksForAddress = mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex)) - hostLocalBlocksByExecutor += address -> blocksForAddress + hostLocalBlocksByExecutor += address.asInstanceOf[BlockManagerId] -> blocksForAddress hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3)) hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum } else { @@ -335,7 +342,7 @@ final class ShuffleBlockFetcherIterator( private def createFetchRequest( blocks: Seq[FetchBlockInfo], - address: BlockManagerId): FetchRequest = { + address: Location): FetchRequest = { logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address " + s"with ${blocks.size} blocks") FetchRequest(address, blocks) @@ -343,7 +350,7 @@ final class ShuffleBlockFetcherIterator( private def createFetchRequests( curBlocks: Seq[FetchBlockInfo], - address: BlockManagerId, + address: Location, isLast: Boolean, collectedRemoteRequests: ArrayBuffer[FetchRequest]): Seq[FetchBlockInfo] = { val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, doBatchFetch) @@ -367,7 +374,7 @@ final class ShuffleBlockFetcherIterator( } private def collectFetchRequests( - address: BlockManagerId, + address: Location, blockInfos: Seq[(BlockId, Long, Int)], collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = { val iterator = blockInfos.iterator @@ -445,18 +452,18 @@ final class ShuffleBlockFetcherIterator( blockId: BlockId, mapIndex: Int, localDirs: Array[String], - blockManagerId: BlockManagerId): Boolean = { + loc: Location): Boolean = { try { val buf = blockManager.getHostLocalShuffleData(blockId, localDirs) buf.retain() - results.put(SuccessFetchResult(blockId, mapIndex, blockManagerId, buf.size(), buf, + results.put(SuccessFetchResult(blockId, mapIndex, loc, buf.size(), buf, isNetworkReqDone = false)) true } catch { case e: Exception => // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(FailureFetchResult(blockId, mapIndex, blockManagerId, e)) + results.put(FailureFetchResult(blockId, mapIndex, loc, e)) false } } @@ -470,7 +477,7 @@ final class ShuffleBlockFetcherIterator( val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = { val (hasCache, noCache) = hostLocalBlocksByExecutor.partition { case (hostLocalBmId, _) => - cachedDirsByExec.contains(hostLocalBmId.executorId) + cachedDirsByExec.contains(hostLocalBmId.asInstanceOf[BlockManagerId].executorId) } (hasCache.toMap, noCache.toMap) } @@ -728,7 +735,7 @@ final class ShuffleBlockFetcherIterator( // Process any regular fetch requests if possible. while (isRemoteBlockFetchable(fetchRequests)) { val request = fetchRequests.dequeue() - val remoteAddress = request.address + val remoteAddress = request.location if (isRemoteAddressMaxedOut(remoteAddress, request)) { logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks") val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]()) @@ -739,7 +746,7 @@ final class ShuffleBlockFetcherIterator( } } - def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { + def send(remoteAddress: Location, request: FetchRequest): Unit = { sendRequest(request) numBlocksInFlightPerAddress(remoteAddress) = numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size @@ -754,7 +761,7 @@ final class ShuffleBlockFetcherIterator( // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a // given remote address. - def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = { + def isRemoteAddressMaxedOut(remoteAddress: Location, request: FetchRequest): Boolean = { numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > maxBlocksInFlightPerAddress } @@ -763,13 +770,15 @@ final class ShuffleBlockFetcherIterator( private[storage] def throwFetchFailedException( blockId: BlockId, mapIndex: Int, - address: BlockManagerId, + address: Location, e: Throwable) = { + assert(address.isInstanceOf[BlockManagerId], s"Require BlockManagerId, but got $address") + val bmId = address.asInstanceOf[BlockManagerId] blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => - throw new FetchFailedException(address, shufId, mapId, mapIndex, reduceId, e) + throw new FetchFailedException(bmId, shufId, mapId, mapIndex, reduceId, e) case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) => - throw new FetchFailedException(address, shuffleId, mapId, mapIndex, startReduceId, e) + throw new FetchFailedException(bmId, shuffleId, mapId, mapIndex, startReduceId, e) case _ => throw new SparkException( "Failed to get block " + blockId + ", which is not a shuffle block", e) @@ -787,7 +796,7 @@ private class BufferReleasingInputStream( private val iterator: ShuffleBlockFetcherIterator, private val blockId: BlockId, private val mapIndex: Int, - private val address: BlockManagerId, + private val address: Location, private val detectCorruption: Boolean) extends InputStream { private[this] var closed = false @@ -967,10 +976,10 @@ object ShuffleBlockFetcherIterator { /** * A request to fetch blocks from a remote BlockManager. - * @param address remote BlockManager to fetch from. + * @param location the location to fetch from. * @param blocks Sequence of the information for blocks to fetch from the same address. */ - case class FetchRequest(address: BlockManagerId, blocks: Seq[FetchBlockInfo]) { + case class FetchRequest(location: Location, blocks: Seq[FetchBlockInfo]) { val size = blocks.map(_.size).sum } @@ -979,7 +988,7 @@ object ShuffleBlockFetcherIterator { */ private[storage] sealed trait FetchResult { val blockId: BlockId - val address: BlockManagerId + val address: Location } /** @@ -995,7 +1004,7 @@ object ShuffleBlockFetcherIterator { private[storage] case class SuccessFetchResult( blockId: BlockId, mapIndex: Int, - address: BlockManagerId, + address: Location, size: Long, buf: ManagedBuffer, isNetworkReqDone: Boolean) extends FetchResult { @@ -1013,7 +1022,7 @@ object ShuffleBlockFetcherIterator { private[storage] case class FailureFetchResult( blockId: BlockId, mapIndex: Int, - address: BlockManagerId, + address: Location, e: Throwable) extends FetchResult } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 4c74e4fbb3728..fe02b6774f7ad 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -538,14 +538,14 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti val initialMapStatus1 = mapOutputTracker.shuffleStatuses(firstShuffleId).mapStatuses // val initialMapStatus1 = mapOutputTracker.mapStatuses.get(0).get assert(initialMapStatus1.count(_ != null) === 3) - assert(initialMapStatus1.map{_.location.executorId}.toSet === + assert(initialMapStatus1.map{_.location.asInstanceOf[BlockManagerId].executorId}.toSet === Set("hostA-exec1", "hostA-exec2", "hostB-exec")) assert(initialMapStatus1.map{_.mapId}.toSet === Set(5, 6, 7)) val initialMapStatus2 = mapOutputTracker.shuffleStatuses(secondShuffleId).mapStatuses // val initialMapStatus1 = mapOutputTracker.mapStatuses.get(0).get assert(initialMapStatus2.count(_ != null) === 3) - assert(initialMapStatus2.map{_.location.executorId}.toSet === + assert(initialMapStatus2.map{_.location.asInstanceOf[BlockManagerId].executorId}.toSet === Set("hostA-exec1", "hostA-exec2", "hostB-exec")) assert(initialMapStatus2.map{_.mapId}.toSet === Set(8, 9, 10)) @@ -561,13 +561,13 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti val mapStatus1 = mapOutputTracker.shuffleStatuses(firstShuffleId).mapStatuses assert(mapStatus1.count(_ != null) === 1) - assert(mapStatus1(2).location.executorId === "hostB-exec") - assert(mapStatus1(2).location.host === "hostB") + assert(mapStatus1(2).location.asInstanceOf[BlockManagerId].executorId === "hostB-exec") + assert(mapStatus1(2).location.asInstanceOf[BlockManagerId].host === "hostB") val mapStatus2 = mapOutputTracker.shuffleStatuses(secondShuffleId).mapStatuses assert(mapStatus2.count(_ != null) === 1) - assert(mapStatus2(2).location.executorId === "hostB-exec") - assert(mapStatus2(2).location.host === "hostB") + assert(mapStatus2(2).location.asInstanceOf[BlockManagerId].executorId === "hostB-exec") + assert(mapStatus2(2).location.asInstanceOf[BlockManagerId].host === "hostB") } test("SPARK-32003: All shuffle files for executor should be cleaned up on fetch failure") { @@ -591,8 +591,10 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti // The MapOutputTracker has all the shuffle files val mapStatuses = mapOutputTracker.shuffleStatuses(shuffleId).mapStatuses assert(mapStatuses.count(_ != null) === 3) - assert(mapStatuses.count(s => s != null && s.location.executorId == "hostA-exec") === 2) - assert(mapStatuses.count(s => s != null && s.location.executorId == "hostB-exec") === 1) + assert(mapStatuses.count(s => s != null && + s.location.asInstanceOf[BlockManagerId].executorId == "hostA-exec") === 2) + assert(mapStatuses.count(s => s != null && + s.location.asInstanceOf[BlockManagerId].executorId == "hostB-exec") === 1) // Now a fetch failure from the lost executor occurs complete(taskSets(1), Seq( @@ -605,8 +607,10 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti // Shuffle files for hostA-exec should be lost assert(mapStatuses.count(_ != null) === 1) - assert(mapStatuses.count(s => s != null && s.location.executorId == "hostA-exec") === 0) - assert(mapStatuses.count(s => s != null && s.location.executorId == "hostB-exec") === 1) + assert(mapStatuses.count(s => s != null && + s.location.asInstanceOf[BlockManagerId].executorId == "hostA-exec") === 0) + assert(mapStatuses.count(s => s != null && + s.location.asInstanceOf[BlockManagerId].executorId == "hostB-exec") === 1) // Additional fetch failure from the executor does not result in further call to // mapOutputTracker.removeOutputsOnExecutor @@ -843,7 +847,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) // we can see both result blocks now - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0) + .map(_._1.asInstanceOf[BlockManagerId].host).toSet === HashSet("hostA", "hostB")) completeAndCheckAnswer(taskSets(3), Seq((Success, 43)), Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty() @@ -1228,7 +1233,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti submit(reduceRdd, Array(0, 1)) completeShuffleMapStageSuccessfully(0, 0, reduceRdd.partitions.length) // The MapOutputTracker should know about both map output locations. - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0) + .map(_._1.asInstanceOf[BlockManagerId].host).toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. @@ -1349,9 +1355,11 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti completeShuffleMapStageSuccessfully(0, 0, 2) // The MapOutputTracker should know about both map output locations. - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0) + .map(_._1.asInstanceOf[BlockManagerId].host).toSet === HashSet("hostA", "hostB")) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 1).map(_._1.host).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 1) + .map(_._1.asInstanceOf[BlockManagerId].host).toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper.