diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java new file mode 100644 index 0000000000000..a312831cb6282 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import org.apache.spark.api.java.Optional; + +import java.util.Objects; + +/** + * :: Experimental :: + * An object defining the shuffle block and length metadata associated with the block. + * @since 3.0.0 + */ +public class ShuffleBlockInfo { + private final int shuffleId; + private final int mapId; + private final int reduceId; + private final long length; + private final Optional shuffleLocation; + + public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length, + Optional shuffleLocation) { + this.shuffleId = shuffleId; + this.mapId = mapId; + this.reduceId = reduceId; + this.length = length; + this.shuffleLocation = shuffleLocation; + } + + public int getShuffleId() { + return shuffleId; + } + + public int getMapId() { + return mapId; + } + + public int getReduceId() { + return reduceId; + } + + public long getLength() { + return length; + } + + public Optional getShuffleLocation() { + return shuffleLocation; + } + + @Override + public boolean equals(Object other) { + return other instanceof ShuffleBlockInfo + && shuffleId == ((ShuffleBlockInfo) other).shuffleId + && mapId == ((ShuffleBlockInfo) other).mapId + && reduceId == ((ShuffleBlockInfo) other).reduceId + && length == ((ShuffleBlockInfo) other).length + && Objects.equals(shuffleLocation, ((ShuffleBlockInfo) other).shuffleLocation); + } + + @Override + public int hashCode() { + return Objects.hash(shuffleId, mapId, reduceId, length, shuffleLocation); + } +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java index 4fc20bad9938b..8baa3bf6f859a 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java @@ -30,4 +30,6 @@ public interface ShuffleExecutorComponents { void initializeExecutor(String appId, String execId); ShuffleWriteSupport writes(); + + ShuffleReadSupport reads(); } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java index 87eb497098e0c..d06c11b3c01ee 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java @@ -21,5 +21,4 @@ * Marker interface representing a location of a shuffle block. Implementations of shuffle readers * and writers are expected to cast this down to an implementation-specific representation. */ -public interface ShuffleLocation { -} +public interface ShuffleLocation {} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java new file mode 100644 index 0000000000000..9cd8fde09064b --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import org.apache.spark.annotation.Experimental; + +import java.io.IOException; +import java.io.InputStream; + +/** + * :: Experimental :: + * An interface for reading shuffle records. + * @since 3.0.0 + */ +@Experimental +public interface ShuffleReadSupport { + /** + * Returns an underlying {@link Iterable} that will iterate + * through shuffle data, given an iterable for the shuffle blocks to fetch. + */ + Iterable getPartitionReaders(Iterable blockMetadata) + throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java index f7ec202ef4b9d..91a5d7f7945ee 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -17,11 +17,15 @@ package org.apache.spark.shuffle.sort.io; +import org.apache.spark.MapOutputTracker; import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; import org.apache.spark.api.shuffle.ShuffleExecutorComponents; +import org.apache.spark.api.shuffle.ShuffleReadSupport; import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport; import org.apache.spark.storage.BlockManager; public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents { @@ -29,6 +33,8 @@ public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponen private final SparkConf sparkConf; private BlockManager blockManager; private IndexShuffleBlockResolver blockResolver; + private MapOutputTracker mapOutputTracker; + private SerializerManager serializerManager; public DefaultShuffleExecutorComponents(SparkConf sparkConf) { this.sparkConf = sparkConf; @@ -37,15 +43,30 @@ public DefaultShuffleExecutorComponents(SparkConf sparkConf) { @Override public void initializeExecutor(String appId, String execId) { blockManager = SparkEnv.get().blockManager(); + mapOutputTracker = SparkEnv.get().mapOutputTracker(); + serializerManager = SparkEnv.get().serializerManager(); blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); } @Override public ShuffleWriteSupport writes() { + checkInitialized(); + return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId()); + } + + @Override + public ShuffleReadSupport reads() { + checkInitialized(); + return new DefaultShuffleReadSupport(blockManager, + mapOutputTracker, + serializerManager, + sparkConf); + } + + private void checkInitialized() { if (blockResolver == null) { throw new IllegalStateException( - "Executor components must be initialized before getting writers."); + "Executor components must be initialized before getting writers."); } - return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId()); } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 74975019e7480..ebddf5ff6f6e0 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -283,7 +283,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByShuffleLocation(shuffleId: Int, reduceId: Int) - : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1) } @@ -297,7 +297,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -647,7 +647,7 @@ private[spark] class MapOutputTrackerMaster( // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -684,7 +684,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. override def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { @@ -873,9 +873,9 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[ShuffleLocation, ListBuffer[(BlockId, Long)]] + val splitsByAddress = new HashMap[Option[ShuffleLocation], ListBuffer[(BlockId, Long)]] for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" @@ -885,9 +885,14 @@ private[spark] object MapOutputTracker extends Logging { for (part <- startPartition until endPartition) { val size = status.getSizeForBlock(part) if (size != 0) { - val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part) - splitsByAddress.getOrElseUpdate(shuffleLoc, ListBuffer()) += + if (status.mapShuffleLocations == null) { + splitsByAddress.getOrElseUpdate(Option.empty, ListBuffer()) += ((ShuffleBlockId(shuffleId, mapId, part), size)) + } else { + val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part) + splitsByAddress.getOrElseUpdate(Option.apply(shuffleLoc), ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), size)) + } } } } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index ea79c7310349d..df30fd5c7f679 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -56,6 +56,8 @@ class TaskMetrics private[spark] () extends Serializable { private val _diskBytesSpilled = new LongAccumulator private val _peakExecutionMemory = new LongAccumulator private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)] + private var _decorFunc: TempShuffleReadMetrics => TempShuffleReadMetrics = + Predef.identity[TempShuffleReadMetrics] /** * Time taken on the executor to deserialize this task. @@ -187,11 +189,17 @@ class TaskMetrics private[spark] () extends Serializable { * be lost. */ private[spark] def createTempShuffleReadMetrics(): TempShuffleReadMetrics = synchronized { - val readMetrics = new TempShuffleReadMetrics - tempShuffleReadMetrics += readMetrics + val tempShuffleMetrics = new TempShuffleReadMetrics + val readMetrics = _decorFunc(tempShuffleMetrics) + tempShuffleReadMetrics += tempShuffleMetrics readMetrics } + private[spark] def decorateTempShuffleReadMetrics( + decorFunc: TempShuffleReadMetrics => TempShuffleReadMetrics): Unit = synchronized { + _decorFunc = decorFunc + } + /** * Merge values across all temporary [[ShuffleReadMetrics]] into `_shuffleReadMetrics`. * This is expected to be called on executor heartbeat and at the end of a task. diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index d6f63e71f113c..530c3694ad1ec 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,11 +17,18 @@ package org.apache.spark.shuffle +import java.io.InputStream + +import scala.collection.JavaConverters._ + import org.apache.spark._ +import org.apache.spark.api.java.Optional +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.{config, Logging} +import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager -import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.storage.{ShuffleBlockFetcherIterator, ShuffleBlockId} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -35,40 +42,68 @@ private[spark] class BlockStoreShuffleReader[K, C]( endPartition: Int, context: TaskContext, readMetrics: ShuffleReadMetricsReporter, + shuffleReadSupport: ShuffleReadSupport, serializerManager: SerializerManager = SparkEnv.get.serializerManager, - blockManager: BlockManager = SparkEnv.get.blockManager, - mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, + sparkConf: SparkConf = SparkEnv.get.conf) extends ShuffleReader[K, C] with Logging { private val dep = handle.dependency + private val compressionCodec = CompressionCodec.createCodec(sparkConf) + + private val compressShuffle = sparkConf.get(config.SHUFFLE_COMPRESS) + /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val wrappedStreams = new ShuffleBlockFetcherIterator( - context, - blockManager.shuffleClient, - blockManager, - mapOutputTracker.getMapSizesByShuffleLocation(handle.shuffleId, startPartition, endPartition) - .map { - case (loc: DefaultMapShuffleLocations, blocks: Seq[(BlockId, Long)]) => - (loc.getBlockManagerId, blocks) - case _ => - throw new UnsupportedOperationException("Not allowed to using non-default map shuffle" + - " locations yet.") - }, - serializerManager.wrapStream, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, - SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), - SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), - SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), - readMetrics).toCompletionIterator + val streamsIterator = + shuffleReadSupport.getPartitionReaders(new Iterable[ShuffleBlockInfo] { + override def iterator: Iterator[ShuffleBlockInfo] = { + mapOutputTracker + .getMapSizesByShuffleLocation(handle.shuffleId, startPartition, endPartition) + .flatMap { shuffleLocationInfo => + shuffleLocationInfo._2.map { blockInfo => + val block = blockInfo._1.asInstanceOf[ShuffleBlockId] + new ShuffleBlockInfo( + block.shuffleId, + block.mapId, + block.reduceId, + blockInfo._2, + Optional.ofNullable(shuffleLocationInfo._1.orNull)) + } + } + } + }.asJava).iterator() - val serializerInstance = dep.serializer.newInstance() + val retryingWrappedStreams = new Iterator[InputStream] { + override def hasNext: Boolean = streamsIterator.hasNext - // Create a key/value iterator for each stream - val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => + override def next(): InputStream = { + var returnStream: InputStream = null + while (streamsIterator.hasNext && returnStream == null) { + if (shuffleReadSupport.isInstanceOf[DefaultShuffleReadSupport]) { + // The default implementation checks for corrupt streams, so it will already have + // decompressed/decrypted the bytes + returnStream = streamsIterator.next() + } else { + val nextStream = streamsIterator.next() + returnStream = if (compressShuffle) { + compressionCodec.compressedInputStream( + serializerManager.wrapForEncryption(nextStream)) + } else { + serializerManager.wrapForEncryption(nextStream) + } + } + } + if (returnStream == null) { + throw new IllegalStateException("Expected shuffle reader iterator to return a stream") + } + returnStream + } + } + + val serializerInstance = dep.serializer.newInstance() + val recordIter = retryingWrappedStreams.flatMap { wrappedStream => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala new file mode 100644 index 0000000000000..9b9b8508e88aa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.io + +import java.io.InputStream + +import scala.collection.JavaConverters._ + +import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} +import org.apache.spark.internal.config +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations +import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} + +class DefaultShuffleReadSupport( + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, + serializerManager: SerializerManager, + conf: SparkConf) extends ShuffleReadSupport { + + private val maxBytesInFlight = conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 + private val maxReqsInFlight = conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT) + private val maxBlocksInFlightPerAddress = + conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS) + private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + private val detectCorrupt = conf.get(config.SHUFFLE_DETECT_CORRUPT) + + override def getPartitionReaders(blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): + java.lang.Iterable[InputStream] = { + + val iterableToReturn = if (blockMetadata.asScala.isEmpty) { + Iterable.empty + } else { + val (minReduceId, maxReduceId) = blockMetadata.asScala.map(block => block.getReduceId) + .foldLeft(Int.MaxValue, 0) { + case ((min, max), elem) => (math.min(min, elem), math.max(max, elem)) + } + val shuffleId = blockMetadata.asScala.head.getShuffleId + new ShuffleBlockFetcherIterable( + TaskContext.get(), + blockManager, + serializerManager, + maxBytesInFlight, + maxReqsInFlight, + maxBlocksInFlightPerAddress, + maxReqSizeShuffleToMem, + detectCorrupt, + shuffleMetrics = TaskContext.get().taskMetrics().createTempShuffleReadMetrics(), + minReduceId, + maxReduceId, + shuffleId, + mapOutputTracker + ) + } + iterableToReturn.asJava + } +} + +private class ShuffleBlockFetcherIterable( + context: TaskContext, + blockManager: BlockManager, + serializerManager: SerializerManager, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + maxReqSizeShuffleToMem: Long, + detectCorruption: Boolean, + shuffleMetrics: ShuffleReadMetricsReporter, + minReduceId: Int, + maxReduceId: Int, + shuffleId: Int, + mapOutputTracker: MapOutputTracker) extends Iterable[InputStream] { + + override def iterator: Iterator[InputStream] = { + new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, minReduceId, maxReduceId + 1) + .map { shuffleLocationInfo => + val defaultShuffleLocation = shuffleLocationInfo._1 + .get.asInstanceOf[DefaultMapShuffleLocations] + (defaultShuffleLocation.getBlockManagerId, shuffleLocationInfo._2) + }, + serializerManager.wrapStream, + maxBytesInFlight, + maxReqsInFlight, + maxBlocksInFlightPerAddress, + maxReqSizeShuffleToMem, + detectCorruption, + shuffleMetrics).toCompletionIterator + } + +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 849050556c569..38495ae523d86 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -124,7 +124,11 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startPartition, endPartition, context, metrics) + startPartition, + endPartition, + context, + metrics, + shuffleExecutorComponents.reads()) } /** Get a writer for a given partition. Called on executors by map tasks. */ diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 3966980a11ed0..287ffdd6e10e6 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -74,7 +74,7 @@ final class ShuffleBlockFetcherIterator( maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean, shuffleMetrics: ShuffleReadMetricsReporter) - extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { + extends Iterator[InputStream] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -397,7 +397,7 @@ final class ShuffleBlockFetcherIterator( * * Throws a FetchFailedException if the next block could not be fetched. */ - override def next(): (BlockId, InputStream) = { + override def next(): InputStream = { if (!hasNext) { throw new NoSuchElementException() } @@ -495,7 +495,6 @@ final class ShuffleBlockFetcherIterator( in.close() } } - case FailureFetchResult(blockId, address, e) => throwFetchFailedException(blockId, address, e) } @@ -508,11 +507,17 @@ final class ShuffleBlockFetcherIterator( throw new NoSuchElementException() } currentResult = result.asInstanceOf[SuccessFetchResult] - (currentResult.blockId, new BufferReleasingInputStream(input, this)) + val blockId = currentResult.blockId.asInstanceOf[ShuffleBlockId] + new BufferReleasingInputStream(input, this) + } + + // for testing only + def getCurrentBlock(): ShuffleBlockId = { + currentResult.blockId.asInstanceOf[ShuffleBlockId] } - def toCompletionIterator: Iterator[(BlockId, InputStream)] = { - CompletionIterator[(BlockId, InputStream), this.type](this, + def toCompletionIterator: Iterator[InputStream] = { + CompletionIterator[InputStream, this.type](this, onCompleteCallback.onComplete(context)) } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 0a77c4f6d5838..8fcbc845d1a7b 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -71,9 +71,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val statuses = tracker.getMapSizesByShuffleLocation(10, 0) assert(statuses.toSet === Seq( - (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), + (Some(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000))), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), - (DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000)), + (Some(DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000))), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) .toSet) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -155,7 +155,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByShuffleLocation(10, 0).toSeq === Seq( - (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), + (Some(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000))), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) @@ -324,12 +324,13 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(size10000, size0, size1000, size0))) assert(tracker.containsShuffle(10)) - assert(tracker.getMapSizesByShuffleLocation(10, 0, 4).toSeq === + assert(tracker.getMapSizesByShuffleLocation(10, 0, 4) + .map(x => (x._1.get, x._2)).toSeq === Seq( - (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), - Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), (DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000)), - Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))) + Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))), + (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), + Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))) ) ) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 83026c002f1b2..1d2713151f505 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -409,12 +409,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val taskContext = new TaskContextImpl( 1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem) + TaskContext.setTaskContext(taskContext) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) manager.unregisterShuffle(0) + TaskContext.unset() } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 14b93957734e4..21b4e56c9e801 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -703,7 +703,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -731,7 +731,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // we can see both result blocks now assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 0) - .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) .toSet === HashSet("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) @@ -774,7 +774,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } } else { assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) } } } @@ -1069,7 +1069,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 0) - .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) .toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. @@ -1201,11 +1201,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 0) - .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) .toSet === HashSet("hostA", "hostB")) assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 1) - .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) .toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. @@ -1397,7 +1397,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 2) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostB"), makeShuffleLocation("hostA"))) + HashSet(makeMaybeShuffleLocation("hostB"), makeMaybeShuffleLocation("hostA"))) // finish the next stage normally, which completes the job complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -1803,7 +1803,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostC"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB"))) // Make sure that the reduce stage was now submitted. assert(taskSets.size === 3) @@ -2066,7 +2066,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"))) + HashSet(makeMaybeShuffleLocation("hostA"))) // Reducer should run on the same host that map task ran val reduceTaskSet = taskSets(1) @@ -2112,7 +2112,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"))) + HashSet(makeMaybeShuffleLocation("hostA"))) // Reducer should run where RDD 2 has preferences, even though it also has a shuffle dep val reduceTaskSet = taskSets(1) @@ -2276,7 +2276,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) assert(listener1.results.size === 1) // When attempting the second stage, show a fetch failure @@ -2292,7 +2292,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostC"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB"))) assert(listener2.results.size === 0) // Second stage listener should still not have a result @@ -2302,7 +2302,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostB", rdd2.partitions.length)), (Success, makeMapStatus("hostD", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep2.shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostB"), makeShuffleLocation("hostD"))) + HashSet(makeMaybeShuffleLocation("hostB"), makeMaybeShuffleLocation("hostD"))) assert(listener2.results.size === 1) // Finally, the reduce job should be running as task set 4; make it see a fetch failure, @@ -2341,7 +2341,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) assert(listener1.results.size === 1) // When attempting stage1, trigger a fetch failure. @@ -2367,7 +2367,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - Set(makeShuffleLocation("hostC"), makeShuffleLocation("hostB"))) + Set(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB"))) // After stage0 is finished, stage1 will be submitted and found there is no missing // partitions in it. Then listener got triggered. @@ -2923,6 +2923,10 @@ object DAGSchedulerSuite { def makeShuffleLocation(host: String): MapShuffleLocations = { DefaultMapShuffleLocations.get(makeBlockManagerId(host)) } + + def makeMaybeShuffleLocation(host: String): Option[MapShuffleLocations] = { + Some(DefaultMapShuffleLocations.get(makeBlockManagerId(host))) + } } object FailThisAttempt { diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index b3073addb7ccc..6468914bf3185 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -21,13 +21,19 @@ import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.apache.spark._ +import org.apache.spark.api.shuffle.ShuffleLocation import org.apache.spark.internal.config +import org.apache.spark.io.CompressionCodec import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.BlockId /** * Wrapper for a managed buffer that keeps track of how many times retain and release are called. @@ -79,11 +85,14 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Create a buffer with some randomly generated key-value pairs to use as the shuffle data // from each mappers (all mappers return the same shuffle data). val byteOutputStream = new ByteArrayOutputStream() - val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) + val compressionCodec = CompressionCodec.createCodec(testConf) + val compressedOutputStream = compressionCodec.compressedOutputStream(byteOutputStream) + val serializationStream = serializer.newInstance().serializeStream(compressedOutputStream) (0 until keyValuePairsPerMap).foreach { i => serializationStream.writeKey(i) serializationStream.writeValue(2*i) } + compressedOutputStream.close() // Setup the mocked BlockManager to return RecordingManagedBuffers. val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) @@ -102,19 +111,20 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. - val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByShuffleLocation( - shuffleId, reduceId, reduceId + 1)).thenReturn { - // Test a scenario where all data is local, to avoid creating a bunch of additional mocks - // for the code to read data over the network. - val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => - val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - (shuffleBlockId, byteOutputStream.size().toLong) - } - Seq( - (DefaultMapShuffleLocations.get(localBlockManagerId), shuffleBlockIdsAndSizes)) - .toIterator + val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + (shuffleBlockId, byteOutputStream.size().toLong) } + val blocksToRetrieve = Seq( + (Option.apply(DefaultMapShuffleLocations.get(localBlockManagerId)), shuffleBlockIdsAndSizes)) + val mapOutputTracker = mock(classOf[MapOutputTracker]) + when(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1)) + .thenAnswer(new Answer[Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]] { + def answer(invocationOnMock: InvocationOnMock): + Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { + blocksToRetrieve.iterator + } + }) // Create a mocked shuffle handle to pass into HashShuffleReader. val shuffleHandle = { @@ -128,19 +138,23 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val serializerManager = new SerializerManager( serializer, new SparkConf() - .set(config.SHUFFLE_COMPRESS, false) + .set(config.SHUFFLE_COMPRESS, true) .set(config.SHUFFLE_SPILL_COMPRESS, false)) val taskContext = TaskContext.empty() + TaskContext.setTaskContext(taskContext) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() + + val shuffleReadSupport = + new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, testConf) val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, taskContext, metrics, + shuffleReadSupport, serializerManager, - blockManager, mapOutputTracker) assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) @@ -151,5 +165,6 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext assert(buffer.callsToRetain === 1) assert(buffer.callsToRelease === 1) } + TaskContext.unset() } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index b39e37c1e3842..4dc1251a4ca84 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -34,10 +34,10 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.network.BlockTransferService import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} -import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, FetchFailedException} +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, BlockManagerMaster, ShuffleBlockId} import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener, Utils} @@ -199,21 +199,28 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { val shuffleBlockId = ShuffleBlockId(0, mapId, 0) (shuffleBlockId, dataFileLength) } - Seq((DefaultMapShuffleLocations.get(dataBlockId), shuffleBlockIdsAndSizes)).toIterator + Seq((Option.apply(DefaultMapShuffleLocations.get(dataBlockId)), shuffleBlockIdsAndSizes)) + .toIterator } when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(aggregator) when(dependency.keyOrdering).thenReturn(sorter) + val readSupport = new DefaultShuffleReadSupport( + blockManager, + mapOutputTracker, + serializerManager, + defaultConf) + new BlockStoreShuffleReader[String, String]( shuffleHandle, 0, 1, taskContext, taskContext.taskMetrics().createTempShuffleReadMetrics(), + readSupport, serializerManager, - blockManager, mapOutputTracker ) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 98fe9663b6211..a4b6920be04c0 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -21,14 +21,13 @@ import java.io.{File, InputStream, IOException} import java.util.UUID import java.util.concurrent.Semaphore -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.Future - import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{mock, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ @@ -125,7 +124,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - val (blockId, inputStream) = iterator.next() + val inputStream = iterator.next() + val blockId = iterator.getCurrentBlock() // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) @@ -200,11 +200,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext.taskMetrics.createTempShuffleReadMetrics()) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() - iterator.next()._2.close() // close() first block's input stream + iterator.next().close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator - val subIter = iterator.next()._2 + val subIter = iterator.next() // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() @@ -402,7 +402,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT sem.acquire() // The first block should be returned without an exception - val (id1, _) = iterator.next() + iterator.next() + val id1 = iterator.getCurrentBlock() assert(id1 === ShuffleBlockId(0, 0, 0)) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) @@ -422,6 +423,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } sem.acquire() + intercept[FetchFailedException] { iterator.next() } } @@ -463,8 +465,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Blocks should be returned without exceptions. - assert(Set(iterator.next()._1, iterator.next()._1) === - Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) + iterator.next() + val blockId1 = iterator.getCurrentBlock() + iterator.next() + val blockId2 = iterator.getCurrentBlock() + assert(Set(blockId1, blockId2) === Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) } test("retry corrupt blocks (disabled)") { @@ -522,11 +527,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT sem.acquire() // The first block should be returned without an exception - val (id1, _) = iterator.next() + iterator.next() + val id1 = iterator.getCurrentBlock() assert(id1 === ShuffleBlockId(0, 0, 0)) - val (id2, _) = iterator.next() + iterator.next() + val id2 = iterator.getCurrentBlock() assert(id2 === ShuffleBlockId(0, 1, 0)) - val (id3, _) = iterator.next() + iterator.next() + val id3 = iterator.getCurrentBlock() assert(id3 === ShuffleBlockId(0, 2, 0)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 079ff25fcb67e..22cfbf506c645 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -156,10 +156,11 @@ class ShuffledRowRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition] - val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, // as well as the `tempMetrics` for basic shuffle metrics. - val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) + context.taskMetrics().decorateTempShuffleReadMetrics( + tempMetrics => new SQLShuffleReadMetricsReporter(tempMetrics, metrics)) + val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() // The range of pre-shuffle partitions that we are fetching at here is // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1]. val reader = @@ -168,7 +169,7 @@ class ShuffledRowRDD( shuffledRowPartition.startPreShufflePartitionIndex, shuffledRowPartition.endPreShufflePartitionIndex, context, - sqlMetricsReporter) + tempMetrics) reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) }