From 864d1cd46235bbec59ee3773d8f2c4ca64c8f312 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 20 Mar 2019 16:49:52 -0700 Subject: [PATCH 01/56] initial API --- .../spark/api/shuffle/BlockMetadata.java | 84 +++++++++++++++++++ .../spark/api/shuffle/ShuffleReadSupport.java | 30 +++++++ 2 files changed, 114 insertions(+) create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/BlockMetadata.java create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java diff --git a/core/src/main/java/org/apache/spark/api/shuffle/BlockMetadata.java b/core/src/main/java/org/apache/spark/api/shuffle/BlockMetadata.java new file mode 100644 index 0000000000000..2633a87b6208f --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/BlockMetadata.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import org.apache.spark.storage.BlockManagerId; + +import java.util.Optional; + +public final class BlockMetadata { + private final String appId; + private final int shuffleId; + private final int mapId; + private final int reduceId; + private final long length; + private final Optional shuffleLocation; + + private BlockMetadata( + String appId, + int shuffleId, + int mapId, + int reduceId, + long length, + Optional shuffleLocation) { + this.appId = appId; + this.shuffleId = shuffleId; + this.mapId = mapId; + this.reduceId = reduceId; + this.length = length; + this.shuffleLocation = shuffleLocation; + } + + public static BlockMetadata create(String appId, int shuffleId, int mapId, int reduceId, long length) { + return new BlockMetadata(appId, shuffleId, mapId, reduceId, length, Optional.empty()); + } + + public static BlockMetadata create( + String appId, + int shuffleId, + int mapId, + int reduceId, + long length, + BlockManagerId blockManagerId) { + return new BlockMetadata(appId, shuffleId, mapId, reduceId, length, Optional.of(blockManagerId)); + } + + public String getAppId() { + return appId; + } + + public int getShuffleId() { + return shuffleId; + } + + public int getMapId() { + return mapId; + } + + public int getReduceId() { + return reduceId; + } + + public long getLength() { + return length; + } + + public Optional getShuffleLocation() { + return shuffleLocation; + } +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java new file mode 100644 index 0000000000000..7844c427fa564 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.IOException; +import java.io.InputStream; + +/** + * :: Experimental :: + * An interface for reading shuffle records + */ +public interface ShuffleReadSupport { + Iterable getPartitionReaders( + Iterable blockMetadata) throws IOException; +} From c88751c646d168f2599a08c70b593142528e340d Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Thu, 21 Mar 2019 12:36:09 -0700 Subject: [PATCH 02/56] wip --- .../spark/api/shuffle/BlockMetadata.java | 14 ++---- .../spark/api/shuffle/ShuffleReadSupport.java | 3 +- .../shuffle/BlockStoreShuffleReader.scala | 4 +- .../io/DefaultShuffleReadSupport.scala | 49 +++++++++++++++++++ .../storage/ShuffleBlockFetcherIterator.scala | 10 ++-- 5 files changed, 59 insertions(+), 21 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala diff --git a/core/src/main/java/org/apache/spark/api/shuffle/BlockMetadata.java b/core/src/main/java/org/apache/spark/api/shuffle/BlockMetadata.java index 2633a87b6208f..c14fdf0475e3b 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/BlockMetadata.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/BlockMetadata.java @@ -22,7 +22,6 @@ import java.util.Optional; public final class BlockMetadata { - private final String appId; private final int shuffleId; private final int mapId; private final int reduceId; @@ -30,13 +29,11 @@ public final class BlockMetadata { private final Optional shuffleLocation; private BlockMetadata( - String appId, int shuffleId, int mapId, int reduceId, long length, Optional shuffleLocation) { - this.appId = appId; this.shuffleId = shuffleId; this.mapId = mapId; this.reduceId = reduceId; @@ -44,22 +41,17 @@ private BlockMetadata( this.shuffleLocation = shuffleLocation; } - public static BlockMetadata create(String appId, int shuffleId, int mapId, int reduceId, long length) { - return new BlockMetadata(appId, shuffleId, mapId, reduceId, length, Optional.empty()); + public static BlockMetadata create(int shuffleId, int mapId, int reduceId, long length) { + return new BlockMetadata(shuffleId, mapId, reduceId, length, Optional.empty()); } public static BlockMetadata create( - String appId, int shuffleId, int mapId, int reduceId, long length, BlockManagerId blockManagerId) { - return new BlockMetadata(appId, shuffleId, mapId, reduceId, length, Optional.of(blockManagerId)); - } - - public String getAppId() { - return appId; + return new BlockMetadata(shuffleId, mapId, reduceId, length, Optional.of(blockManagerId)); } public int getShuffleId() { diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index 7844c427fa564..82a1f46a23fe0 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -25,6 +25,5 @@ * An interface for reading shuffle records */ public interface ShuffleReadSupport { - Iterable getPartitionReaders( - Iterable blockMetadata) throws IOException; + Iterable getPartitionReaders(Iterable blockMetadata) throws IOException; } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index c5eefc7c5c049..b06cb8bd2614b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -44,7 +44,6 @@ private[spark] class BlockStoreShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { val wrappedStreams = new ShuffleBlockFetcherIterator( - context, blockManager.shuffleClient, blockManager, mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), @@ -54,8 +53,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), - readMetrics).toCompletionIterator + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT)).toCompletionIterator val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala new file mode 100644 index 0000000000000..5a96c4763b353 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.io + +import java.io.InputStream +import java.lang + +import org.apache.spark.{MapOutputTracker, SparkEnv} +import org.apache.spark.api.shuffle.{BlockMetadata, ShuffleReadSupport} +import org.apache.spark.internal.config +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} + +class DefaultShuffleReadSupport( + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, + serializerManager: SerializerManager) extends ShuffleReadSupport { + + val maxBytesInFlight = SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 + val maxReqsInFlight = SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT) + val maxBlocksInFlightPerAddress = + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS) + val maxReqSizeShuffleToMem = SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + val detectCorrupt = SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT) + + override def getPartitionReaders( + blockMetadata: lang.Iterable[BlockMetadata]): lang.Iterable[InputStream] = { + val shuffleBlockFetcherIterator = new ShuffleBlockFetcherIterator( + blockManager.shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByExecutorId() + ) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 3966980a11ed0..b89435d88885c 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -44,7 +44,6 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid * using too much memory. * - * @param context [[TaskContext]], used for metrics update * @param shuffleClient [[ShuffleClient]] for fetching remote blocks * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. @@ -59,11 +58,9 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * for a given remote host:port. * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param detectCorrupt whether to detect any corruption in fetched blocks. - * @param shuffleMetrics used to report shuffle metrics. */ private[spark] final class ShuffleBlockFetcherIterator( - context: TaskContext, shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])], @@ -72,8 +69,7 @@ final class ShuffleBlockFetcherIterator( maxReqsInFlight: Int, maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, - detectCorrupt: Boolean, - shuffleMetrics: ShuffleReadMetricsReporter) + detectCorrupt: Boolean) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -162,6 +158,10 @@ final class ShuffleBlockFetcherIterator( private[this] val onCompleteCallback = new ShuffleFetchCompletionListener(this) + private[this] val context = TaskContext.get() + + private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() + initialize() // Decrements the buffer reference count. From 9af216fc9a5155a9256433aacdea66998e2e616a Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 22 Mar 2019 15:30:58 -0700 Subject: [PATCH 03/56] wip --- ...tadata.java => ShuffleLocationBlocks.java} | 75 ++++++++++--------- .../spark/api/shuffle/ShuffleReadSupport.java | 2 +- .../io/DefaultShuffleReadSupport.scala | 14 +++- .../storage/ShuffleBlockFetcherIterator.scala | 29 ++++--- 4 files changed, 68 insertions(+), 52 deletions(-) rename core/src/main/java/org/apache/spark/api/shuffle/{BlockMetadata.java => ShuffleLocationBlocks.java} (54%) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/BlockMetadata.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocationBlocks.java similarity index 54% rename from core/src/main/java/org/apache/spark/api/shuffle/BlockMetadata.java rename to core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocationBlocks.java index c14fdf0475e3b..9fd40a0a3a5ac 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/BlockMetadata.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocationBlocks.java @@ -21,53 +21,54 @@ import java.util.Optional; -public final class BlockMetadata { - private final int shuffleId; - private final int mapId; - private final int reduceId; - private final long length; +public final class ShuffleLocationBlocks { + private final ShuffleBlockInfo[] shuffleBlocks; private final Optional shuffleLocation; - private BlockMetadata( - int shuffleId, - int mapId, - int reduceId, - long length, - Optional shuffleLocation) { - this.shuffleId = shuffleId; - this.mapId = mapId; - this.reduceId = reduceId; - this.length = length; - this.shuffleLocation = shuffleLocation; - } + private final class ShuffleBlockInfo { + private final int shuffleId; + private final int mapId; + private final int reduceId; + private final long length; - public static BlockMetadata create(int shuffleId, int mapId, int reduceId, long length) { - return new BlockMetadata(shuffleId, mapId, reduceId, length, Optional.empty()); - } + ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length) { + this.shuffleId = shuffleId; + this.mapId = mapId; + this.reduceId = reduceId; + this.length = length; + } - public static BlockMetadata create( - int shuffleId, - int mapId, - int reduceId, - long length, - BlockManagerId blockManagerId) { - return new BlockMetadata(shuffleId, mapId, reduceId, length, Optional.of(blockManagerId)); - } + public int getShuffleId() { + return shuffleId; + } - public int getShuffleId() { - return shuffleId; - } + public int getMapId() { + return mapId; + } - public int getMapId() { - return mapId; + public int getReduceId() { + return reduceId; + } + + public long getLength() { + return length; + } + + public String getBlockId() { + return String.format("shuffle_%d_%d_%d", shuffleId, mapId, reduceId); + } } - public int getReduceId() { - return reduceId; + private ShuffleLocationBlocks( + ShuffleBlockInfo[] shuffleBlocks, + Optional shuffleLocation) { + this.shuffleBlocks = shuffleBlocks; + this.shuffleLocation = shuffleLocation; } - public long getLength() { - return length; + + public ShuffleBlockInfo[] getShuffleBlocks() { + return shuffleBlocks; } public Optional getShuffleLocation() { diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index 82a1f46a23fe0..1fe9a6cd19d32 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -25,5 +25,5 @@ * An interface for reading shuffle records */ public interface ShuffleReadSupport { - Iterable getPartitionReaders(Iterable blockMetadata) throws IOException; + Iterable getPartitionReaders(Iterable blockMetadata) throws IOException; } diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index 5a96c4763b353..f717082b9014a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -20,8 +20,10 @@ package org.apache.spark.shuffle.io import java.io.InputStream import java.lang +import scala.collection.JavaConverters + import org.apache.spark.{MapOutputTracker, SparkEnv} -import org.apache.spark.api.shuffle.{BlockMetadata, ShuffleReadSupport} +import org.apache.spark.api.shuffle.{ShuffleLocationBlocks, ShuffleReadSupport} import org.apache.spark.internal.config import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} @@ -39,11 +41,17 @@ class DefaultShuffleReadSupport( val detectCorrupt = SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT) override def getPartitionReaders( - blockMetadata: lang.Iterable[BlockMetadata]): lang.Iterable[InputStream] = { + blockMetadata: lang.Iterable[ShuffleLocationBlocks]): lang.Iterable[InputStream] = { val shuffleBlockFetcherIterator = new ShuffleBlockFetcherIterator( blockManager.shuffleClient, blockManager, - mapOutputTracker.getMapSizesByExecutorId() + JavaConverters.iterableAsScalaIterable(blockMetadata).iterator, + serializerManager.wrapStream, + maxBytesInFlight, + maxReqsInFlight, + maxBlocksInFlightPerAddress, + maxReqSizeShuffleToMem, + detectCorrupt ) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b89435d88885c..47cfd091a737b 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,22 +17,23 @@ package org.apache.spark.storage -import java.io.{InputStream, IOException} +import java.io.{IOException, InputStream} import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue -import javax.annotation.concurrent.GuardedBy +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} -import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.api.shuffle.ShuffleLocationBlocks import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} -import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} import org.apache.spark.util.io.ChunkedByteBufferOutputStream +import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} +import org.apache.spark.{SparkException, TaskContext} /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -44,6 +45,7 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid * using too much memory. * + * @param context [[TaskContext]], used for metrics update * @param shuffleClient [[ShuffleClient]] for fetching remote blocks * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. @@ -58,18 +60,21 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * for a given remote host:port. * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param detectCorrupt whether to detect any corruption in fetched blocks. + * @param shuffleMetrics used to report shuffle metrics. */ private[spark] final class ShuffleBlockFetcherIterator( + context: TaskContext, shuffleClient: ShuffleClient, blockManager: BlockManager, - blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])], + blocksByAddress: Iterator[ShuffleLocationBlocks], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, - detectCorrupt: Boolean) + detectCorrupt: Boolean, + shuffleMetrics: ShuffleReadMetricsReporter) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -158,10 +163,6 @@ final class ShuffleBlockFetcherIterator( private[this] val onCompleteCallback = new ShuffleFetchCompletionListener(this) - private[this] val context = TaskContext.get() - - private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() - initialize() // Decrements the buffer reference count. @@ -285,7 +286,13 @@ final class ShuffleBlockFetcherIterator( var localBlockBytes = 0L var remoteBlockBytes = 0L - for ((address, blockInfos) <- blocksByAddress) { + for (shuffleLocationBlocks <- blocksByAddress) { + assert(shuffleLocationBlocks.getShuffleLocation.isPresent, + "expected shuffleLocationBlock to contain a valid shuffleLocation") + val address = shuffleLocationBlocks.getShuffleLocation.get() + val blockInfos = shuffleLocationBlocks.getShuffleBlocks + .map(block => + (ShuffleBlockId(block.getShuffleId, block.getMapId, block.getReduceId), block.getLength)) if (address.executorId == blockManager.blockManagerId.executorId) { blockInfos.find(_._2 <= 0) match { case Some((blockId, size)) if size < 0 => From a35b8261c31d613353020733c1ecc6f91d94fd16 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Mon, 25 Mar 2019 17:09:17 -0700 Subject: [PATCH 04/56] initial implementation of reader --- .../api/shuffle/ShuffleLocationBlocks.java | 10 ++-- .../shuffle/BlockStoreShuffleReader.scala | 43 ++++++++------- .../io/DefaultShuffleReadSupport.scala | 52 ++++++++++++++++--- .../shuffle/sort/SortShuffleManager.scala | 6 ++- .../storage/ShuffleBlockFetcherIterator.scala | 19 +++---- .../BlockStoreShuffleReaderSuite.scala | 8 ++- 6 files changed, 91 insertions(+), 47 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocationBlocks.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocationBlocks.java index 9fd40a0a3a5ac..7d18255b46072 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocationBlocks.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocationBlocks.java @@ -25,13 +25,13 @@ public final class ShuffleLocationBlocks { private final ShuffleBlockInfo[] shuffleBlocks; private final Optional shuffleLocation; - private final class ShuffleBlockInfo { + public static final class ShuffleBlockInfo { private final int shuffleId; private final int mapId; private final int reduceId; private final long length; - ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length) { + public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length) { this.shuffleId = shuffleId; this.mapId = mapId; this.reduceId = reduceId; @@ -59,9 +59,9 @@ public String getBlockId() { } } - private ShuffleLocationBlocks( - ShuffleBlockInfo[] shuffleBlocks, - Optional shuffleLocation) { + public ShuffleLocationBlocks( + Optional shuffleLocation, + ShuffleBlockInfo[] shuffleBlocks) { this.shuffleBlocks = shuffleBlocks; this.shuffleLocation = shuffleLocation; } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index b06cb8bd2614b..39e94020e5daf 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,10 +17,15 @@ package org.apache.spark.shuffle +import java.util.Optional + +import scala.collection.JavaConverters._ + import org.apache.spark._ -import org.apache.spark.internal.{config, Logging} -import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.api.shuffle.{ShuffleLocationBlocks, ShuffleReadSupport} +import org.apache.spark.api.shuffle.ShuffleLocationBlocks.ShuffleBlockInfo +import org.apache.spark.internal.Logging +import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -34,31 +39,33 @@ private[spark] class BlockStoreShuffleReader[K, C]( endPartition: Int, context: TaskContext, readMetrics: ShuffleReadMetricsReporter, - serializerManager: SerializerManager = SparkEnv.get.serializerManager, - blockManager: BlockManager = SparkEnv.get.blockManager, + shuffleReadSupport: ShuffleReadSupport, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] with Logging { private val dep = handle.dependency /** Read the combined key-values for this reduce task */ + val blocksIterator = + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition) + .map(blockManagerIdInfo => { + val shuffleBlockInfo = blockManagerIdInfo._2.map( + blockInfo => { + val block = blockInfo._1.asInstanceOf[ShuffleBlockId] + new ShuffleBlockInfo(block.shuffleId, block.mapId, block.reduceId, blockInfo._2) + } + ) + new ShuffleLocationBlocks(Optional.of(blockManagerIdInfo._1), shuffleBlockInfo.toArray) + }) override def read(): Iterator[Product2[K, C]] = { - val wrappedStreams = new ShuffleBlockFetcherIterator( - blockManager.shuffleClient, - blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), - serializerManager.wrapStream, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, - SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), - SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), - SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT)).toCompletionIterator + val wrappedStreams = + shuffleReadSupport.getPartitionReaders(blocksIterator.toIterable.asJava).asScala + val serializerInstance = dep.serializer.newInstance() // Create a key/value iterator for each stream - val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => + val recordIter = wrappedStreams.flatMap { case wrappedStream => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. @@ -70,7 +77,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( recordIter.map { record => readMetrics.incRecordsRead(1) record - }, + }.toIterator, context.taskMetrics().mergeShuffleReadMetrics()) // An interruptible iterator must be used here in order to support task cancellation diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index f717082b9014a..5b9271f20d011 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -20,17 +20,17 @@ package org.apache.spark.shuffle.io import java.io.InputStream import java.lang -import scala.collection.JavaConverters +import scala.collection.JavaConverters._ -import org.apache.spark.{MapOutputTracker, SparkEnv} +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.shuffle.{ShuffleLocationBlocks, ShuffleReadSupport} +import org.apache.spark.api.shuffle.ShuffleLocationBlocks.ShuffleBlockInfo import org.apache.spark.internal.config import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} class DefaultShuffleReadSupport( blockManager: BlockManager, - mapOutputTracker: MapOutputTracker, serializerManager: SerializerManager) extends ShuffleReadSupport { val maxBytesInFlight = SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 @@ -40,18 +40,54 @@ class DefaultShuffleReadSupport( val maxReqSizeShuffleToMem = SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) val detectCorrupt = SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT) + override def getPartitionReaders( blockMetadata: lang.Iterable[ShuffleLocationBlocks]): lang.Iterable[InputStream] = { - val shuffleBlockFetcherIterator = new ShuffleBlockFetcherIterator( + val blockMetadataAsScala = blockMetadata.asScala.map(shuffleLocationBlocks => { + val blockInfos = shuffleLocationBlocks.getShuffleBlocks + .map(blockInfo => { + (ShuffleBlockId(blockInfo.getShuffleId, blockInfo.getMapId, blockInfo.getReduceId), + blockInfo.getLength) + }).toSeq + (shuffleLocationBlocks.getShuffleLocation.get(), blockInfos) + }) + + val shuffleBlockFetchIterator = new ShuffleBlockFetcherIterator( + TaskContext.get(), blockManager.shuffleClient, blockManager, - JavaConverters.iterableAsScalaIterable(blockMetadata).iterator, + blockMetadataAsScala.iterator, serializerManager.wrapStream, maxBytesInFlight, maxReqsInFlight, maxBlocksInFlightPerAddress, maxReqSizeShuffleToMem, - detectCorrupt - ) + detectCorrupt, + shuffleMetrics = TaskContext.get().taskMetrics().createTempShuffleReadMetrics() + ).toCompletionIterator + + new ShuffleBlockInputStreamIterator(shuffleBlockFetchIterator).toIterable.asJava + } + + private class ShuffleBlockInputStreamIterator( + blockFetchIterator: Iterator[(BlockId, InputStream)]) + extends Iterator[InputStream] { + override def hasNext: Boolean = blockFetchIterator.hasNext + + override def next(): InputStream = { + blockFetchIterator.next()._2 + } + } + + private[spark] object DefaultShuffleReadSupport { + def toShuffleBlockInfo(blockId: BlockId, length: Long): ShuffleBlockInfo = { + assert(blockId.isInstanceOf[ShuffleBlockId]) + val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + new ShuffleBlockInfo( + shuffleBlockId.shuffleId, + shuffleBlockId.mapId, + shuffleBlockId.reduceId, + length) + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index b59fa8e8a3ccd..d115d0b95cd52 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport /** * In sort-based shuffle, incoming records are sorted according to their target partition ids, then @@ -116,9 +117,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + // TODO: remove this from here once ShuffleExecutorComponents is introduced + val readSupport = new DefaultShuffleReadSupport( + blockManager = SparkEnv.get.blockManager, serializerManager = SparkEnv.get.serializerManager) new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startPartition, endPartition, context, metrics) + startPartition, endPartition, context, metrics, readSupport) } /** Get a writer for a given partition. Called on executors by map tasks. */ diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 47cfd091a737b..3966980a11ed0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,23 +17,22 @@ package org.apache.spark.storage -import java.io.{IOException, InputStream} +import java.io.{InputStream, IOException} import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue - import javax.annotation.concurrent.GuardedBy + import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} -import org.apache.spark.api.shuffle.ShuffleLocationBlocks +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} -import org.apache.spark.util.io.ChunkedByteBufferOutputStream import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} -import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.util.io.ChunkedByteBufferOutputStream /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -67,7 +66,7 @@ final class ShuffleBlockFetcherIterator( context: TaskContext, shuffleClient: ShuffleClient, blockManager: BlockManager, - blocksByAddress: Iterator[ShuffleLocationBlocks], + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, @@ -286,13 +285,7 @@ final class ShuffleBlockFetcherIterator( var localBlockBytes = 0L var remoteBlockBytes = 0L - for (shuffleLocationBlocks <- blocksByAddress) { - assert(shuffleLocationBlocks.getShuffleLocation.isPresent, - "expected shuffleLocationBlock to contain a valid shuffleLocation") - val address = shuffleLocationBlocks.getShuffleLocation.get() - val blockInfos = shuffleLocationBlocks.getShuffleBlocks - .map(block => - (ShuffleBlockId(block.getShuffleId, block.getMapId, block.getReduceId), block.getLength)) + for ((address, blockInfos) <- blocksByAddress) { if (address.executorId == blockManager.blockManagerId.executorId) { blockInfos.find(_._2 <= 0) match { case Some((blockId, size)) if size < 0 => diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 6d2ef17a7a790..afad84953a7c5 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark._ import org.apache.spark.internal.config import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} /** @@ -128,15 +129,18 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext .set(config.SHUFFLE_SPILL_COMPRESS, false)) val taskContext = TaskContext.empty() + TaskContext.setTaskContext(taskContext) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() + + val shuffleReadSupport = + new DefaultShuffleReadSupport(blockManager, serializerManager) val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, taskContext, metrics, - serializerManager, - blockManager, + shuffleReadSupport, mapOutputTracker) assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) From 14c47ae408207adf62eae71d2f4751915f5239dc Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Tue, 26 Mar 2019 11:38:46 -0700 Subject: [PATCH 05/56] fix based on comments --- .../spark/api/shuffle/ShuffleBlockInfo.java | 53 +++++++++++++ .../api/shuffle/ShuffleLocationBlocks.java | 77 ------------------- .../spark/api/shuffle/ShuffleReadSupport.java | 2 +- .../shuffle/BlockStoreShuffleReader.scala | 9 +-- .../io/DefaultShuffleReadSupport.scala | 27 +++---- .../shuffle/sort/SortShuffleManager.scala | 4 +- .../BlockStoreShuffleReaderSuite.scala | 29 ++++--- 7 files changed, 88 insertions(+), 113 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java delete mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocationBlocks.java diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java new file mode 100644 index 0000000000000..39f16d4afe6a2 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +public final class ShuffleBlockInfo { + private final int shuffleId; + private final int mapId; + private final int reduceId; + private final long length; + + public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length) { + this.shuffleId = shuffleId; + this.mapId = mapId; + this.reduceId = reduceId; + this.length = length; + } + + public int getShuffleId() { + return shuffleId; + } + + public int getMapId() { + return mapId; + } + + public int getReduceId() { + return reduceId; + } + + public long getLength() { + return length; + } + + public String getBlockName() { + return String.format("shuffle_%d_%d_%d", shuffleId, mapId, reduceId); + } + +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocationBlocks.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocationBlocks.java deleted file mode 100644 index 7d18255b46072..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocationBlocks.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import org.apache.spark.storage.BlockManagerId; - -import java.util.Optional; - -public final class ShuffleLocationBlocks { - private final ShuffleBlockInfo[] shuffleBlocks; - private final Optional shuffleLocation; - - public static final class ShuffleBlockInfo { - private final int shuffleId; - private final int mapId; - private final int reduceId; - private final long length; - - public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length) { - this.shuffleId = shuffleId; - this.mapId = mapId; - this.reduceId = reduceId; - this.length = length; - } - - public int getShuffleId() { - return shuffleId; - } - - public int getMapId() { - return mapId; - } - - public int getReduceId() { - return reduceId; - } - - public long getLength() { - return length; - } - - public String getBlockId() { - return String.format("shuffle_%d_%d_%d", shuffleId, mapId, reduceId); - } - } - - public ShuffleLocationBlocks( - Optional shuffleLocation, - ShuffleBlockInfo[] shuffleBlocks) { - this.shuffleBlocks = shuffleBlocks; - this.shuffleLocation = shuffleLocation; - } - - - public ShuffleBlockInfo[] getShuffleBlocks() { - return shuffleBlocks; - } - - public Optional getShuffleLocation() { - return shuffleLocation; - } -} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index 1fe9a6cd19d32..c6816f4193303 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -25,5 +25,5 @@ * An interface for reading shuffle records */ public interface ShuffleReadSupport { - Iterable getPartitionReaders(Iterable blockMetadata) throws IOException; + Iterable getPartitionReaders(Iterable blockMetadata) throws IOException; } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 39e94020e5daf..373fe964b654d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -22,8 +22,7 @@ import java.util.Optional import scala.collection.JavaConverters._ import org.apache.spark._ -import org.apache.spark.api.shuffle.{ShuffleLocationBlocks, ShuffleReadSupport} -import org.apache.spark.api.shuffle.ShuffleLocationBlocks.ShuffleBlockInfo +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.Logging import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.CompletionIterator @@ -48,20 +47,18 @@ private[spark] class BlockStoreShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ val blocksIterator = mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition) - .map(blockManagerIdInfo => { - val shuffleBlockInfo = blockManagerIdInfo._2.map( + .flatMap(blockManagerIdInfo => { + blockManagerIdInfo._2.map( blockInfo => { val block = blockInfo._1.asInstanceOf[ShuffleBlockId] new ShuffleBlockInfo(block.shuffleId, block.mapId, block.reduceId, blockInfo._2) } ) - new ShuffleLocationBlocks(Optional.of(blockManagerIdInfo._1), shuffleBlockInfo.toArray) }) override def read(): Iterator[Product2[K, C]] = { val wrappedStreams = shuffleReadSupport.getPartitionReaders(blocksIterator.toIterable.asJava).asScala - val serializerInstance = dep.serializer.newInstance() // Create a key/value iterator for each stream diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index 5b9271f20d011..7753985340ece 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -22,16 +22,16 @@ import java.lang import scala.collection.JavaConverters._ -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.shuffle.{ShuffleLocationBlocks, ShuffleReadSupport} -import org.apache.spark.api.shuffle.ShuffleLocationBlocks.ShuffleBlockInfo +import org.apache.spark.{MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.config import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator, ShuffleBlockId} class DefaultShuffleReadSupport( blockManager: BlockManager, - serializerManager: SerializerManager) extends ShuffleReadSupport { + serializerManager: SerializerManager, + mapOutputTracker: MapOutputTracker) extends ShuffleReadSupport { val maxBytesInFlight = SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 val maxReqsInFlight = SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT) @@ -40,23 +40,18 @@ class DefaultShuffleReadSupport( val maxReqSizeShuffleToMem = SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) val detectCorrupt = SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT) - override def getPartitionReaders( - blockMetadata: lang.Iterable[ShuffleLocationBlocks]): lang.Iterable[InputStream] = { - val blockMetadataAsScala = blockMetadata.asScala.map(shuffleLocationBlocks => { - val blockInfos = shuffleLocationBlocks.getShuffleBlocks - .map(blockInfo => { - (ShuffleBlockId(blockInfo.getShuffleId, blockInfo.getMapId, blockInfo.getReduceId), - blockInfo.getLength) - }).toSeq - (shuffleLocationBlocks.getShuffleLocation.get(), blockInfos) - }) + blockMetadata: lang.Iterable[ShuffleBlockInfo]): lang.Iterable[InputStream] = { + + val minReduceId = blockMetadata.asScala.map(block => block.getReduceId).min + val maxReduceId = blockMetadata.asScala.map(block => block.getReduceId).max + val shuffleId = blockMetadata.asScala.head.getShuffleId val shuffleBlockFetchIterator = new ShuffleBlockFetcherIterator( TaskContext.get(), blockManager.shuffleClient, blockManager, - blockMetadataAsScala.iterator, + mapOutputTracker.getMapSizesByExecutorId(shuffleId, minReduceId, maxReduceId + 1), serializerManager.wrapStream, maxBytesInFlight, maxReqsInFlight, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index d115d0b95cd52..e893d512de4ae 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -119,7 +119,9 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { // TODO: remove this from here once ShuffleExecutorComponents is introduced val readSupport = new DefaultShuffleReadSupport( - blockManager = SparkEnv.get.blockManager, serializerManager = SparkEnv.get.serializerManager) + blockManager = SparkEnv.get.blockManager, + serializerManager = SparkEnv.get.serializerManager, + mapOutputTracker = SparkEnv.get.mapOutputTracker) new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context, metrics, readSupport) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index afad84953a7c5..39a6e0ea6d26b 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -20,14 +20,16 @@ package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer -import org.mockito.Mockito.{mock, when} +import org.mockito.Mockito.{doReturn, mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.{Answer, Stubber} import org.apache.spark._ import org.apache.spark.internal.config import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.shuffle.io.DefaultShuffleReadSupport -import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockId} /** * Wrapper for a managed buffer that keeps track of how many times retain and release are called. @@ -102,16 +104,19 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. - val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn { - // Test a scenario where all data is local, to avoid creating a bunch of additional mocks - // for the code to read data over the network. - val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => - val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - (shuffleBlockId, byteOutputStream.size().toLong) - } - Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator + val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + (shuffleBlockId, byteOutputStream.size().toLong) } + val blocksToRetrieve = Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) + val mapOutputTracker = mock(classOf[MapOutputTracker]) + when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)) + .thenAnswer(new Answer[Iterator[(BlockManagerId, Seq[(BlockId, Long)])]] { + def answer(invocationOnMock: InvocationOnMock): + Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + blocksToRetrieve.iterator + } + }) // Create a mocked shuffle handle to pass into HashShuffleReader. val shuffleHandle = { @@ -133,7 +138,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val shuffleReadSupport = - new DefaultShuffleReadSupport(blockManager, serializerManager) + new DefaultShuffleReadSupport(blockManager, serializerManager, mapOutputTracker) val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, From 5bb4c32f6a8d7ecefa2e1c5e618b5e276f72da14 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Tue, 26 Mar 2019 11:50:03 -0700 Subject: [PATCH 06/56] fix java lang import and delete unneeded class --- .../java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java | 4 ---- .../apache/spark/shuffle/io/DefaultShuffleReadSupport.scala | 3 +-- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java index 39f16d4afe6a2..5851b64cdaf66 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java @@ -46,8 +46,4 @@ public long getLength() { return length; } - public String getBlockName() { - return String.format("shuffle_%d_%d_%d", shuffleId, mapId, reduceId); - } - } diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index 7753985340ece..a3cb82271c31b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -18,7 +18,6 @@ package org.apache.spark.shuffle.io import java.io.InputStream -import java.lang import scala.collection.JavaConverters._ @@ -41,7 +40,7 @@ class DefaultShuffleReadSupport( val detectCorrupt = SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT) override def getPartitionReaders( - blockMetadata: lang.Iterable[ShuffleBlockInfo]): lang.Iterable[InputStream] = { + blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): java.lang.Iterable[InputStream] = { val minReduceId = blockMetadata.asScala.map(block => block.getReduceId).min val maxReduceId = blockMetadata.asScala.map(block => block.getReduceId).max From 584e6c8ad4b98877fb3ec86f101217f41f10c2d2 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 27 Mar 2019 11:51:07 -0700 Subject: [PATCH 07/56] address initial comments --- .../apache/spark/api/shuffle/ShuffleBlockInfo.java | 2 +- .../spark/shuffle/io/DefaultShuffleReadSupport.scala | 12 +++++++----- .../spark/shuffle/BlockStoreShuffleReaderSuite.scala | 3 ++- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java index 5851b64cdaf66..10635ea6a3b79 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java @@ -17,7 +17,7 @@ package org.apache.spark.api.shuffle; -public final class ShuffleBlockInfo { +public class ShuffleBlockInfo { private final int shuffleId; private final int mapId; private final int reduceId; diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index a3cb82271c31b..2805acc9543c6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -32,12 +32,14 @@ class DefaultShuffleReadSupport( serializerManager: SerializerManager, mapOutputTracker: MapOutputTracker) extends ShuffleReadSupport { - val maxBytesInFlight = SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 - val maxReqsInFlight = SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT) - val maxBlocksInFlightPerAddress = + private val maxBytesInFlight = + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 + private val maxReqsInFlight = SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT) + private val maxBlocksInFlightPerAddress = SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS) - val maxReqSizeShuffleToMem = SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) - val detectCorrupt = SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT) + private val maxReqSizeShuffleToMem = + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + private val detectCorrupt = SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT) override def getPartitionReaders( blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): java.lang.Iterable[InputStream] = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 39a6e0ea6d26b..aa4be00ed1e4b 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -113,7 +113,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)) .thenAnswer(new Answer[Iterator[(BlockManagerId, Seq[(BlockId, Long)])]] { def answer(invocationOnMock: InvocationOnMock): - Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { blocksToRetrieve.iterator } }) @@ -156,5 +156,6 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext assert(buffer.callsToRetain === 1) assert(buffer.callsToRelease === 1) } + TaskContext.unset() } } From 0292fe236a4f7ea552f2c7ada4b7e53f1028af2f Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 27 Mar 2019 13:21:25 -0700 Subject: [PATCH 08/56] fix unit tests --- .../io/DefaultShuffleReadSupport.scala | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index 2805acc9543c6..cc4f2185cdcae 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -44,25 +44,29 @@ class DefaultShuffleReadSupport( override def getPartitionReaders( blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): java.lang.Iterable[InputStream] = { - val minReduceId = blockMetadata.asScala.map(block => block.getReduceId).min - val maxReduceId = blockMetadata.asScala.map(block => block.getReduceId).max - val shuffleId = blockMetadata.asScala.head.getShuffleId + if (blockMetadata.asScala.isEmpty) { + new ShuffleBlockInputStreamIterator(Iterator.empty).toIterable.asJava + } else { + val minReduceId = blockMetadata.asScala.map(block => block.getReduceId).min + val maxReduceId = blockMetadata.asScala.map(block => block.getReduceId).max + val shuffleId = blockMetadata.asScala.head.getShuffleId - val shuffleBlockFetchIterator = new ShuffleBlockFetcherIterator( - TaskContext.get(), - blockManager.shuffleClient, - blockManager, - mapOutputTracker.getMapSizesByExecutorId(shuffleId, minReduceId, maxReduceId + 1), - serializerManager.wrapStream, - maxBytesInFlight, - maxReqsInFlight, - maxBlocksInFlightPerAddress, - maxReqSizeShuffleToMem, - detectCorrupt, - shuffleMetrics = TaskContext.get().taskMetrics().createTempShuffleReadMetrics() - ).toCompletionIterator + val shuffleBlockFetchIterator = new ShuffleBlockFetcherIterator( + TaskContext.get(), + blockManager.shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByExecutorId(shuffleId, minReduceId, maxReduceId + 1), + serializerManager.wrapStream, + maxBytesInFlight, + maxReqsInFlight, + maxBlocksInFlightPerAddress, + maxReqSizeShuffleToMem, + detectCorrupt, + shuffleMetrics = TaskContext.get().taskMetrics().createTempShuffleReadMetrics() + ).toCompletionIterator - new ShuffleBlockInputStreamIterator(shuffleBlockFetchIterator).toIterable.asJava + new ShuffleBlockInputStreamIterator(shuffleBlockFetchIterator).toIterable.asJava + } } private class ShuffleBlockInputStreamIterator( From 71c2cc7a9cc036e9367ed54383f322972fc6c18a Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 27 Mar 2019 13:24:03 -0700 Subject: [PATCH 09/56] java checkstyle --- .../java/org/apache/spark/api/shuffle/ShuffleReadSupport.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index c6816f4193303..0eaef4a2fb64d 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -25,5 +25,6 @@ * An interface for reading shuffle records */ public interface ShuffleReadSupport { - Iterable getPartitionReaders(Iterable blockMetadata) throws IOException; + Iterable getPartitionReaders(Iterable blockMetadata) + throws IOException; } From 43c377ca9677aa111894e7440b2dd4a2a3b74385 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 27 Mar 2019 15:49:06 -0700 Subject: [PATCH 10/56] fix tests --- core/src/test/scala/org/apache/spark/ShuffleSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 8b1084a8edc76..140c005130f8a 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -405,12 +405,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val taskContext = new TaskContextImpl( 1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem) + TaskContext.setTaskContext(taskContext) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) manager.unregisterShuffle(0) + TaskContext.unset() } } From 9fc6a60465ec8c7938b852b3dbcb111bdc37e4a4 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 27 Mar 2019 20:56:38 -0700 Subject: [PATCH 11/56] address some comments --- .../spark/api/shuffle/ShuffleReadSupport.java | 4 ++ .../io/DefaultShuffleReadSupport.scala | 39 ++++++++----------- .../shuffle/sort/SortShuffleManager.scala | 3 +- .../BlockStoreShuffleReaderSuite.scala | 2 +- 4 files changed, 23 insertions(+), 25 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index 0eaef4a2fb64d..d0bec5c440b8c 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -17,13 +17,17 @@ package org.apache.spark.api.shuffle; +import org.apache.spark.annotation.Experimental; + import java.io.IOException; import java.io.InputStream; /** * :: Experimental :: * An interface for reading shuffle records + * @since 3.0.0 */ +@Experimental public interface ShuffleReadSupport { Iterable getPartitionReaders(Iterable blockMetadata) throws IOException; diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index cc4f2185cdcae..a0f511456dcae 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -21,7 +21,7 @@ import java.io.InputStream import scala.collection.JavaConverters._ -import org.apache.spark.{MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.config import org.apache.spark.serializer.SerializerManager @@ -30,25 +30,28 @@ import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherItera class DefaultShuffleReadSupport( blockManager: BlockManager, serializerManager: SerializerManager, - mapOutputTracker: MapOutputTracker) extends ShuffleReadSupport { + mapOutputTracker: MapOutputTracker, + conf: SparkConf) extends ShuffleReadSupport { - private val maxBytesInFlight = - SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 - private val maxReqsInFlight = SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT) + private val maxBytesInFlight = conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 + private val maxReqsInFlight = conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT) private val maxBlocksInFlightPerAddress = - SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS) - private val maxReqSizeShuffleToMem = - SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) - private val detectCorrupt = SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT) + conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS) + private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + private val detectCorrupt = conf.get(config.SHUFFLE_DETECT_CORRUPT) override def getPartitionReaders( blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): java.lang.Iterable[InputStream] = { if (blockMetadata.asScala.isEmpty) { - new ShuffleBlockInputStreamIterator(Iterator.empty).toIterable.asJava + Iterable.empty.asJava } else { - val minReduceId = blockMetadata.asScala.map(block => block.getReduceId).min - val maxReduceId = blockMetadata.asScala.map(block => block.getReduceId).max + val minMaxReduceIds = blockMetadata.asScala.map(block => block.getReduceId) + .foldLeft(0, Int.MaxValue) { + case ((min, max), elem) => (math.min(min, elem), math.max(max, elem)) + } + val minReduceId = minMaxReduceIds._1 + val maxReduceId = minMaxReduceIds._2 val shuffleId = blockMetadata.asScala.head.getShuffleId val shuffleBlockFetchIterator = new ShuffleBlockFetcherIterator( @@ -65,17 +68,7 @@ class DefaultShuffleReadSupport( shuffleMetrics = TaskContext.get().taskMetrics().createTempShuffleReadMetrics() ).toCompletionIterator - new ShuffleBlockInputStreamIterator(shuffleBlockFetchIterator).toIterable.asJava - } - } - - private class ShuffleBlockInputStreamIterator( - blockFetchIterator: Iterator[(BlockId, InputStream)]) - extends Iterator[InputStream] { - override def hasNext: Boolean = blockFetchIterator.hasNext - - override def next(): InputStream = { - blockFetchIterator.next()._2 + shuffleBlockFetchIterator.map(_._2).toIterable.asJava } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index e893d512de4ae..ee69d0d177e3b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -121,7 +121,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager val readSupport = new DefaultShuffleReadSupport( blockManager = SparkEnv.get.blockManager, serializerManager = SparkEnv.get.serializerManager, - mapOutputTracker = SparkEnv.get.mapOutputTracker) + mapOutputTracker = SparkEnv.get.mapOutputTracker, + conf = SparkEnv.get.conf) new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context, metrics, readSupport) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index aa4be00ed1e4b..e1a91afae8927 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -138,7 +138,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val shuffleReadSupport = - new DefaultShuffleReadSupport(blockManager, serializerManager, mapOutputTracker) + new DefaultShuffleReadSupport(blockManager, serializerManager, mapOutputTracker, testConf) val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, From 45172a5ba75ad05e95992d47d4117a4356c41812 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 27 Mar 2019 21:46:33 -0700 Subject: [PATCH 12/56] blah --- .../org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index a0f511456dcae..bae3cbf340c3d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -47,7 +47,7 @@ class DefaultShuffleReadSupport( Iterable.empty.asJava } else { val minMaxReduceIds = blockMetadata.asScala.map(block => block.getReduceId) - .foldLeft(0, Int.MaxValue) { + .foldLeft(Int.MaxValue, 0) { case ((min, max), elem) => (math.min(min, elem), math.max(max, elem)) } val minReduceId = minMaxReduceIds._1 From 4e5652bd1beb79777ca52f1fca3ec3732a0d63ae Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 27 Mar 2019 21:57:32 -0700 Subject: [PATCH 13/56] address more comments --- .../shuffle/BlockStoreShuffleReader.scala | 3 +- .../io/DefaultShuffleReadSupport.scala | 57 ++++++++++++------- 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 373fe964b654d..1b584b1f740a3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -57,7 +57,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( }) override def read(): Iterator[Product2[K, C]] = { val wrappedStreams = - shuffleReadSupport.getPartitionReaders(blocksIterator.toIterable.asJava).asScala + shuffleReadSupport.getPartitionReaders(blocksIterator.toIterable.asJava) + .iterator().asScala val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index bae3cbf340c3d..2ae667a7c8554 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -25,7 +25,8 @@ import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.config import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator, ShuffleBlockId} +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} class DefaultShuffleReadSupport( blockManager: BlockManager, @@ -54,33 +55,51 @@ class DefaultShuffleReadSupport( val maxReduceId = minMaxReduceIds._2 val shuffleId = blockMetadata.asScala.head.getShuffleId - val shuffleBlockFetchIterator = new ShuffleBlockFetcherIterator( + new ShuffleBlockFetcherIterable( TaskContext.get(), - blockManager.shuffleClient, blockManager, - mapOutputTracker.getMapSizesByExecutorId(shuffleId, minReduceId, maxReduceId + 1), serializerManager.wrapStream, maxBytesInFlight, maxReqsInFlight, maxBlocksInFlightPerAddress, maxReqSizeShuffleToMem, detectCorrupt, - shuffleMetrics = TaskContext.get().taskMetrics().createTempShuffleReadMetrics() - ).toCompletionIterator - - shuffleBlockFetchIterator.map(_._2).toIterable.asJava + shuffleMetrics = TaskContext.get().taskMetrics().createTempShuffleReadMetrics(), + minReduceId, + maxReduceId, + shuffleId, + mapOutputTracker + ).asJava } } +} - private[spark] object DefaultShuffleReadSupport { - def toShuffleBlockInfo(blockId: BlockId, length: Long): ShuffleBlockInfo = { - assert(blockId.isInstanceOf[ShuffleBlockId]) - val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] - new ShuffleBlockInfo( - shuffleBlockId.shuffleId, - shuffleBlockId.mapId, - shuffleBlockId.reduceId, - length) - } - } +private class ShuffleBlockFetcherIterable( + context: TaskContext, + blockManager: BlockManager, + streamWrapper: (BlockId, InputStream) => InputStream, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + maxReqSizeShuffleToMem: Long, + detectCorrupt: Boolean, + shuffleMetrics: ShuffleReadMetricsReporter, + minReduceId: Int, + maxReduceId: Int, + shuffleId: Int, + mapOutputTracker: MapOutputTracker) extends Iterable[InputStream] { + + override def iterator: Iterator[InputStream] = + new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByExecutorId(shuffleId, minReduceId, maxReduceId + 1), + streamWrapper, + maxBytesInFlight, + maxReqsInFlight, + maxBlocksInFlightPerAddress, + maxReqSizeShuffleToMem, + detectCorrupt, + shuffleMetrics).toCompletionIterator.map(_._2) } From a35d8fee55240ea1e62221d5dadcffa62da4e72c Mon Sep 17 00:00:00 2001 From: mcheah Date: Mon, 1 Apr 2019 12:28:03 -0700 Subject: [PATCH 14/56] Use decorators to customize how the read metrics reporter is instantiated. Important for SQL which always wants a SQL metrics reporter instead. --- .../apache/spark/executor/TaskMetrics.scala | 9 +++++- .../shuffle/BlockStoreShuffleReader.scala | 31 +++++++++---------- .../spark/sql/execution/ShuffledRowRDD.scala | 7 +++-- 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index ea79c7310349d..e99bd673f195a 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -56,6 +56,8 @@ class TaskMetrics private[spark] () extends Serializable { private val _diskBytesSpilled = new LongAccumulator private val _peakExecutionMemory = new LongAccumulator private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)] + private var _decorFunc: TempShuffleReadMetrics => TempShuffleReadMetrics = + Predef.identity[TempShuffleReadMetrics] /** * Time taken on the executor to deserialize this task. @@ -187,11 +189,16 @@ class TaskMetrics private[spark] () extends Serializable { * be lost. */ private[spark] def createTempShuffleReadMetrics(): TempShuffleReadMetrics = synchronized { - val readMetrics = new TempShuffleReadMetrics + val readMetrics = _decorFunc(new TempShuffleReadMetrics) tempShuffleReadMetrics += readMetrics readMetrics } + private[spark] def decorateTempShuffleReadMetrics( + decorFunc: TempShuffleReadMetrics => TempShuffleReadMetrics): Unit = synchronized { + _decorFunc = decorFunc + } + /** * Merge values across all temporary [[ShuffleReadMetrics]] into `_shuffleReadMetrics`. * This is expected to be called on executor heartbeat and at the end of a task. diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 1b584b1f740a3..d3edc68f6967c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,8 +17,6 @@ package org.apache.spark.shuffle -import java.util.Optional - import scala.collection.JavaConverters._ import org.apache.spark._ @@ -44,21 +42,22 @@ private[spark] class BlockStoreShuffleReader[K, C]( private val dep = handle.dependency - /** Read the combined key-values for this reduce task */ - val blocksIterator = - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition) - .flatMap(blockManagerIdInfo => { - blockManagerIdInfo._2.map( - blockInfo => { - val block = blockInfo._1.asInstanceOf[ShuffleBlockId] - new ShuffleBlockInfo(block.shuffleId, block.mapId, block.reduceId, blockInfo._2) - } - ) - }) override def read(): Iterator[Product2[K, C]] = { val wrappedStreams = - shuffleReadSupport.getPartitionReaders(blocksIterator.toIterable.asJava) - .iterator().asScala + shuffleReadSupport.getPartitionReaders(new Iterable[ShuffleBlockInfo] { + override def iterator: Iterator[ShuffleBlockInfo] = { + /** Read the combined key-values for this reduce task */ + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition) + .flatMap(blockManagerIdInfo => { + blockManagerIdInfo._2.map( + blockInfo => { + val block = blockInfo._1.asInstanceOf[ShuffleBlockId] + new ShuffleBlockInfo(block.shuffleId, block.mapId, block.reduceId, blockInfo._2) + } + ) + }) + } + }.asJava).iterator().asScala val serializerInstance = dep.serializer.newInstance() @@ -75,7 +74,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( recordIter.map { record => readMetrics.incRecordsRead(1) record - }.toIterator, + }, context.taskMetrics().mergeShuffleReadMetrics()) // An interruptible iterator must be used here in order to support task cancellation diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 079ff25fcb67e..22cfbf506c645 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -156,10 +156,11 @@ class ShuffledRowRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition] - val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, // as well as the `tempMetrics` for basic shuffle metrics. - val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) + context.taskMetrics().decorateTempShuffleReadMetrics( + tempMetrics => new SQLShuffleReadMetricsReporter(tempMetrics, metrics)) + val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() // The range of pre-shuffle partitions that we are fetching at here is // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1]. val reader = @@ -168,7 +169,7 @@ class ShuffledRowRDD( shuffledRowPartition.startPreShufflePartitionIndex, shuffledRowPartition.endPreShufflePartitionIndex, context, - sqlMetricsReporter) + tempMetrics) reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) } From 1a09ebe09057d457367cc4998890820b5ef20285 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Mon, 1 Apr 2019 17:56:31 -0700 Subject: [PATCH 15/56] blah --- .../main/scala/org/apache/spark/executor/TaskMetrics.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index e99bd673f195a..df30fd5c7f679 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -189,8 +189,9 @@ class TaskMetrics private[spark] () extends Serializable { * be lost. */ private[spark] def createTempShuffleReadMetrics(): TempShuffleReadMetrics = synchronized { - val readMetrics = _decorFunc(new TempShuffleReadMetrics) - tempShuffleReadMetrics += readMetrics + val tempShuffleMetrics = new TempShuffleReadMetrics + val readMetrics = _decorFunc(tempShuffleMetrics) + tempShuffleReadMetrics += tempShuffleMetrics readMetrics } From c149d242c25614cc0b613e36f9041261652bbe2c Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Tue, 2 Apr 2019 13:19:10 -0700 Subject: [PATCH 16/56] initial tests --- .../shuffle/ShuffleBlockInputStreamId.java | 25 +++++++++++++++++++ .../spark/api/shuffle/ShuffleReadSupport.java | 3 ++- .../shuffle/BlockStoreShuffleReader.scala | 10 ++++++-- .../io/DefaultShuffleReadSupport.scala | 21 ++++++++++------ .../storage/ShuffleBlockFetcherIterator.scala | 5 +++- .../BlockStoreShuffleReaderSuite.scala | 8 ++++-- 6 files changed, 59 insertions(+), 13 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInputStreamId.java diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInputStreamId.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInputStreamId.java new file mode 100644 index 0000000000000..1eea4d1d38817 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInputStreamId.java @@ -0,0 +1,25 @@ +package org.apache.spark.api.shuffle; + +public class ShuffleBlockInputStreamId { + private final int shuffleId; + private final int mapId; + private final int reduceId; + + public ShuffleBlockInputStreamId(int shuffleId, int mapId, int reduceId) { + this.shuffleId = shuffleId; + this.mapId = mapId; + this.reduceId = reduceId; + } + + public int getShuffleId() { + return shuffleId; + } + + public int getMapId() { + return mapId; + } + + public int getReduceId() { + return reduceId; + } +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index d0bec5c440b8c..7789e1d0c6a70 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -18,6 +18,7 @@ package org.apache.spark.api.shuffle; import org.apache.spark.annotation.Experimental; +import scala.Tuple2; import java.io.IOException; import java.io.InputStream; @@ -29,6 +30,6 @@ */ @Experimental public interface ShuffleReadSupport { - Iterable getPartitionReaders(Iterable blockMetadata) + Iterable> getPartitionReaders(Iterable blockMetadata) throws IOException; } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index d3edc68f6967c..d310c2d6907f9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -61,12 +61,18 @@ private[spark] class BlockStoreShuffleReader[K, C]( val serializerInstance = dep.serializer.newInstance() + val serializerManager = SparkEnv.get.serializerManager // Create a key/value iterator for each stream - val recordIter = wrappedStreams.flatMap { case wrappedStream => + val recordIter = wrappedStreams.flatMap { case (shuffleStreamId, wrappedStream) => + val blockId = ShuffleBlockId( + shuffleStreamId.getShuffleId, + shuffleStreamId.getMapId, + shuffleStreamId.getReduceId) + val decompressedStream = serializerManager.wrapStream(blockId, wrappedStream) // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. - serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + serializerInstance.deserializeStream(decompressedStream).asKeyValueIterator } // Update the context task metrics for each record read. diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index 2ae667a7c8554..91a551f59670d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -18,15 +18,16 @@ package org.apache.spark.shuffle.io import java.io.InputStream +import java.lang import scala.collection.JavaConverters._ import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} -import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleBlockInputStreamId, ShuffleReadSupport} import org.apache.spark.internal.config import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.ShuffleReadMetricsReporter -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator, ShuffleBlockId} class DefaultShuffleReadSupport( blockManager: BlockManager, @@ -41,8 +42,8 @@ class DefaultShuffleReadSupport( private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) private val detectCorrupt = conf.get(config.SHUFFLE_DETECT_CORRUPT) - override def getPartitionReaders( - blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): java.lang.Iterable[InputStream] = { + override def getPartitionReaders(blockMetadata: lang.Iterable[ShuffleBlockInfo]): + lang.Iterable[(ShuffleBlockInputStreamId, InputStream)] = { if (blockMetadata.asScala.isEmpty) { Iterable.empty.asJava @@ -87,9 +88,9 @@ private class ShuffleBlockFetcherIterable( minReduceId: Int, maxReduceId: Int, shuffleId: Int, - mapOutputTracker: MapOutputTracker) extends Iterable[InputStream] { + mapOutputTracker: MapOutputTracker) extends Iterable[(ShuffleBlockInputStreamId, InputStream)] { - override def iterator: Iterator[InputStream] = + override def iterator: Iterator[(ShuffleBlockInputStreamId, InputStream)] = new ShuffleBlockFetcherIterator( context, blockManager.shuffleClient, @@ -101,5 +102,11 @@ private class ShuffleBlockFetcherIterable( maxBlocksInFlightPerAddress, maxReqSizeShuffleToMem, detectCorrupt, - shuffleMetrics).toCompletionIterator.map(_._2) + shuffleMetrics) + .toCompletionIterator + .map(stream => { + val blockId = stream._1.asInstanceOf[ShuffleBlockId] + (new ShuffleBlockInputStreamId(blockId.shuffleId, blockId.mapId, blockId.reduceId), + stream._2) + }) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 3966980a11ed0..1c3656e152529 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -406,6 +406,7 @@ final class ShuffleBlockFetcherIterator( var result: FetchResult = null var input: InputStream = null + var resultInputStream: InputStream = null // Take the next fetched result and try to decompress it to detect data corruption, // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch // is also corrupt, so the previous stage could be retried. @@ -463,6 +464,7 @@ final class ShuffleBlockFetcherIterator( buf.release() throwFetchFailedException(blockId, address, e) } + resultInputStream = buf.createInputStream() var isStreamCopied: Boolean = false try { input = streamWrapper(blockId, in) @@ -508,7 +510,8 @@ final class ShuffleBlockFetcherIterator( throw new NoSuchElementException() } currentResult = result.asInstanceOf[SuccessFetchResult] - (currentResult.blockId, new BufferReleasingInputStream(input, this)) + input.close() + (currentResult.blockId, new BufferReleasingInputStream(resultInputStream, this)) } def toCompletionIterator: Iterator[(BlockId, InputStream)] = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index e1a91afae8927..a64fe586ab97d 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -26,6 +26,7 @@ import org.mockito.stubbing.{Answer, Stubber} import org.apache.spark._ import org.apache.spark.internal.config +import org.apache.spark.io.{CompressionCodec, CompressionCodec$} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.shuffle.io.DefaultShuffleReadSupport @@ -81,11 +82,14 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Create a buffer with some randomly generated key-value pairs to use as the shuffle data // from each mappers (all mappers return the same shuffle data). val byteOutputStream = new ByteArrayOutputStream() - val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) + val compressionCodec = CompressionCodec.createCodec(testConf) + val compressionOutputStream = compressionCodec.compressedOutputStream(byteOutputStream) + val serializationStream = serializer.newInstance().serializeStream(compressionOutputStream) (0 until keyValuePairsPerMap).foreach { i => serializationStream.writeKey(i) serializationStream.writeValue(2*i) } + compressionOutputStream.close() // Setup the mocked BlockManager to return RecordingManagedBuffers. val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) @@ -130,7 +134,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val serializerManager = new SerializerManager( serializer, new SparkConf() - .set(config.SHUFFLE_COMPRESS, false) + .set(config.SHUFFLE_COMPRESS, true) .set(config.SHUFFLE_SPILL_COMPRESS, false)) val taskContext = TaskContext.empty() From 672d4737e3716e8e9494d46f34b79fbdfa213b20 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Tue, 2 Apr 2019 18:26:25 -0700 Subject: [PATCH 17/56] Revert "initial tests" This reverts commit c149d242c25614cc0b613e36f9041261652bbe2c. --- .../shuffle/ShuffleBlockInputStreamId.java | 25 ------------------- .../spark/api/shuffle/ShuffleReadSupport.java | 3 +-- .../shuffle/BlockStoreShuffleReader.scala | 10 ++------ .../io/DefaultShuffleReadSupport.scala | 21 ++++++---------- .../storage/ShuffleBlockFetcherIterator.scala | 5 +--- .../BlockStoreShuffleReaderSuite.scala | 8 ++---- 6 files changed, 13 insertions(+), 59 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInputStreamId.java diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInputStreamId.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInputStreamId.java deleted file mode 100644 index 1eea4d1d38817..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInputStreamId.java +++ /dev/null @@ -1,25 +0,0 @@ -package org.apache.spark.api.shuffle; - -public class ShuffleBlockInputStreamId { - private final int shuffleId; - private final int mapId; - private final int reduceId; - - public ShuffleBlockInputStreamId(int shuffleId, int mapId, int reduceId) { - this.shuffleId = shuffleId; - this.mapId = mapId; - this.reduceId = reduceId; - } - - public int getShuffleId() { - return shuffleId; - } - - public int getMapId() { - return mapId; - } - - public int getReduceId() { - return reduceId; - } -} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index 7789e1d0c6a70..d0bec5c440b8c 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -18,7 +18,6 @@ package org.apache.spark.api.shuffle; import org.apache.spark.annotation.Experimental; -import scala.Tuple2; import java.io.IOException; import java.io.InputStream; @@ -30,6 +29,6 @@ */ @Experimental public interface ShuffleReadSupport { - Iterable> getPartitionReaders(Iterable blockMetadata) + Iterable getPartitionReaders(Iterable blockMetadata) throws IOException; } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index d310c2d6907f9..d3edc68f6967c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -61,18 +61,12 @@ private[spark] class BlockStoreShuffleReader[K, C]( val serializerInstance = dep.serializer.newInstance() - val serializerManager = SparkEnv.get.serializerManager // Create a key/value iterator for each stream - val recordIter = wrappedStreams.flatMap { case (shuffleStreamId, wrappedStream) => - val blockId = ShuffleBlockId( - shuffleStreamId.getShuffleId, - shuffleStreamId.getMapId, - shuffleStreamId.getReduceId) - val decompressedStream = serializerManager.wrapStream(blockId, wrappedStream) + val recordIter = wrappedStreams.flatMap { case wrappedStream => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. - serializerInstance.deserializeStream(decompressedStream).asKeyValueIterator + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator } // Update the context task metrics for each record read. diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index 91a551f59670d..2ae667a7c8554 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -18,16 +18,15 @@ package org.apache.spark.shuffle.io import java.io.InputStream -import java.lang import scala.collection.JavaConverters._ import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} -import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleBlockInputStreamId, ShuffleReadSupport} +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.config import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.ShuffleReadMetricsReporter -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} class DefaultShuffleReadSupport( blockManager: BlockManager, @@ -42,8 +41,8 @@ class DefaultShuffleReadSupport( private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) private val detectCorrupt = conf.get(config.SHUFFLE_DETECT_CORRUPT) - override def getPartitionReaders(blockMetadata: lang.Iterable[ShuffleBlockInfo]): - lang.Iterable[(ShuffleBlockInputStreamId, InputStream)] = { + override def getPartitionReaders( + blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): java.lang.Iterable[InputStream] = { if (blockMetadata.asScala.isEmpty) { Iterable.empty.asJava @@ -88,9 +87,9 @@ private class ShuffleBlockFetcherIterable( minReduceId: Int, maxReduceId: Int, shuffleId: Int, - mapOutputTracker: MapOutputTracker) extends Iterable[(ShuffleBlockInputStreamId, InputStream)] { + mapOutputTracker: MapOutputTracker) extends Iterable[InputStream] { - override def iterator: Iterator[(ShuffleBlockInputStreamId, InputStream)] = + override def iterator: Iterator[InputStream] = new ShuffleBlockFetcherIterator( context, blockManager.shuffleClient, @@ -102,11 +101,5 @@ private class ShuffleBlockFetcherIterable( maxBlocksInFlightPerAddress, maxReqSizeShuffleToMem, detectCorrupt, - shuffleMetrics) - .toCompletionIterator - .map(stream => { - val blockId = stream._1.asInstanceOf[ShuffleBlockId] - (new ShuffleBlockInputStreamId(blockId.shuffleId, blockId.mapId, blockId.reduceId), - stream._2) - }) + shuffleMetrics).toCompletionIterator.map(_._2) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 1c3656e152529..3966980a11ed0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -406,7 +406,6 @@ final class ShuffleBlockFetcherIterator( var result: FetchResult = null var input: InputStream = null - var resultInputStream: InputStream = null // Take the next fetched result and try to decompress it to detect data corruption, // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch // is also corrupt, so the previous stage could be retried. @@ -464,7 +463,6 @@ final class ShuffleBlockFetcherIterator( buf.release() throwFetchFailedException(blockId, address, e) } - resultInputStream = buf.createInputStream() var isStreamCopied: Boolean = false try { input = streamWrapper(blockId, in) @@ -510,8 +508,7 @@ final class ShuffleBlockFetcherIterator( throw new NoSuchElementException() } currentResult = result.asInstanceOf[SuccessFetchResult] - input.close() - (currentResult.blockId, new BufferReleasingInputStream(resultInputStream, this)) + (currentResult.blockId, new BufferReleasingInputStream(input, this)) } def toCompletionIterator: Iterator[(BlockId, InputStream)] = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index a64fe586ab97d..e1a91afae8927 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -26,7 +26,6 @@ import org.mockito.stubbing.{Answer, Stubber} import org.apache.spark._ import org.apache.spark.internal.config -import org.apache.spark.io.{CompressionCodec, CompressionCodec$} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.shuffle.io.DefaultShuffleReadSupport @@ -82,14 +81,11 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Create a buffer with some randomly generated key-value pairs to use as the shuffle data // from each mappers (all mappers return the same shuffle data). val byteOutputStream = new ByteArrayOutputStream() - val compressionCodec = CompressionCodec.createCodec(testConf) - val compressionOutputStream = compressionCodec.compressedOutputStream(byteOutputStream) - val serializationStream = serializer.newInstance().serializeStream(compressionOutputStream) + val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) (0 until keyValuePairsPerMap).foreach { i => serializationStream.writeKey(i) serializationStream.writeValue(2*i) } - compressionOutputStream.close() // Setup the mocked BlockManager to return RecordingManagedBuffers. val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) @@ -134,7 +130,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val serializerManager = new SerializerManager( serializer, new SparkConf() - .set(config.SHUFFLE_COMPRESS, true) + .set(config.SHUFFLE_COMPRESS, false) .set(config.SHUFFLE_SPILL_COMPRESS, false)) val taskContext = TaskContext.empty() From e0a32894abbaf26101d94d6760f685a0d7b66184 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Tue, 2 Apr 2019 21:09:01 -0700 Subject: [PATCH 18/56] initial impl --- .../spark/api/shuffle/ShuffleReadSupport.java | 20 ++++++++- .../shuffle/BlockStoreShuffleReader.scala | 42 +++++++++++++++---- .../io/DefaultShuffleReadSupport.scala | 36 ++++++++++++---- .../storage/ShuffleBlockFetcherIterator.scala | 31 ++++++++++---- 4 files changed, 106 insertions(+), 23 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index d0bec5c440b8c..3cd525a1db73f 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -18,9 +18,11 @@ package org.apache.spark.api.shuffle; import org.apache.spark.annotation.Experimental; +import scala.Tuple2; import java.io.IOException; import java.io.InputStream; +import java.util.Iterator; /** * :: Experimental :: @@ -29,6 +31,22 @@ */ @Experimental public interface ShuffleReadSupport { - Iterable getPartitionReaders(Iterable blockMetadata) + ShuffleReaderIterable getPartitionReaders(Iterable blockMetadata) throws IOException; + + + interface ShuffleReaderIterable extends Iterable> { + @Override + ShuffleReaderIterator iterator(); + } + + interface ShuffleReaderIterator extends Iterator> { + default void retryLastBlock() { + throw new UnsupportedOperationException(); + } + + default void throwCurrentBlockFailedException(Exception e) throws Exception { + throw e; + } + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index d3edc68f6967c..be33beea7cc71 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,12 +17,15 @@ package org.apache.spark.shuffle +import java.io.IOException + import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark._ import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} -import org.apache.spark.internal.Logging -import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.storage.{BlockId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -40,6 +43,11 @@ private[spark] class BlockStoreShuffleReader[K, C]( mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] with Logging { + private val detectCorrupt = SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT) + private val maxBytesInFlight = + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 + private val corruptedBlocks = mutable.HashSet[BlockId]() + private val dep = handle.dependency override def read(): Iterator[Product2[K, C]] = { @@ -57,16 +65,34 @@ private[spark] class BlockStoreShuffleReader[K, C]( ) }) } - }.asJava).iterator().asScala + }.asJava).iterator() val serializerInstance = dep.serializer.newInstance() + val serializerManager = SparkEnv.get.serializerManager // Create a key/value iterator for each stream - val recordIter = wrappedStreams.flatMap { case wrappedStream => - // Note: the asKeyValueIterator below wraps a key/value iterator inside of a - // NextIterator. The NextIterator makes sure that close() is called on the - // underlying InputStream when all records have been read. - serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + val recordIter = wrappedStreams.asScala.flatMap { case (blockInfo, wrappedStream) => + val blockId = ShuffleBlockId( + blockInfo.getShuffleId, + blockInfo.getMapId, + blockInfo.getReduceId) + try { + val decryptedDecompressedStream = serializerManager.wrapStream(blockId, wrappedStream) + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(decryptedDecompressedStream).asKeyValueIterator + } catch { + case e: IOException => + if (detectCorrupt && blockInfo.getLength < maxBytesInFlight && + !corruptedBlocks.contains(blockId)) { + logWarning(s"got an corrupted block $blockId, fetch again", e) + corruptedBlocks += blockId + wrappedStreams.retryLastBlock() + } + wrappedStreams.throwCurrentBlockFailedException(e) + throw new RuntimeException("Expected shuffle reader iterator to throw exception") + } } // Update the context task metrics for each record read. diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index 2ae667a7c8554..fd021c7719c08 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} +import org.apache.spark.api.shuffle.ShuffleReadSupport.{ShuffleReaderIterable, ShuffleReaderIterator} import org.apache.spark.internal.config import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.ShuffleReadMetricsReporter @@ -39,13 +40,21 @@ class DefaultShuffleReadSupport( private val maxBlocksInFlightPerAddress = conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS) private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + // todo remove: private val detectCorrupt = conf.get(config.SHUFFLE_DETECT_CORRUPT) override def getPartitionReaders( - blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): java.lang.Iterable[InputStream] = { + blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): ShuffleReaderIterable = { if (blockMetadata.asScala.isEmpty) { - Iterable.empty.asJava + val emptyIterator = new ShuffleReaderIterator { + override def hasNext: Boolean = Iterator.empty.hasNext + + override def next(): (ShuffleBlockInfo, InputStream) = Iterator.empty.next() + } + return new ShuffleReaderIterable { + override def iterator(): ShuffleReaderIterator = emptyIterator + } } else { val minMaxReduceIds = blockMetadata.asScala.map(block => block.getReduceId) .foldLeft(Int.MaxValue, 0) { @@ -69,7 +78,7 @@ class DefaultShuffleReadSupport( maxReduceId, shuffleId, mapOutputTracker - ).asJava + ) } } } @@ -87,10 +96,10 @@ private class ShuffleBlockFetcherIterable( minReduceId: Int, maxReduceId: Int, shuffleId: Int, - mapOutputTracker: MapOutputTracker) extends Iterable[InputStream] { + mapOutputTracker: MapOutputTracker) extends ShuffleReaderIterable { - override def iterator: Iterator[InputStream] = - new ShuffleBlockFetcherIterator( + override def iterator: ShuffleReaderIterator = { + val innerIterator = new ShuffleBlockFetcherIterator( context, blockManager.shuffleClient, blockManager, @@ -101,5 +110,18 @@ private class ShuffleBlockFetcherIterable( maxBlocksInFlightPerAddress, maxReqSizeShuffleToMem, detectCorrupt, - shuffleMetrics).toCompletionIterator.map(_._2) + shuffleMetrics) + val completionIterator = innerIterator.toCompletionIterator + new ShuffleReaderIterator { + override def hasNext: Boolean = innerIterator.hasNext + + override def next(): (ShuffleBlockInfo, InputStream) = completionIterator.next() + + override def retryLastBlock(): Unit = innerIterator.retryLast() + + override def throwCurrentBlockFailedException(e: Exception): Unit = + innerIterator.throwCurrentBlockFailedException(e) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 3966980a11ed0..846c7bc35a3ed 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,14 +17,15 @@ package org.apache.spark.storage -import java.io.{InputStream, IOException} +import java.io.{IOException, InputStream} import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue -import javax.annotation.concurrent.GuardedBy +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import org.apache.spark.api.shuffle.ShuffleBlockInfo import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} @@ -74,7 +75,7 @@ final class ShuffleBlockFetcherIterator( maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean, shuffleMetrics: ShuffleReadMetricsReporter) - extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { + extends Iterator[(ShuffleBlockInfo, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -397,7 +398,7 @@ final class ShuffleBlockFetcherIterator( * * Throws a FetchFailedException if the next block could not be fetched. */ - override def next(): (BlockId, InputStream) = { + override def next(): (ShuffleBlockInfo, InputStream) = { if (!hasNext) { throw new NoSuchElementException() } @@ -508,11 +509,18 @@ final class ShuffleBlockFetcherIterator( throw new NoSuchElementException() } currentResult = result.asInstanceOf[SuccessFetchResult] - (currentResult.blockId, new BufferReleasingInputStream(input, this)) + val blockId = currentResult.blockId.asInstanceOf[ShuffleBlockId] + (new ShuffleBlockInfo(blockId.shuffleId, blockId.mapId, blockId.reduceId, currentResult.size), + new BufferReleasingInputStream(input, this)) + } + + def retryLast(): Unit = { + fetchRequests += FetchRequest(currentResult.address, + Array((currentResult.blockId, currentResult.size))) } - def toCompletionIterator: Iterator[(BlockId, InputStream)] = { - CompletionIterator[(BlockId, InputStream), this.type](this, + def toCompletionIterator: Iterator[(ShuffleBlockInfo, InputStream)] = { + CompletionIterator[(ShuffleBlockInfo, InputStream), this.type](this, onCompleteCallback.onComplete(context)) } @@ -580,6 +588,15 @@ final class ShuffleBlockFetcherIterator( "Failed to get block " + blockId + ", which is not a shuffle block", e) } } + + def throwCurrentBlockFailedException(e: Throwable): Unit = { + val blockId = currentResult.blockId.asInstanceOf[ShuffleBlockId] + throw new FetchFailedException(currentResult.address, + blockId.shuffleId, + blockId.mapId, + blockId.reduceId, + e) + } } /** From 1e89b3f876cce917816d7084edf0c45f45176863 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Tue, 2 Apr 2019 21:28:08 -0700 Subject: [PATCH 19/56] get shuffle reader tests to pass --- .../shuffle/BlockStoreShuffleReader.scala | 3 +- .../shuffle/sort/SortShuffleManager.scala | 2 +- .../storage/ShuffleBlockFetcherIterator.scala | 63 ++++++++++--------- .../BlockStoreShuffleReaderSuite.scala | 9 ++- .../ShuffleBlockFetcherIteratorSuite.scala | 6 +- 5 files changed, 47 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index be33beea7cc71..4f9122cc24700 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import org.apache.spark._ import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.{config, Logging} +import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -39,6 +40,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( endPartition: Int, context: TaskContext, readMetrics: ShuffleReadMetricsReporter, + serializerManager: SerializerManager, shuffleReadSupport: ShuffleReadSupport, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] with Logging { @@ -68,7 +70,6 @@ private[spark] class BlockStoreShuffleReader[K, C]( }.asJava).iterator() val serializerInstance = dep.serializer.newInstance() - val serializerManager = SparkEnv.get.serializerManager // Create a key/value iterator for each stream val recordIter = wrappedStreams.asScala.flatMap { case (blockInfo, wrappedStream) => diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index ee69d0d177e3b..c202afd736238 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -125,7 +125,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager conf = SparkEnv.get.conf) new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startPartition, endPartition, context, metrics, readSupport) + startPartition, endPartition, context, metrics, SparkEnv.get.serializerManager, readSupport) } /** Get a writer for a given partition. Called on executors by map tasks. */ diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 846c7bc35a3ed..09cb40232b67d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -465,37 +465,38 @@ final class ShuffleBlockFetcherIterator( throwFetchFailedException(blockId, address, e) } var isStreamCopied: Boolean = false - try { - input = streamWrapper(blockId, in) - // Only copy the stream if it's wrapped by compression or encryption, also the size of - // block is small (the decompressed block is smaller than maxBytesInFlight) - if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) { - isStreamCopied = true - val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) - // Decompress the whole block at once to detect any corruption, which could increase - // the memory usage tne potential increase the chance of OOM. - // TODO: manage the memory used here, and spill it into disk in case of OOM. - Utils.copyStream(input, out, closeStreams = true) - input = out.toChunkedByteBuffer.toInputStream(dispose = true) - } - } catch { - case e: IOException => - buf.release() - if (buf.isInstanceOf[FileSegmentManagedBuffer] - || corruptedBlocks.contains(blockId)) { - throwFetchFailedException(blockId, address, e) - } else { - logWarning(s"got an corrupted block $blockId from $address, fetch again", e) - corruptedBlocks += blockId - fetchRequests += FetchRequest(address, Array((blockId, size))) - result = null - } - } finally { - // TODO: release the buf here to free memory earlier - if (isStreamCopied) { - in.close() - } - } + input = in +// try { +// input = streamWrapper(blockId, in) +// // Only copy the stream if it's wrapped by compression or encryption, also the size of +// // block is small (the decompressed block is smaller than maxBytesInFlight) +// if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) { +// isStreamCopied = true +// val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) +// // Decompress the whole block at once to detect any corruption, which could increase +// // the memory usage tne potential increase the chance of OOM. +// // TODO: manage the memory used here, and spill it into disk in case of OOM. +// Utils.copyStream(input, out, closeStreams = true) +// input = out.toChunkedByteBuffer.toInputStream(dispose = true) +// } +// } catch { +// case e: IOException => +// buf.release() +// if (buf.isInstanceOf[FileSegmentManagedBuffer] +// || corruptedBlocks.contains(blockId)) { +// throwFetchFailedException(blockId, address, e) +// } else { +// logWarning(s"got an corrupted block $blockId from $address, fetch again", e) +// corruptedBlocks += blockId +// fetchRequests += FetchRequest(address, Array((blockId, size))) +// result = null +// } +// } finally { +// // TODO: release the buf here to free memory earlier +// if (isStreamCopied) { +// in.close() +// } +// } case FailureFetchResult(blockId, address, e) => throwFetchFailedException(blockId, address, e) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index e1a91afae8927..e0fc39758e0a3 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -26,6 +26,7 @@ import org.mockito.stubbing.{Answer, Stubber} import org.apache.spark._ import org.apache.spark.internal.config +import org.apache.spark.io.CompressionCodec import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.shuffle.io.DefaultShuffleReadSupport @@ -81,11 +82,14 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Create a buffer with some randomly generated key-value pairs to use as the shuffle data // from each mappers (all mappers return the same shuffle data). val byteOutputStream = new ByteArrayOutputStream() - val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) + val compressionCodec = CompressionCodec.createCodec(testConf) + val compressedOutputStream = compressionCodec.compressedOutputStream(byteOutputStream) + val serializationStream = serializer.newInstance().serializeStream(compressedOutputStream) (0 until keyValuePairsPerMap).foreach { i => serializationStream.writeKey(i) serializationStream.writeValue(2*i) } + compressedOutputStream.close() // Setup the mocked BlockManager to return RecordingManagedBuffers. val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) @@ -130,7 +134,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val serializerManager = new SerializerManager( serializer, new SparkConf() - .set(config.SHUFFLE_COMPRESS, false) + .set(config.SHUFFLE_COMPRESS, true) .set(config.SHUFFLE_SPILL_COMPRESS, false)) val taskContext = TaskContext.empty() @@ -145,6 +149,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext reduceId + 1, taskContext, metrics, + serializerManager, shuffleReadSupport, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 98fe9663b6211..3e075265442e7 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -125,9 +125,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - val (blockId, inputStream) = iterator.next() + val (shuffleBlockInfo, inputStream) = iterator.next() // Make sure we release buffers when a wrapped input stream is closed. + val blockId = ShuffleBlockId( + shuffleBlockInfo.getShuffleId, + shuffleBlockInfo.getMapId, + shuffleBlockInfo.getReduceId) val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream] From 495c7bd1f1cd90397c20535832a449b38fe2dc3c Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 3 Apr 2019 11:12:35 -0700 Subject: [PATCH 20/56] update --- .../spark/api/shuffle/ShuffleReadSupport.java | 6 +- .../shuffle/BlockStoreShuffleReader.scala | 20 ++---- .../io/DefaultShuffleReadSupport.scala | 5 +- .../storage/ShuffleBlockFetcherIterator.scala | 68 +++++-------------- 4 files changed, 24 insertions(+), 75 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index 3cd525a1db73f..0797e79a02f44 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -41,12 +41,8 @@ interface ShuffleReaderIterable extends Iterable> { - default void retryLastBlock() { + default void retryLastBlock(Throwable t) { throw new UnsupportedOperationException(); } - - default void throwCurrentBlockFailedException(Exception e) throws Exception { - throw e; - } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 4f9122cc24700..270fd1ca66cec 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -20,13 +20,12 @@ package org.apache.spark.shuffle import java.io.IOException import scala.collection.JavaConverters._ -import scala.collection.mutable import org.apache.spark._ import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} -import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.Logging import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockId, ShuffleBlockId} +import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -45,11 +44,6 @@ private[spark] class BlockStoreShuffleReader[K, C]( mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] with Logging { - private val detectCorrupt = SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT) - private val maxBytesInFlight = - SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 - private val corruptedBlocks = mutable.HashSet[BlockId]() - private val dep = handle.dependency override def read(): Iterator[Product2[K, C]] = { @@ -85,14 +79,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( serializerInstance.deserializeStream(decryptedDecompressedStream).asKeyValueIterator } catch { case e: IOException => - if (detectCorrupt && blockInfo.getLength < maxBytesInFlight && - !corruptedBlocks.contains(blockId)) { - logWarning(s"got an corrupted block $blockId, fetch again", e) - corruptedBlocks += blockId - wrappedStreams.retryLastBlock() - } - wrappedStreams.throwCurrentBlockFailedException(e) - throw new RuntimeException("Expected shuffle reader iterator to throw exception") + wrappedStreams.retryLastBlock(e) + None } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index fd021c7719c08..8b5097c681a4d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -117,10 +117,7 @@ private class ShuffleBlockFetcherIterable( override def next(): (ShuffleBlockInfo, InputStream) = completionIterator.next() - override def retryLastBlock(): Unit = innerIterator.retryLast() - - override def throwCurrentBlockFailedException(e: Exception): Unit = - innerIterator.throwCurrentBlockFailedException(e) + override def retryLastBlock(t: Throwable): Unit = innerIterator.retryLast(t) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 09cb40232b67d..ed6d1cc293090 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,23 +17,21 @@ package org.apache.spark.storage -import java.io.{IOException, InputStream} -import java.nio.ByteBuffer +import java.io.{InputStream, IOException} import java.util.concurrent.LinkedBlockingQueue - import javax.annotation.concurrent.GuardedBy + import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} -import org.apache.spark.api.shuffle.ShuffleBlockInfo import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.api.shuffle.ShuffleBlockInfo import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} -import org.apache.spark.util.io.ChunkedByteBufferOutputStream /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -464,40 +462,7 @@ final class ShuffleBlockFetcherIterator( buf.release() throwFetchFailedException(blockId, address, e) } - var isStreamCopied: Boolean = false input = in -// try { -// input = streamWrapper(blockId, in) -// // Only copy the stream if it's wrapped by compression or encryption, also the size of -// // block is small (the decompressed block is smaller than maxBytesInFlight) -// if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) { -// isStreamCopied = true -// val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) -// // Decompress the whole block at once to detect any corruption, which could increase -// // the memory usage tne potential increase the chance of OOM. -// // TODO: manage the memory used here, and spill it into disk in case of OOM. -// Utils.copyStream(input, out, closeStreams = true) -// input = out.toChunkedByteBuffer.toInputStream(dispose = true) -// } -// } catch { -// case e: IOException => -// buf.release() -// if (buf.isInstanceOf[FileSegmentManagedBuffer] -// || corruptedBlocks.contains(blockId)) { -// throwFetchFailedException(blockId, address, e) -// } else { -// logWarning(s"got an corrupted block $blockId from $address, fetch again", e) -// corruptedBlocks += blockId -// fetchRequests += FetchRequest(address, Array((blockId, size))) -// result = null -// } -// } finally { -// // TODO: release the buf here to free memory earlier -// if (isStreamCopied) { -// in.close() -// } -// } - case FailureFetchResult(blockId, address, e) => throwFetchFailedException(blockId, address, e) } @@ -515,9 +480,21 @@ final class ShuffleBlockFetcherIterator( new BufferReleasingInputStream(input, this)) } - def retryLast(): Unit = { - fetchRequests += FetchRequest(currentResult.address, - Array((currentResult.blockId, currentResult.size))) + def retryLast(t: Throwable): Unit = { + val blockId = currentResult.blockId + if (detectCorrupt && currentResult.size < maxBytesInFlight) { + if (corruptedBlocks.contains(blockId)) { + throwFetchFailedException(blockId, currentResult.address, t) + } + else { + logWarning(s"got an corrupted block $blockId from $currentResult.address, fetch again", t) + corruptedBlocks += blockId + fetchRequests += FetchRequest(currentResult.address, + Array((currentResult.blockId, currentResult.size))) + } + } else { + throwFetchFailedException(blockId, currentResult.address, t) + } } def toCompletionIterator: Iterator[(ShuffleBlockInfo, InputStream)] = { @@ -589,15 +566,6 @@ final class ShuffleBlockFetcherIterator( "Failed to get block " + blockId + ", which is not a shuffle block", e) } } - - def throwCurrentBlockFailedException(e: Throwable): Unit = { - val blockId = currentResult.blockId.asInstanceOf[ShuffleBlockId] - throw new FetchFailedException(currentResult.address, - blockId.shuffleId, - blockId.mapId, - blockId.reduceId, - e) - } } /** From 88a03cb769e13f7fb07477c7cfada5a7d0a801a7 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 3 Apr 2019 15:48:03 -0700 Subject: [PATCH 21/56] tests --- .../spark/api/shuffle/ShuffleBlockInfo.java | 15 ++++++ .../storage/ShuffleBlockFetcherIterator.scala | 3 ++ .../ShuffleBlockFetcherIteratorSuite.scala | 48 +++++++++++++++---- 3 files changed, 57 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java index 10635ea6a3b79..f6b2f28bd908f 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java @@ -17,6 +17,8 @@ package org.apache.spark.api.shuffle; +import java.util.Objects; + public class ShuffleBlockInfo { private final int shuffleId; private final int mapId; @@ -46,4 +48,17 @@ public long getLength() { return length; } + @Override + public boolean equals(Object other) { + return other instanceof ShuffleBlockInfo + && shuffleId == ((ShuffleBlockInfo) other).shuffleId + && mapId == ((ShuffleBlockInfo) other).mapId + && reduceId == ((ShuffleBlockInfo) other).reduceId + && length == ((ShuffleBlockInfo) other).length; + } + + @Override + public int hashCode() { + return Objects.hash(shuffleId, mapId, reduceId, length); + } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index ed6d1cc293090..a3d77b544ff29 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -491,6 +491,9 @@ final class ShuffleBlockFetcherIterator( corruptedBlocks += blockId fetchRequests += FetchRequest(currentResult.address, Array((currentResult.blockId, currentResult.size))) + // Send fetch requests up to maxBytesInFlight + numBlocksToFetch += 1 + fetchUpToMaxBytes() } } else { throwFetchFailedException(blockId, currentResult.address, t) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 3e075265442e7..6aa985a80ffc9 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -31,6 +31,7 @@ import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} +import org.apache.spark.api.shuffle.ShuffleBlockInfo import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager} @@ -407,7 +408,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // The first block should be returned without an exception val (id1, _) = iterator.next() - assert(id1 === ShuffleBlockId(0, 0, 0)) + assert(id1 === new ShuffleBlockInfo(0, 0, 0, 1)) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { @@ -422,15 +423,38 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - // The next block is corrupt local block (the second one is corrupt and retried) - intercept[FetchFailedException] { iterator.next() } + val readNextStreamAndRetryIfCant = () => { + try { + val stream = iterator.next() + val readByte = Array[Byte](1) + stream._2.read(readByte, 0, 1) + readByte + } catch { + case e: IOException => + iterator.retryLast(e) + } + None + } + + // This should fail to read the bytes and call for a retry + val readByte = readNextStreamAndRetryIfCant() + assert(readByte === None) sem.acquire() - intercept[FetchFailedException] { iterator.next() } + + // The next call should fail - local blocks failure + intercept[FetchFailedException] { + iterator.next() + } + // The next call is the retry of the second block + intercept[FetchFailedException] { + readNextStreamAndRetryIfCant() + } } test("big blocks are not checked for corruption") { - val corruptBuffer = mockCorruptBuffer(10000L) + val size = 10000L + val corruptBuffer = mockCorruptBuffer(size) val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -468,7 +492,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext.taskMetrics.createTempShuffleReadMetrics()) // Blocks should be returned without exceptions. assert(Set(iterator.next()._1, iterator.next()._1) === - Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) + Set(new ShuffleBlockInfo(0, 0, 0, size), new ShuffleBlockInfo(0, 1, 0, size))) } test("retry corrupt blocks (disabled)") { @@ -527,11 +551,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // The first block should be returned without an exception val (id1, _) = iterator.next() - assert(id1 === ShuffleBlockId(0, 0, 0)) + assert(id1 === new ShuffleBlockInfo(0, 0, 0, 1)) val (id2, _) = iterator.next() - assert(id2 === ShuffleBlockId(0, 1, 0)) + assert(id2 === new ShuffleBlockInfo(0, 1, 0, 1)) val (id3, _) = iterator.next() - assert(id3 === ShuffleBlockId(0, 2, 0)) + assert(id3 === new ShuffleBlockInfo(0, 2, 0, 1)) } test("Blocks should be shuffled to disk when size of the request is above the" + @@ -635,4 +659,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val e = intercept[FetchFailedException] { iterator.next() } assert(e.getMessage.contains("Received a zero-size buffer")) } + + def toShuffleBlockId(shuffleBlockInfo: ShuffleBlockInfo): ShuffleBlockId = { + ShuffleBlockId(shuffleBlockInfo.getShuffleId, + shuffleBlockInfo.getMapId, + shuffleBlockInfo.getReduceId) + } } From 741deedacf06ab18955229f4ace54e59bf797152 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 3 Apr 2019 16:56:32 -0700 Subject: [PATCH 22/56] style --- .../apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala | 4 ++-- .../spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index e0fc39758e0a3..f2c0193572e05 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer -import org.mockito.Mockito.{doReturn, mock, when} +import org.mockito.Mockito.{mock, when} import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.{Answer, Stubber} +import org.mockito.stubbing.{Answer} import org.apache.spark._ import org.apache.spark.internal.config diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 6aa985a80ffc9..ce8a63b332558 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -442,7 +442,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT sem.acquire() - // The next call should fail - local blocks failure + // The next call should fail - local block failure intercept[FetchFailedException] { iterator.next() } From c7c52b016354aa61b30076f7205ba1670fec6e26 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 3 Apr 2019 17:07:52 -0700 Subject: [PATCH 23/56] hook up executor components --- .../shuffle/ShuffleExecutorComponents.java | 2 ++ .../io/DefaultShuffleExecutorComponents.java | 26 +++++++++++++++++-- .../io/DefaultShuffleReadSupport.scala | 1 - .../shuffle/sort/SortShuffleManager.scala | 13 +++++----- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java index 4fc20bad9938b..8baa3bf6f859a 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java @@ -30,4 +30,6 @@ public interface ShuffleExecutorComponents { void initializeExecutor(String appId, String execId); ShuffleWriteSupport writes(); + + ShuffleReadSupport reads(); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java index 76e87a6740259..a321e27ed160f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -17,11 +17,15 @@ package org.apache.spark.shuffle.sort.io; +import org.apache.spark.MapOutputTracker; import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; import org.apache.spark.api.shuffle.ShuffleExecutorComponents; +import org.apache.spark.api.shuffle.ShuffleReadSupport; import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport; import org.apache.spark.storage.BlockManager; public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents { @@ -29,6 +33,8 @@ public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponen private final SparkConf sparkConf; private BlockManager blockManager; private IndexShuffleBlockResolver blockResolver; + private SerializerManager serializerManager; + private MapOutputTracker mapOutputTracker; public DefaultShuffleExecutorComponents(SparkConf sparkConf) { this.sparkConf = sparkConf; @@ -37,15 +43,31 @@ public DefaultShuffleExecutorComponents(SparkConf sparkConf) { @Override public void initializeExecutor(String appId, String execId) { blockManager = SparkEnv.get().blockManager(); + serializerManager = SparkEnv.get().serializerManager(); + mapOutputTracker = SparkEnv.get().mapOutputTracker(); blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); } @Override public ShuffleWriteSupport writes() { + checkInitialized(); + return new DefaultShuffleWriteSupport(sparkConf, blockResolver); + } + + @Override + public ShuffleReadSupport reads() { + checkInitialized(); + return new DefaultShuffleReadSupport( + blockManager, + serializerManager, + mapOutputTracker, + sparkConf); + } + + private void checkInitialized() { if (blockResolver == null) { throw new IllegalStateException( - "Executor components must be initialized before getting writers."); + "Executor components must be initialized before getting writers."); } - return new DefaultShuffleWriteSupport(sparkConf, blockResolver); } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index 8b5097c681a4d..1f1115f804c14 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -40,7 +40,6 @@ class DefaultShuffleReadSupport( private val maxBlocksInFlightPerAddress = conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS) private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) - // todo remove: private val detectCorrupt = conf.get(config.SHUFFLE_DETECT_CORRUPT) override def getPartitionReaders( diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index e8a397d1e4507..0e2792c2c3158 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -123,15 +123,14 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { - // TODO: remove this from here once ShuffleExecutorComponents is introduced - val readSupport = new DefaultShuffleReadSupport( - blockManager = SparkEnv.get.blockManager, - serializerManager = SparkEnv.get.serializerManager, - mapOutputTracker = SparkEnv.get.mapOutputTracker, - conf = SparkEnv.get.conf) new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startPartition, endPartition, context, metrics, SparkEnv.get.serializerManager, readSupport) + startPartition, + endPartition, + context, + metrics, + SparkEnv.get.serializerManager, + shuffleExecutorComponents.reads()) } /** Get a writer for a given partition. Called on executors by map tasks. */ From 897c0bf252910b9debeba0902c535c1c487d3e87 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 3 Apr 2019 18:09:07 -0700 Subject: [PATCH 24/56] fix compile --- .../org/apache/spark/shuffle/sort/SortShuffleManager.scala | 3 +-- .../shuffle/sort/BlockStoreShuffleReaderBenchmark.scala | 6 +++++- .../spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 0e2792c2c3158..3d10b07f8fabf 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -21,9 +21,8 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark._ import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleExecutorComponents} -import org.apache.spark.internal.{Logging, config} +import org.apache.spark.internal.{config, Logging} import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport import org.apache.spark.util.Utils /** diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index 2690f1a515fcc..c38ecb04ba8d2 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -38,6 +38,7 @@ import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, FetchFailedException} +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, BlockManagerMaster, ShuffleBlockId} import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener, Utils} @@ -206,6 +207,9 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { when(dependency.aggregator).thenReturn(aggregator) when(dependency.keyOrdering).thenReturn(sorter) + val readSupport = new DefaultShuffleReadSupport( + blockManager, serializerManager, mapOutputTracker, defaultConf) + new BlockStoreShuffleReader[String, String]( shuffleHandle, 0, @@ -213,7 +217,7 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { taskContext, taskContext.taskMetrics().createTempShuffleReadMetrics(), serializerManager, - blockManager, + readSupport, mapOutputTracker ) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index ce8a63b332558..293b78824abc3 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -395,7 +395,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - (_, in) => new LimitedInputStream(in, 100), + (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, @@ -483,7 +483,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - (_, in) => new LimitedInputStream(in, 10000), + (_, in) => in, 2048, Int.MaxValue, Int.MaxValue, From 34eaaf6d880ad72d0156be4802b890a95b822a8e Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 3 Apr 2019 18:13:49 -0700 Subject: [PATCH 25/56] remove unnecessary fields --- .../sort/io/DefaultShuffleExecutorComponents.java | 3 --- .../spark/shuffle/io/DefaultShuffleReadSupport.scala | 4 ---- .../spark/storage/ShuffleBlockFetcherIterator.scala | 2 -- .../spark/shuffle/BlockStoreShuffleReaderSuite.scala | 2 +- .../shuffle/sort/BlockStoreShuffleReaderBenchmark.scala | 4 +--- .../spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 9 --------- 6 files changed, 2 insertions(+), 22 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java index a321e27ed160f..40f24eb05a7aa 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -33,7 +33,6 @@ public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponen private final SparkConf sparkConf; private BlockManager blockManager; private IndexShuffleBlockResolver blockResolver; - private SerializerManager serializerManager; private MapOutputTracker mapOutputTracker; public DefaultShuffleExecutorComponents(SparkConf sparkConf) { @@ -43,7 +42,6 @@ public DefaultShuffleExecutorComponents(SparkConf sparkConf) { @Override public void initializeExecutor(String appId, String execId) { blockManager = SparkEnv.get().blockManager(); - serializerManager = SparkEnv.get().serializerManager(); mapOutputTracker = SparkEnv.get().mapOutputTracker(); blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); } @@ -59,7 +57,6 @@ public ShuffleReadSupport reads() { checkInitialized(); return new DefaultShuffleReadSupport( blockManager, - serializerManager, mapOutputTracker, sparkConf); } diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index 1f1115f804c14..c93a6bb60cffd 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -31,7 +31,6 @@ import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherItera class DefaultShuffleReadSupport( blockManager: BlockManager, - serializerManager: SerializerManager, mapOutputTracker: MapOutputTracker, conf: SparkConf) extends ShuffleReadSupport { @@ -66,7 +65,6 @@ class DefaultShuffleReadSupport( new ShuffleBlockFetcherIterable( TaskContext.get(), blockManager, - serializerManager.wrapStream, maxBytesInFlight, maxReqsInFlight, maxBlocksInFlightPerAddress, @@ -85,7 +83,6 @@ class DefaultShuffleReadSupport( private class ShuffleBlockFetcherIterable( context: TaskContext, blockManager: BlockManager, - streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, maxBlocksInFlightPerAddress: Int, @@ -103,7 +100,6 @@ private class ShuffleBlockFetcherIterable( blockManager.shuffleClient, blockManager, mapOutputTracker.getMapSizesByExecutorId(shuffleId, minReduceId, maxReduceId + 1), - streamWrapper, maxBytesInFlight, maxReqsInFlight, maxBlocksInFlightPerAddress, diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a3d77b544ff29..96cf46a76e443 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -51,7 +51,6 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} * order to throttle the memory usage. Note that zero-sized blocks are * already excluded, which happened in * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. - * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point @@ -66,7 +65,6 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])], - streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, maxBlocksInFlightPerAddress: Int, diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index f2c0193572e05..9ee592408a39c 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -142,7 +142,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val shuffleReadSupport = - new DefaultShuffleReadSupport(blockManager, serializerManager, mapOutputTracker, testConf) + new DefaultShuffleReadSupport(blockManager, mapOutputTracker, testConf) val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index c38ecb04ba8d2..5e40bdcbe7631 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -34,7 +34,6 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.network.BlockTransferService import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} -import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, FetchFailedException} @@ -207,8 +206,7 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { when(dependency.aggregator).thenReturn(aggregator) when(dependency.keyOrdering).thenReturn(sorter) - val readSupport = new DefaultShuffleReadSupport( - blockManager, serializerManager, mapOutputTracker, defaultConf) + val readSupport = new DefaultShuffleReadSupport(blockManager, mapOutputTracker, defaultConf) new BlockStoreShuffleReader[String, String]( shuffleHandle, diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 293b78824abc3..fc9a5403324a9 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -113,7 +113,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, @@ -196,7 +195,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, @@ -264,7 +262,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, @@ -324,7 +321,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, @@ -395,7 +391,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, @@ -483,7 +478,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - (_, in) => in, 2048, Int.MaxValue, Int.MaxValue, @@ -538,7 +532,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - (_, in) => new LimitedInputStream(in, 100), 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, @@ -600,7 +593,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - (_, in) => in, maxBytesInFlight = Int.MaxValue, maxReqsInFlight = Int.MaxValue, maxBlocksInFlightPerAddress = Int.MaxValue, @@ -647,7 +639,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress.toIterator, - (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, From 0548800ddd6dd89b4ea76910740cfb5df7198f09 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 3 Apr 2019 18:19:12 -0700 Subject: [PATCH 26/56] remove unused --- .../spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index fc9a5403324a9..f151e062a8a4a 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -650,10 +650,4 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val e = intercept[FetchFailedException] { iterator.next() } assert(e.getMessage.contains("Received a zero-size buffer")) } - - def toShuffleBlockId(shuffleBlockInfo: ShuffleBlockInfo): ShuffleBlockId = { - ShuffleBlockId(shuffleBlockInfo.getShuffleId, - shuffleBlockInfo.getMapId, - shuffleBlockInfo.getReduceId) - } } From 0637e70f8b8aa39b29d374485ba43fda4b48d3a6 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 3 Apr 2019 18:34:12 -0700 Subject: [PATCH 27/56] refactor retrying iterator --- .../shuffle/BlockStoreShuffleReader.scala | 49 ++++++++++++------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 270fd1ca66cec..2a0b9db44d712 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import java.io.IOException +import java.io.{InputStream, IOException} import scala.collection.JavaConverters._ @@ -63,27 +63,40 @@ private[spark] class BlockStoreShuffleReader[K, C]( } }.asJava).iterator() - val serializerInstance = dep.serializer.newInstance() + val retryingWrappedStreams = new Iterator[InputStream] { + override def hasNext: Boolean = wrappedStreams.hasNext - // Create a key/value iterator for each stream - val recordIter = wrappedStreams.asScala.flatMap { case (blockInfo, wrappedStream) => - val blockId = ShuffleBlockId( - blockInfo.getShuffleId, - blockInfo.getMapId, - blockInfo.getReduceId) - try { - val decryptedDecompressedStream = serializerManager.wrapStream(blockId, wrappedStream) - // Note: the asKeyValueIterator below wraps a key/value iterator inside of a - // NextIterator. The NextIterator makes sure that close() is called on the - // underlying InputStream when all records have been read. - serializerInstance.deserializeStream(decryptedDecompressedStream).asKeyValueIterator - } catch { - case e: IOException => - wrappedStreams.retryLastBlock(e) - None + override def next(): InputStream = { + var returnStream: InputStream = null + while (wrappedStreams.hasNext && returnStream == null) { + val nextStream = wrappedStreams.next() + val blockInfo = nextStream._1 + val blockId = ShuffleBlockId( + blockInfo.getShuffleId, + blockInfo.getMapId, + blockInfo.getReduceId) + try { + returnStream = serializerManager.wrapStream(blockId, nextStream._2) + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + } catch { + case e: IOException => + wrappedStreams.retryLastBlock(e) + } + } + returnStream } } + val serializerInstance = dep.serializer.newInstance() + val recordIter = retryingWrappedStreams.flatMap { wrappedStream => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + // Update the context task metrics for each record read. val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( recordIter.map { record => From f069dc1e818f1403f03127807f8e1528df478616 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 3 Apr 2019 18:46:16 -0700 Subject: [PATCH 28/56] remove unused import --- .../spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java index 40f24eb05a7aa..f3dc7cc13288c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -23,7 +23,6 @@ import org.apache.spark.api.shuffle.ShuffleExecutorComponents; import org.apache.spark.api.shuffle.ShuffleReadSupport; import org.apache.spark.api.shuffle.ShuffleWriteSupport; -import org.apache.spark.serializer.SerializerManager; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.io.DefaultShuffleReadSupport; import org.apache.spark.storage.BlockManager; From 0bba6779c2e163c4212ce145640ae0de608bd37a Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Thu, 4 Apr 2019 20:29:23 -0700 Subject: [PATCH 29/56] fix some comments --- .../spark/api/shuffle/ShuffleReadSupport.java | 2 +- .../storage/ShuffleBlockFetcherIterator.scala | 8 ++-- .../ShuffleBlockFetcherIteratorSuite.scala | 38 +++++++++---------- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index 0797e79a02f44..d596fc636bb1b 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -42,7 +42,7 @@ interface ShuffleReaderIterable extends Iterable> { default void retryLastBlock(Throwable t) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException("Cannot retry fetching bad blocks", t); } } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 96cf46a76e443..fe766caefb1b5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -450,7 +450,7 @@ final class ShuffleBlockFetcherIterator( throwFetchFailedException(blockId, address, new IOException(msg)) } - val in = try { + input = try { buf.createInputStream() } catch { // The exception could only be throwed by local shuffle block @@ -460,7 +460,6 @@ final class ShuffleBlockFetcherIterator( buf.release() throwFetchFailedException(blockId, address, e) } - input = in case FailureFetchResult(blockId, address, e) => throwFetchFailedException(blockId, address, e) } @@ -483,9 +482,8 @@ final class ShuffleBlockFetcherIterator( if (detectCorrupt && currentResult.size < maxBytesInFlight) { if (corruptedBlocks.contains(blockId)) { throwFetchFailedException(blockId, currentResult.address, t) - } - else { - logWarning(s"got an corrupted block $blockId from $currentResult.address, fetch again", t) + } else { + logWarning(s"got a corrupted block $blockId from $currentResult.address, fetch again", t) corruptedBlocks += blockId fetchRequests += FetchRequest(currentResult.address, Array((currentResult.blockId, currentResult.size))) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index f151e062a8a4a..5717ee209e766 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -35,7 +35,6 @@ import org.apache.spark.api.shuffle.ShuffleBlockInfo import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager} -import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils @@ -418,33 +417,34 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - val readNextStreamAndRetryIfCant = () => { - try { - val stream = iterator.next() - val readByte = Array[Byte](1) - stream._2.read(readByte, 0, 1) - readByte - } catch { - case e: IOException => - iterator.retryLast(e) - } - None - } - // This should fail to read the bytes and call for a retry - val readByte = readNextStreamAndRetryIfCant() + val readByte = readNextStreamAndRetryOnError(iterator) assert(readByte === None) sem.acquire() - // The next call should fail - local block failure + // The next call should fail and not call for a retry because the stream wasn't + // corrupt, but the fetch itself failed intercept[FetchFailedException] { - iterator.next() + readNextStreamAndRetryOnError(iterator) } - // The next call is the retry of the second block + // The next call is the retry of the second block, which fails intercept[FetchFailedException] { - readNextStreamAndRetryIfCant() + readNextStreamAndRetryOnError(iterator) + } + } + + def readNextStreamAndRetryOnError(iterator: ShuffleBlockFetcherIterator): Option[Byte] = { + try { + val stream = iterator.next() + val readByte = Array[Byte](1) + stream._2.read(readByte, 0, 1) + Some(readByte) + } catch { + case e: IOException => + iterator.retryLast(e) } + None } test("big blocks are not checked for corruption") { From a82a72554775aea86ddee4574858c816b913956d Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Thu, 4 Apr 2019 20:46:28 -0700 Subject: [PATCH 30/56] null check --- .../org/apache/spark/shuffle/BlockStoreShuffleReader.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 2a0b9db44d712..01b82de6ed30b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -85,6 +85,9 @@ private[spark] class BlockStoreShuffleReader[K, C]( wrappedStreams.retryLastBlock(e) } } + if (returnStream == null) { + throw new IllegalStateException("Expected shuffle reader iterator to return a stream") + } returnStream } } From ac392a1c3c1c4c03104a8d987ca6ad99ab336f7b Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Thu, 4 Apr 2019 20:54:28 -0700 Subject: [PATCH 31/56] refactor interface --- .../spark/api/shuffle/ShuffleReadSupport.java | 12 ------------ .../spark/api/shuffle/ShuffleReaderIterable.java | 16 ++++++++++++++++ .../shuffle/io/DefaultShuffleReadSupport.scala | 7 +++---- 3 files changed, 19 insertions(+), 16 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index d596fc636bb1b..05137bfbcd0ff 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -33,16 +33,4 @@ public interface ShuffleReadSupport { ShuffleReaderIterable getPartitionReaders(Iterable blockMetadata) throws IOException; - - - interface ShuffleReaderIterable extends Iterable> { - @Override - ShuffleReaderIterator iterator(); - } - - interface ShuffleReaderIterator extends Iterator> { - default void retryLastBlock(Throwable t) { - throw new UnsupportedOperationException("Cannot retry fetching bad blocks", t); - } - } } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java new file mode 100644 index 0000000000000..9c638ee8cf27c --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java @@ -0,0 +1,16 @@ +package org.apache.spark.api.shuffle; + +import scala.Tuple2; + +import java.io.InputStream; +import java.util.Iterator; + +public interface ShuffleReaderIterable extends Iterable> { + interface ShuffleReaderIterator extends Iterator> { + default void retryLastBlock(Throwable t) { + throw new UnsupportedOperationException("Cannot retry fetching bad blocks", t); + } + } + @Override + ShuffleReaderIterator iterator(); +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index c93a6bb60cffd..e486a8b369ead 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -22,12 +22,11 @@ import java.io.InputStream import scala.collection.JavaConverters._ import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} -import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} -import org.apache.spark.api.shuffle.ShuffleReadSupport.{ShuffleReaderIterable, ShuffleReaderIterator} +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReaderIterable, ShuffleReadSupport} +import org.apache.spark.api.shuffle.ShuffleReaderIterable.ShuffleReaderIterator import org.apache.spark.internal.config -import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.ShuffleReadMetricsReporter -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} class DefaultShuffleReadSupport( blockManager: BlockManager, From 53dd94bdc1c059fa64cf0a89b01c149639013efd Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Thu, 4 Apr 2019 21:07:43 -0700 Subject: [PATCH 32/56] refactor API --- .../api/shuffle/ShuffleReaderInputStream.java | 22 +++++++++++++++++++ .../api/shuffle/ShuffleReaderIterable.java | 6 +++-- .../shuffle/BlockStoreShuffleReader.scala | 4 ++-- .../io/DefaultShuffleReadSupport.scala | 6 ++--- .../storage/ShuffleBlockFetcherIterator.scala | 17 +++++++------- .../ShuffleBlockFetcherIteratorSuite.scala | 20 +++++++++-------- 6 files changed, 51 insertions(+), 24 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java new file mode 100644 index 0000000000000..a824d64da91c2 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java @@ -0,0 +1,22 @@ +package org.apache.spark.api.shuffle; + +import java.io.InputStream; + +public class ShuffleReaderInputStream { + + private final ShuffleBlockInfo shuffleBlockInfo; + private final InputStream inputStream; + + public ShuffleReaderInputStream(ShuffleBlockInfo shuffleBlockInfo, InputStream inputStream) { + this.shuffleBlockInfo = shuffleBlockInfo; + this.inputStream = inputStream; + } + + public ShuffleBlockInfo getShuffleBlockInfo() { + return shuffleBlockInfo; + } + + public InputStream getInputStream() { + return inputStream; + } +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java index 9c638ee8cf27c..fa9f0af384145 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java @@ -5,12 +5,14 @@ import java.io.InputStream; import java.util.Iterator; -public interface ShuffleReaderIterable extends Iterable> { - interface ShuffleReaderIterator extends Iterator> { +public interface ShuffleReaderIterable extends Iterable { + + interface ShuffleReaderIterator extends Iterator { default void retryLastBlock(Throwable t) { throw new UnsupportedOperationException("Cannot retry fetching bad blocks", t); } } + @Override ShuffleReaderIterator iterator(); } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 01b82de6ed30b..86a9a705b3092 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -70,13 +70,13 @@ private[spark] class BlockStoreShuffleReader[K, C]( var returnStream: InputStream = null while (wrappedStreams.hasNext && returnStream == null) { val nextStream = wrappedStreams.next() - val blockInfo = nextStream._1 + val blockInfo = nextStream.getShuffleBlockInfo val blockId = ShuffleBlockId( blockInfo.getShuffleId, blockInfo.getMapId, blockInfo.getReduceId) try { - returnStream = serializerManager.wrapStream(blockId, nextStream._2) + returnStream = serializerManager.wrapStream(blockId, nextStream.getInputStream) // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index e486a8b369ead..8ddd15569850e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -22,7 +22,7 @@ import java.io.InputStream import scala.collection.JavaConverters._ import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} -import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReaderIterable, ShuffleReadSupport} +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReaderInputStream, ShuffleReaderIterable, ShuffleReadSupport} import org.apache.spark.api.shuffle.ShuffleReaderIterable.ShuffleReaderIterator import org.apache.spark.internal.config import org.apache.spark.shuffle.ShuffleReadMetricsReporter @@ -47,7 +47,7 @@ class DefaultShuffleReadSupport( val emptyIterator = new ShuffleReaderIterator { override def hasNext: Boolean = Iterator.empty.hasNext - override def next(): (ShuffleBlockInfo, InputStream) = Iterator.empty.next() + override def next(): ShuffleReaderInputStream = Iterator.empty.next() } return new ShuffleReaderIterable { override def iterator(): ShuffleReaderIterator = emptyIterator @@ -109,7 +109,7 @@ private class ShuffleBlockFetcherIterable( new ShuffleReaderIterator { override def hasNext: Boolean = innerIterator.hasNext - override def next(): (ShuffleBlockInfo, InputStream) = completionIterator.next() + override def next(): ShuffleReaderInputStream = completionIterator.next() override def retryLastBlock(t: Throwable): Unit = innerIterator.retryLast(t) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index fe766caefb1b5..098fce6a8ac8f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,15 +17,15 @@ package org.apache.spark.storage -import java.io.{InputStream, IOException} +import java.io.{IOException, InputStream} import java.util.concurrent.LinkedBlockingQueue -import javax.annotation.concurrent.GuardedBy +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.api.shuffle.ShuffleBlockInfo +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReaderInputStream} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ @@ -71,7 +71,7 @@ final class ShuffleBlockFetcherIterator( maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean, shuffleMetrics: ShuffleReadMetricsReporter) - extends Iterator[(ShuffleBlockInfo, InputStream)] with DownloadFileManager with Logging { + extends Iterator[ShuffleReaderInputStream] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -394,7 +394,7 @@ final class ShuffleBlockFetcherIterator( * * Throws a FetchFailedException if the next block could not be fetched. */ - override def next(): (ShuffleBlockInfo, InputStream) = { + override def next(): ShuffleReaderInputStream = { if (!hasNext) { throw new NoSuchElementException() } @@ -473,7 +473,8 @@ final class ShuffleBlockFetcherIterator( } currentResult = result.asInstanceOf[SuccessFetchResult] val blockId = currentResult.blockId.asInstanceOf[ShuffleBlockId] - (new ShuffleBlockInfo(blockId.shuffleId, blockId.mapId, blockId.reduceId, currentResult.size), + new ShuffleReaderInputStream( + new ShuffleBlockInfo(blockId.shuffleId, blockId.mapId, blockId.reduceId, currentResult.size), new BufferReleasingInputStream(input, this)) } @@ -496,8 +497,8 @@ final class ShuffleBlockFetcherIterator( } } - def toCompletionIterator: Iterator[(ShuffleBlockInfo, InputStream)] = { - CompletionIterator[(ShuffleBlockInfo, InputStream), this.type](this, + def toCompletionIterator: Iterator[ShuffleReaderInputStream] = { + CompletionIterator[ShuffleReaderInputStream, this.type](this, onCompleteCallback.onComplete(context)) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 5717ee209e766..3299b497934ed 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -124,7 +124,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - val (shuffleBlockInfo, inputStream) = iterator.next() + val shuffleInputStream = iterator.next() + val shuffleBlockInfo = shuffleInputStream.getShuffleBlockInfo + val inputStream = shuffleInputStream.getInputStream // Make sure we release buffers when a wrapped input stream is closed. val blockId = ShuffleBlockId( @@ -202,11 +204,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext.taskMetrics.createTempShuffleReadMetrics()) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() - iterator.next()._2.close() // close() first block's input stream + iterator.next().getInputStream.close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator - val subIter = iterator.next()._2 + val subIter = iterator.next().getInputStream // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() @@ -401,7 +403,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT sem.acquire() // The first block should be returned without an exception - val (id1, _) = iterator.next() + val id1 = iterator.next().getShuffleBlockInfo assert(id1 === new ShuffleBlockInfo(0, 0, 0, 1)) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) @@ -438,7 +440,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT try { val stream = iterator.next() val readByte = Array[Byte](1) - stream._2.read(readByte, 0, 1) + stream.getInputStream.read(readByte, 0, 1) Some(readByte) } catch { case e: IOException => @@ -485,7 +487,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Blocks should be returned without exceptions. - assert(Set(iterator.next()._1, iterator.next()._1) === + assert(Set(iterator.next().getShuffleBlockInfo, iterator.next().getShuffleBlockInfo) === Set(new ShuffleBlockInfo(0, 0, 0, size), new ShuffleBlockInfo(0, 1, 0, size))) } @@ -543,11 +545,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT sem.acquire() // The first block should be returned without an exception - val (id1, _) = iterator.next() + val id1 = iterator.next().getShuffleBlockInfo assert(id1 === new ShuffleBlockInfo(0, 0, 0, 1)) - val (id2, _) = iterator.next() + val id2 = iterator.next().getShuffleBlockInfo assert(id2 === new ShuffleBlockInfo(0, 1, 0, 1)) - val (id3, _) = iterator.next() + val id3 = iterator.next().getShuffleBlockInfo assert(id3 === new ShuffleBlockInfo(0, 2, 0, 1)) } From 4c0c79107e41f450f58795972e9b596783a23c34 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Thu, 4 Apr 2019 21:11:10 -0700 Subject: [PATCH 33/56] shuffle iterator style --- .../apache/spark/storage/ShuffleBlockFetcherIterator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 098fce6a8ac8f..9e7622bcc4b06 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,10 +17,10 @@ package org.apache.spark.storage -import java.io.{IOException, InputStream} +import java.io.{InputStream, IOException} import java.util.concurrent.LinkedBlockingQueue - import javax.annotation.concurrent.GuardedBy + import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} From 84f79317e7c411de5c19d6d4bc142e0206237d84 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Thu, 4 Apr 2019 21:18:01 -0700 Subject: [PATCH 34/56] add some javadocs for interfaces --- .../spark/api/shuffle/ShuffleReadSupport.java | 6 +++++- .../spark/api/shuffle/ShuffleReaderInputStream.java | 8 ++++++++ .../spark/api/shuffle/ShuffleReaderIterable.java | 13 +++++++++++-- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index 05137bfbcd0ff..165038fa48fb8 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -26,11 +26,15 @@ /** * :: Experimental :: - * An interface for reading shuffle records + * An interface for reading shuffle records. * @since 3.0.0 */ @Experimental public interface ShuffleReadSupport { + /** + * Returns an underlying {@link ShuffleReaderIterable} that will iterate through shuffle data, + * given an iterable for the shuffle blocks to fetch. + */ ShuffleReaderIterable getPartitionReaders(Iterable blockMetadata) throws IOException; } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java index a824d64da91c2..b251595621ffc 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java @@ -1,7 +1,15 @@ package org.apache.spark.api.shuffle; +import org.apache.spark.annotation.Experimental; + import java.io.InputStream; +/** + * :: Experimental :: + * An interface for reading shuffle records. + * @since 3.0.0 + */ +@Experimental public class ShuffleReaderInputStream { private final ShuffleBlockInfo shuffleBlockInfo; diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java index fa9f0af384145..a67fcee8174db 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java @@ -1,13 +1,22 @@ package org.apache.spark.api.shuffle; -import scala.Tuple2; +import org.apache.spark.annotation.Experimental; -import java.io.InputStream; import java.util.Iterator; +/** + * :: Experimental :: + * An interface for iterating through shuffle blocks to read. + * @since 3.0.0 + */ +@Experimental public interface ShuffleReaderIterable extends Iterable { interface ShuffleReaderIterator extends Iterator { + /** + * Instructs the shuffle iterator to fetch the last block again. This is useful + * if the block is determined to be corrupt after decryption or decompression. + */ default void retryLastBlock(Throwable t) { throw new UnsupportedOperationException("Cannot retry fetching bad blocks", t); } From b59efb58b289125df4dd0e270d9a94056e816a18 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Thu, 4 Apr 2019 21:40:42 -0700 Subject: [PATCH 35/56] attach apache headers --- .../api/shuffle/ShuffleReaderInputStream.java | 17 +++++++++++++++++ .../api/shuffle/ShuffleReaderIterable.java | 17 +++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java index b251595621ffc..80c97221e3ab1 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.api.shuffle; import org.apache.spark.annotation.Experimental; diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java index a67fcee8174db..632359b7c9de5 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.api.shuffle; import org.apache.spark.annotation.Experimental; From aba8a94ecdee8d9798d8116e014b74b7b70a08a4 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 5 Apr 2019 11:06:20 -0700 Subject: [PATCH 36/56] remove unused imports --- .../java/org/apache/spark/api/shuffle/ShuffleReadSupport.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index 165038fa48fb8..3e3d04a268320 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -18,11 +18,8 @@ package org.apache.spark.api.shuffle; import org.apache.spark.annotation.Experimental; -import scala.Tuple2; import java.io.IOException; -import java.io.InputStream; -import java.util.Iterator; /** * :: Experimental :: From 5ef59b696ba4bfa0167eefa8f152db99e933674d Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 5 Apr 2019 11:31:43 -0700 Subject: [PATCH 37/56] remove another import --- .../org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index 8ddd15569850e..c81c6db3bce64 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -17,8 +17,6 @@ package org.apache.spark.shuffle.io -import java.io.InputStream - import scala.collection.JavaConverters._ import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} From 49a19018ba0a7044ffab1645470a9038a37e0a4e Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 5 Apr 2019 16:12:25 -0700 Subject: [PATCH 38/56] fix reader --- .../spark/shuffle/BlockStoreShuffleReader.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 86a9a705b3092..de35caedd2ad9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,7 +17,8 @@ package org.apache.spark.shuffle -import java.io.{InputStream, IOException} +import java.io.{IOException, InputStream} +import java.nio.ByteBuffer import scala.collection.JavaConverters._ @@ -26,8 +27,9 @@ import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.Logging import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.{CompletionIterator, Utils} import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.io.ChunkedByteBufferOutputStream /** * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by @@ -76,10 +78,13 @@ private[spark] class BlockStoreShuffleReader[K, C]( blockInfo.getMapId, blockInfo.getReduceId) try { - returnStream = serializerManager.wrapStream(blockId, nextStream.getInputStream) - // Note: the asKeyValueIterator below wraps a key/value iterator inside of a - // NextIterator. The NextIterator makes sure that close() is called on the - // underlying InputStream when all records have been read. + val in = serializerManager.wrapStream(blockId, nextStream.getInputStream) + val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) + // Decompress the whole block at once to detect any corruption, which could increase + // the memory usage tne potential increase the chance of OOM. + // TODO: manage the memory used here, and spill it into disk in case of OOM. + Utils.copyStream(in, out, closeStreams = true) + returnStream = out.toChunkedByteBuffer.toInputStream(dispose = true) } catch { case e: IOException => wrappedStreams.retryLastBlock(e) From 8c6c09c2b1b16db4b0eee17ec93afa317993c626 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 5 Apr 2019 16:14:53 -0700 Subject: [PATCH 39/56] fix imports --- .../org/apache/spark/shuffle/BlockStoreShuffleReader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index de35caedd2ad9..f46b3c495b52f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import java.io.{IOException, InputStream} +import java.io.{InputStream, IOException} import java.nio.ByteBuffer import scala.collection.JavaConverters._ From 6370b4198b6988a8aedd4de5d02daed013dde6c6 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 10 Apr 2019 12:26:28 -0700 Subject: [PATCH 40/56] add exception comment for retry API --- .../org/apache/spark/api/shuffle/ShuffleReaderIterable.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java index 632359b7c9de5..2307d22342ed5 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java @@ -33,8 +33,10 @@ interface ShuffleReaderIterator extends Iterator { /** * Instructs the shuffle iterator to fetch the last block again. This is useful * if the block is determined to be corrupt after decryption or decompression. + * + * @throws Exception if current block cannot be retried. */ - default void retryLastBlock(Throwable t) { + default void retryLastBlock(Throwable t) throws Exception { throw new UnsupportedOperationException("Cannot retry fetching bad blocks", t); } } From c442b6384d7d031aeb496d5c8a88a6879a8cd60f Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 10 Apr 2019 14:48:40 -0700 Subject: [PATCH 41/56] address some comments --- .../spark/api/shuffle/ShuffleBlockInfo.java | 5 ++++ .../api/shuffle/ShuffleReaderInputStream.java | 2 +- .../shuffle/BlockStoreShuffleReader.scala | 28 +++++++++---------- .../io/DefaultShuffleReadSupport.scala | 6 ++-- 4 files changed, 21 insertions(+), 20 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java index f6b2f28bd908f..f0d457c8d7cc1 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java @@ -19,6 +19,11 @@ import java.util.Objects; +/** + * :: Experimental :: + * An object defining the shuffle block and length metadata associated with the block. + * @since 3.0.0 + */ public class ShuffleBlockInfo { private final int shuffleId; private final int mapId; diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java index 80c97221e3ab1..0d203ea1ebdcc 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java @@ -23,7 +23,7 @@ /** * :: Experimental :: - * An interface for reading shuffle records. + * An object containing the shuffle block's input stream and information about that block. * @since 3.0.0 */ @Experimental diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index f46b3c495b52f..c6312ef76a331 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -48,30 +48,28 @@ private[spark] class BlockStoreShuffleReader[K, C]( private val dep = handle.dependency + /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val wrappedStreams = + val streamsIterator = shuffleReadSupport.getPartitionReaders(new Iterable[ShuffleBlockInfo] { override def iterator: Iterator[ShuffleBlockInfo] = { - /** Read the combined key-values for this reduce task */ mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition) - .flatMap(blockManagerIdInfo => { - blockManagerIdInfo._2.map( - blockInfo => { - val block = blockInfo._1.asInstanceOf[ShuffleBlockId] - new ShuffleBlockInfo(block.shuffleId, block.mapId, block.reduceId, blockInfo._2) - } - ) - }) + .flatMap { blockManagerIdInfo => + blockManagerIdInfo._2.map { blockInfo => + val block = blockInfo._1.asInstanceOf[ShuffleBlockId] + new ShuffleBlockInfo(block.shuffleId, block.mapId, block.reduceId, blockInfo._2) + } + } } }.asJava).iterator() val retryingWrappedStreams = new Iterator[InputStream] { - override def hasNext: Boolean = wrappedStreams.hasNext + override def hasNext: Boolean = streamsIterator.hasNext override def next(): InputStream = { var returnStream: InputStream = null - while (wrappedStreams.hasNext && returnStream == null) { - val nextStream = wrappedStreams.next() + while (streamsIterator.hasNext && returnStream == null) { + val nextStream = streamsIterator.next() val blockInfo = nextStream.getShuffleBlockInfo val blockId = ShuffleBlockId( blockInfo.getShuffleId, @@ -81,13 +79,13 @@ private[spark] class BlockStoreShuffleReader[K, C]( val in = serializerManager.wrapStream(blockId, nextStream.getInputStream) val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) // Decompress the whole block at once to detect any corruption, which could increase - // the memory usage tne potential increase the chance of OOM. + // the memory usage and potentially increase the chance of OOM. // TODO: manage the memory used here, and spill it into disk in case of OOM. Utils.copyStream(in, out, closeStreams = true) returnStream = out.toChunkedByteBuffer.toInputStream(dispose = true) } catch { case e: IOException => - wrappedStreams.retryLastBlock(e) + streamsIterator.retryLastBlock(e) } } if (returnStream == null) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index c81c6db3bce64..12d404d259a82 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -51,12 +51,10 @@ class DefaultShuffleReadSupport( override def iterator(): ShuffleReaderIterator = emptyIterator } } else { - val minMaxReduceIds = blockMetadata.asScala.map(block => block.getReduceId) + val (minReduceId, maxReduceId) = blockMetadata.asScala.map(block => block.getReduceId) .foldLeft(Int.MaxValue, 0) { case ((min, max), elem) => (math.min(min, elem), math.max(max, elem)) } - val minReduceId = minMaxReduceIds._1 - val maxReduceId = minMaxReduceIds._2 val shuffleId = blockMetadata.asScala.head.getShuffleId new ShuffleBlockFetcherIterable( @@ -105,7 +103,7 @@ private class ShuffleBlockFetcherIterable( shuffleMetrics) val completionIterator = innerIterator.toCompletionIterator new ShuffleReaderIterator { - override def hasNext: Boolean = innerIterator.hasNext + override def hasNext: Boolean = completionIterator.hasNext override def next(): ShuffleReaderInputStream = completionIterator.next() From 2c1272a8dc2a54bd9a71ea38220a8799665e5762 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 10 Apr 2019 15:25:52 -0700 Subject: [PATCH 42/56] address comments --- .../shuffle/BlockStoreShuffleReader.scala | 35 ++++++++++++------- .../io/DefaultShuffleReadSupport.scala | 4 --- .../storage/ShuffleBlockFetcherIterator.scala | 24 +++++-------- .../ShuffleBlockFetcherIteratorSuite.scala | 9 ----- 4 files changed, 31 insertions(+), 41 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index c6312ef76a331..f59b0a2ae68cd 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.{CompletionIterator, Utils} @@ -43,11 +43,16 @@ private[spark] class BlockStoreShuffleReader[K, C]( readMetrics: ShuffleReadMetricsReporter, serializerManager: SerializerManager, shuffleReadSupport: ShuffleReadSupport, - mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, + sparkConf: SparkConf = SparkEnv.get.conf) extends ShuffleReader[K, C] with Logging { private val dep = handle.dependency + private val detectCorrupt = sparkConf.get(config.SHUFFLE_DETECT_CORRUPT) + + private val maxBytesInFlight = sparkConf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 + /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { val streamsIterator = @@ -75,17 +80,21 @@ private[spark] class BlockStoreShuffleReader[K, C]( blockInfo.getShuffleId, blockInfo.getMapId, blockInfo.getReduceId) - try { - val in = serializerManager.wrapStream(blockId, nextStream.getInputStream) - val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) - // Decompress the whole block at once to detect any corruption, which could increase - // the memory usage and potentially increase the chance of OOM. - // TODO: manage the memory used here, and spill it into disk in case of OOM. - Utils.copyStream(in, out, closeStreams = true) - returnStream = out.toChunkedByteBuffer.toInputStream(dispose = true) - } catch { - case e: IOException => - streamsIterator.retryLastBlock(e) + if (detectCorrupt && blockInfo.getLength < maxBytesInFlight) { + try { + val in = serializerManager.wrapStream(blockId, nextStream.getInputStream) + val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) + // Decompress the whole block at once to detect any corruption, which could increase + // the memory usage and potentially increase the chance of OOM. + // TODO: manage the memory used here, and spill it into disk in case of OOM. + Utils.copyStream(in, out, closeStreams = true) + returnStream = out.toChunkedByteBuffer.toInputStream(dispose = true) + } catch { + case e: IOException => + streamsIterator.retryLastBlock(e) + } + } else { + returnStream = serializerManager.wrapStream(blockId, nextStream.getInputStream) } } if (returnStream == null) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index 12d404d259a82..8b167576af0b2 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -36,7 +36,6 @@ class DefaultShuffleReadSupport( private val maxBlocksInFlightPerAddress = conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS) private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) - private val detectCorrupt = conf.get(config.SHUFFLE_DETECT_CORRUPT) override def getPartitionReaders( blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): ShuffleReaderIterable = { @@ -64,7 +63,6 @@ class DefaultShuffleReadSupport( maxReqsInFlight, maxBlocksInFlightPerAddress, maxReqSizeShuffleToMem, - detectCorrupt, shuffleMetrics = TaskContext.get().taskMetrics().createTempShuffleReadMetrics(), minReduceId, maxReduceId, @@ -82,7 +80,6 @@ private class ShuffleBlockFetcherIterable( maxReqsInFlight: Int, maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, - detectCorrupt: Boolean, shuffleMetrics: ShuffleReadMetricsReporter, minReduceId: Int, maxReduceId: Int, @@ -99,7 +96,6 @@ private class ShuffleBlockFetcherIterable( maxReqsInFlight, maxBlocksInFlightPerAddress, maxReqSizeShuffleToMem, - detectCorrupt, shuffleMetrics) val completionIterator = innerIterator.toCompletionIterator new ShuffleReaderIterator { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 9e7622bcc4b06..f2358e380c4bf 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -56,7 +56,6 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point * for a given remote host:port. * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. - * @param detectCorrupt whether to detect any corruption in fetched blocks. * @param shuffleMetrics used to report shuffle metrics. */ private[spark] @@ -69,7 +68,6 @@ final class ShuffleBlockFetcherIterator( maxReqsInFlight: Int, maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, - detectCorrupt: Boolean, shuffleMetrics: ShuffleReadMetricsReporter) extends Iterator[ShuffleReaderInputStream] with DownloadFileManager with Logging { @@ -480,20 +478,16 @@ final class ShuffleBlockFetcherIterator( def retryLast(t: Throwable): Unit = { val blockId = currentResult.blockId - if (detectCorrupt && currentResult.size < maxBytesInFlight) { - if (corruptedBlocks.contains(blockId)) { - throwFetchFailedException(blockId, currentResult.address, t) - } else { - logWarning(s"got a corrupted block $blockId from $currentResult.address, fetch again", t) - corruptedBlocks += blockId - fetchRequests += FetchRequest(currentResult.address, - Array((currentResult.blockId, currentResult.size))) - // Send fetch requests up to maxBytesInFlight - numBlocksToFetch += 1 - fetchUpToMaxBytes() - } - } else { + if (corruptedBlocks.contains(blockId)) { throwFetchFailedException(blockId, currentResult.address, t) + } else { + logWarning(s"got a corrupted block $blockId from $currentResult.address, fetch again", t) + corruptedBlocks += blockId + fetchRequests += FetchRequest(currentResult.address, + Array((currentResult.blockId, currentResult.size))) + // Send fetch requests up to maxBytesInFlight + numBlocksToFetch += 1 + fetchUpToMaxBytes() } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 3299b497934ed..bb80386ceacf3 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -116,7 +116,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true, metrics) // 3 local blocks fetched in initialization @@ -200,7 +199,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true, taskContext.taskMetrics.createTempShuffleReadMetrics()) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() @@ -267,7 +265,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true, taskContext.taskMetrics.createTempShuffleReadMetrics()) @@ -326,7 +323,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure @@ -396,7 +392,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure @@ -484,7 +479,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Blocks should be returned without exceptions. assert(Set(iterator.next().getShuffleBlockInfo, iterator.next().getShuffleBlockInfo) === @@ -538,7 +532,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - false, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure @@ -599,7 +592,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxReqsInFlight = Int.MaxValue, maxBlocksInFlightPerAddress = Int.MaxValue, maxReqSizeShuffleToMem = 200, - detectCorrupt = true, taskContext.taskMetrics.createTempShuffleReadMetrics()) } @@ -645,7 +637,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true, taskContext.taskMetrics.createTempShuffleReadMetrics()) // All blocks fetched return zero length and should trigger a receive-side error: From bd349ca7dabd1a4971c5ec43001dd239922d5e57 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 19 Apr 2019 15:03:30 -0700 Subject: [PATCH 43/56] resolve conflicts --- .../spark/api/shuffle/ShuffleBlockInfo.java | 13 +++++- .../spark/api/shuffle/ShuffleLocation.java | 1 + .../io/DefaultShuffleExecutorComponents.java | 6 +-- .../org/apache/spark/MapOutputTracker.scala | 9 +++- .../shuffle/BlockStoreShuffleReader.scala | 43 +++++-------------- .../io/DefaultShuffleReadSupport.scala | 8 +++- .../storage/ShuffleBlockFetcherIterator.scala | 8 ++-- .../BlockStoreShuffleReaderSuite.scala | 32 ++++---------- .../ShuffleBlockFetcherIteratorSuite.scala | 14 +++--- 9 files changed, 60 insertions(+), 74 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java index f0d457c8d7cc1..7cb2e62a07c5a 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java @@ -18,6 +18,7 @@ package org.apache.spark.api.shuffle; import java.util.Objects; +import java.util.Optional; /** * :: Experimental :: @@ -29,12 +30,18 @@ public class ShuffleBlockInfo { private final int mapId; private final int reduceId; private final long length; + private final Optional shuffleLocation; - public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length) { + public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length, ShuffleLocation shuffleLocation) { this.shuffleId = shuffleId; this.mapId = mapId; this.reduceId = reduceId; this.length = length; + if (shuffleLocation == ShuffleLocation.EMPTY_LOCATION) { + this.shuffleLocation = Optional.empty(); + } else { + this.shuffleLocation = Optional.of(shuffleLocation); + } } public int getShuffleId() { @@ -53,6 +60,10 @@ public long getLength() { return length; } + public Optional getShuffleLocation() { + return shuffleLocation; + } + @Override public boolean equals(Object other) { return other instanceof ShuffleBlockInfo diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java index 87eb497098e0c..41d474dd6eef8 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java @@ -22,4 +22,5 @@ * and writers are expected to cast this down to an implementation-specific representation. */ public interface ShuffleLocation { + ShuffleLocation EMPTY_LOCATION = new ShuffleLocation() {}; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java index 9c4c9f81a81a4..4451280bf4b07 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -48,7 +48,7 @@ public void initializeExecutor(String appId, String execId) { @Override public ShuffleWriteSupport writes() { checkInitialized(); - return new DefaultShuffleWriteSupport(sparkConf, blockResolver); + return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId()); } @Override @@ -65,9 +65,5 @@ private void checkInitialized() { throw new IllegalStateException( "Executor components must be initialized before getting writers."); } -<<<<<<< HEAD -======= - return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId()); ->>>>>>> spark-25299 } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 74975019e7480..7ac28656aeee5 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -885,9 +885,14 @@ private[spark] object MapOutputTracker extends Logging { for (part <- startPartition until endPartition) { val size = status.getSizeForBlock(part) if (size != 0) { - val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part) - splitsByAddress.getOrElseUpdate(shuffleLoc, ListBuffer()) += + if (status.mapShuffleLocations == null) { + splitsByAddress.getOrElseUpdate(ShuffleLocation.EMPTY_LOCATION, ListBuffer()) += ((ShuffleBlockId(shuffleId, mapId, part), size)) + } else { + val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part) + splitsByAddress.getOrElseUpdate(shuffleLoc, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), size)) + } } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 3a868c6e7bf8b..0d405525ef093 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -26,14 +26,8 @@ import org.apache.spark._ import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer.SerializerManager -<<<<<<< HEAD import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.{CompletionIterator, Utils} -======= -import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} -import org.apache.spark.util.CompletionIterator ->>>>>>> spark-25299 import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -61,41 +55,24 @@ private[spark] class BlockStoreShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { -<<<<<<< HEAD val streamsIterator = shuffleReadSupport.getPartitionReaders(new Iterable[ShuffleBlockInfo] { override def iterator: Iterator[ShuffleBlockInfo] = { - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition) - .flatMap { blockManagerIdInfo => - blockManagerIdInfo._2.map { blockInfo => + mapOutputTracker + .getMapSizesByShuffleLocation(handle.shuffleId, startPartition, endPartition) + .flatMap { shuffleLocationInfo => + shuffleLocationInfo._2.map { blockInfo => val block = blockInfo._1.asInstanceOf[ShuffleBlockId] - new ShuffleBlockInfo(block.shuffleId, block.mapId, block.reduceId, blockInfo._2) + new ShuffleBlockInfo( + block.shuffleId, + block.mapId, + block.reduceId, + blockInfo._2, + shuffleLocationInfo._1) } } } }.asJava).iterator() -======= - val wrappedStreams = new ShuffleBlockFetcherIterator( - context, - blockManager.shuffleClient, - blockManager, - mapOutputTracker.getMapSizesByShuffleLocation(handle.shuffleId, startPartition, endPartition) - .map { - case (loc: DefaultMapShuffleLocations, blocks: Seq[(BlockId, Long)]) => - (loc.getBlockManagerId, blocks) - case _ => - throw new UnsupportedOperationException("Not allowed to using non-default map shuffle" + - " locations yet.") - }, - serializerManager.wrapStream, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, - SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), - SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), - SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), - readMetrics).toCompletionIterator ->>>>>>> spark-25299 val retryingWrappedStreams = new Iterator[InputStream] { override def hasNext: Boolean = streamsIterator.hasNext diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index 8b167576af0b2..a0c6bbfc727cb 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -24,6 +24,7 @@ import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReaderInputStream, import org.apache.spark.api.shuffle.ShuffleReaderIterable.ShuffleReaderIterator import org.apache.spark.internal.config import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} class DefaultShuffleReadSupport( @@ -91,7 +92,12 @@ private class ShuffleBlockFetcherIterable( context, blockManager.shuffleClient, blockManager, - mapOutputTracker.getMapSizesByExecutorId(shuffleId, minReduceId, maxReduceId + 1), + mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, minReduceId, maxReduceId + 1) + .map { shuffleLocationInfo => + val defaultShuffleLocation = shuffleLocationInfo._1 + .asInstanceOf[DefaultMapShuffleLocations] + (defaultShuffleLocation.getBlockManagerId, shuffleLocationInfo._2) + }, maxBytesInFlight, maxReqsInFlight, maxBlocksInFlightPerAddress, diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index f2358e380c4bf..6d4a55a217b74 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,10 +17,10 @@ package org.apache.spark.storage -import java.io.{InputStream, IOException} +import java.io.{IOException, InputStream} import java.util.concurrent.LinkedBlockingQueue -import javax.annotation.concurrent.GuardedBy +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} @@ -472,7 +473,8 @@ final class ShuffleBlockFetcherIterator( currentResult = result.asInstanceOf[SuccessFetchResult] val blockId = currentResult.blockId.asInstanceOf[ShuffleBlockId] new ShuffleReaderInputStream( - new ShuffleBlockInfo(blockId.shuffleId, blockId.mapId, blockId.reduceId, currentResult.size), + new ShuffleBlockInfo(blockId.shuffleId, blockId.mapId, blockId.reduceId, currentResult.size, + DefaultMapShuffleLocations.get(currentResult.address)), new BufferReleasingInputStream(input, this)) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 90cca963747e2..5493eb6ced48e 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -22,20 +22,18 @@ import java.nio.ByteBuffer import org.mockito.Mockito.{mock, when} import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.{Answer} +import org.mockito.stubbing.Answer import org.apache.spark._ +import org.apache.spark.api.shuffle.ShuffleLocation import org.apache.spark.internal.config import org.apache.spark.io.CompressionCodec import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} -<<<<<<< HEAD import org.apache.spark.shuffle.io.DefaultShuffleReadSupport -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockId} -======= import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} ->>>>>>> spark-25299 +import org.apache.spark.storage.BlockId /** * Wrapper for a managed buffer that keeps track of how many times retain and release are called. @@ -113,31 +111,17 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. -<<<<<<< HEAD val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) (shuffleBlockId, byteOutputStream.size().toLong) -======= - val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByShuffleLocation( - shuffleId, reduceId, reduceId + 1)).thenReturn { - // Test a scenario where all data is local, to avoid creating a bunch of additional mocks - // for the code to read data over the network. - val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => - val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - (shuffleBlockId, byteOutputStream.size().toLong) - } - Seq( - (DefaultMapShuffleLocations.get(localBlockManagerId), shuffleBlockIdsAndSizes)) - .toIterator ->>>>>>> spark-25299 } - val blocksToRetrieve = Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) + val blocksToRetrieve = Seq( + (DefaultMapShuffleLocations.get(localBlockManagerId), shuffleBlockIdsAndSizes)) val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)) - .thenAnswer(new Answer[Iterator[(BlockManagerId, Seq[(BlockId, Long)])]] { + when(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1)) + .thenAnswer(new Answer[Iterator[(ShuffleLocation, Seq[(BlockId, Long)])]] { def answer(invocationOnMock: InvocationOnMock): - Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { blocksToRetrieve.iterator } }) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index bb80386ceacf3..d244082b4ff6a 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager} import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.util.Utils @@ -393,13 +394,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, taskContext.taskMetrics.createTempShuffleReadMetrics()) + val shuffleLocation = DefaultMapShuffleLocations.get(remoteBmId) // Continue only after the mock calls onBlockFetchFailure sem.acquire() // The first block should be returned without an exception val id1 = iterator.next().getShuffleBlockInfo - assert(id1 === new ShuffleBlockInfo(0, 0, 0, 1)) + assert(id1 === new ShuffleBlockInfo(0, 0, 0, 1, shuffleLocation)) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { @@ -482,7 +484,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext.taskMetrics.createTempShuffleReadMetrics()) // Blocks should be returned without exceptions. assert(Set(iterator.next().getShuffleBlockInfo, iterator.next().getShuffleBlockInfo) === - Set(new ShuffleBlockInfo(0, 0, 0, size), new ShuffleBlockInfo(0, 1, 0, size))) + Set(new ShuffleBlockInfo(0, 0, 0, size, DefaultMapShuffleLocations.get(localBmId)), + new ShuffleBlockInfo(0, 1, 0, size, DefaultMapShuffleLocations.get(remoteBmId)))) } test("retry corrupt blocks (disabled)") { @@ -537,13 +540,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Continue only after the mock calls onBlockFetchFailure sem.acquire() + val remoteShuffleLocation = DefaultMapShuffleLocations.get(remoteBmId) // The first block should be returned without an exception val id1 = iterator.next().getShuffleBlockInfo - assert(id1 === new ShuffleBlockInfo(0, 0, 0, 1)) + assert(id1 === new ShuffleBlockInfo(0, 0, 0, 1, remoteShuffleLocation)) val id2 = iterator.next().getShuffleBlockInfo - assert(id2 === new ShuffleBlockInfo(0, 1, 0, 1)) + assert(id2 === new ShuffleBlockInfo(0, 1, 0, 1, remoteShuffleLocation)) val id3 = iterator.next().getShuffleBlockInfo - assert(id3 === new ShuffleBlockInfo(0, 2, 0, 1)) + assert(id3 === new ShuffleBlockInfo(0, 2, 0, 1, remoteShuffleLocation)) } test("Blocks should be shuffled to disk when size of the request is above the" + From 653f67c0e4f2fcc8a358fc009ebd7faf90a0385f Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 19 Apr 2019 16:14:47 -0700 Subject: [PATCH 44/56] style --- .../apache/spark/storage/ShuffleBlockFetcherIterator.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 6d4a55a217b74..88976861a22fc 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,10 +17,10 @@ package org.apache.spark.storage -import java.io.{IOException, InputStream} +import java.io.{InputStream, IOException} import java.util.concurrent.LinkedBlockingQueue - import javax.annotation.concurrent.GuardedBy + import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} @@ -30,8 +30,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf -import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} /** From 9f53839399fc216357c4d11ed98a6079bdfb5c6e Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 19 Apr 2019 16:56:35 -0700 Subject: [PATCH 45/56] address some comments --- .../spark/api/shuffle/ShuffleBlockInfo.java | 17 ++++++++--------- .../spark/api/shuffle/ShuffleLocation.java | 4 +--- .../org/apache/spark/MapOutputTracker.scala | 16 ++++++++-------- .../spark/shuffle/BlockStoreShuffleReader.scala | 3 ++- .../shuffle/io/DefaultShuffleReadSupport.scala | 2 +- .../storage/ShuffleBlockFetcherIterator.scala | 7 ++++--- .../shuffle/BlockStoreShuffleReaderSuite.scala | 6 +++--- .../sort/BlockStoreShuffleReaderBenchmark.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 15 ++++++++++----- 9 files changed, 38 insertions(+), 34 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java index 7cb2e62a07c5a..a312831cb6282 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java @@ -17,8 +17,9 @@ package org.apache.spark.api.shuffle; +import org.apache.spark.api.java.Optional; + import java.util.Objects; -import java.util.Optional; /** * :: Experimental :: @@ -32,16 +33,13 @@ public class ShuffleBlockInfo { private final long length; private final Optional shuffleLocation; - public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length, ShuffleLocation shuffleLocation) { + public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length, + Optional shuffleLocation) { this.shuffleId = shuffleId; this.mapId = mapId; this.reduceId = reduceId; this.length = length; - if (shuffleLocation == ShuffleLocation.EMPTY_LOCATION) { - this.shuffleLocation = Optional.empty(); - } else { - this.shuffleLocation = Optional.of(shuffleLocation); - } + this.shuffleLocation = shuffleLocation; } public int getShuffleId() { @@ -70,11 +68,12 @@ public boolean equals(Object other) { && shuffleId == ((ShuffleBlockInfo) other).shuffleId && mapId == ((ShuffleBlockInfo) other).mapId && reduceId == ((ShuffleBlockInfo) other).reduceId - && length == ((ShuffleBlockInfo) other).length; + && length == ((ShuffleBlockInfo) other).length + && Objects.equals(shuffleLocation, ((ShuffleBlockInfo) other).shuffleLocation); } @Override public int hashCode() { - return Objects.hash(shuffleId, mapId, reduceId, length); + return Objects.hash(shuffleId, mapId, reduceId, length, shuffleLocation); } } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java index 41d474dd6eef8..d06c11b3c01ee 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java @@ -21,6 +21,4 @@ * Marker interface representing a location of a shuffle block. Implementations of shuffle readers * and writers are expected to cast this down to an implementation-specific representation. */ -public interface ShuffleLocation { - ShuffleLocation EMPTY_LOCATION = new ShuffleLocation() {}; -} +public interface ShuffleLocation {} diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 7ac28656aeee5..ebddf5ff6f6e0 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -283,7 +283,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByShuffleLocation(shuffleId: Int, reduceId: Int) - : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1) } @@ -297,7 +297,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -647,7 +647,7 @@ private[spark] class MapOutputTrackerMaster( // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -684,7 +684,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. override def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { @@ -873,9 +873,9 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[ShuffleLocation, ListBuffer[(BlockId, Long)]] + val splitsByAddress = new HashMap[Option[ShuffleLocation], ListBuffer[(BlockId, Long)]] for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" @@ -886,11 +886,11 @@ private[spark] object MapOutputTracker extends Logging { val size = status.getSizeForBlock(part) if (size != 0) { if (status.mapShuffleLocations == null) { - splitsByAddress.getOrElseUpdate(ShuffleLocation.EMPTY_LOCATION, ListBuffer()) += + splitsByAddress.getOrElseUpdate(Option.empty, ListBuffer()) += ((ShuffleBlockId(shuffleId, mapId, part), size)) } else { val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part) - splitsByAddress.getOrElseUpdate(shuffleLoc, ListBuffer()) += + splitsByAddress.getOrElseUpdate(Option.apply(shuffleLoc), ListBuffer()) += ((ShuffleBlockId(shuffleId, mapId, part), size)) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0d405525ef093..e4004ec4c65bc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -23,6 +23,7 @@ import java.nio.ByteBuffer import scala.collection.JavaConverters._ import org.apache.spark._ +import org.apache.spark.api.java.Optional import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer.SerializerManager @@ -68,7 +69,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( block.mapId, block.reduceId, blockInfo._2, - shuffleLocationInfo._1) + Optional.ofNullable(shuffleLocationInfo._1.orNull)) } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index a0c6bbfc727cb..f5127bd1b021b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -95,7 +95,7 @@ private class ShuffleBlockFetcherIterable( mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, minReduceId, maxReduceId + 1) .map { shuffleLocationInfo => val defaultShuffleLocation = shuffleLocationInfo._1 - .asInstanceOf[DefaultMapShuffleLocations] + .get.asInstanceOf[DefaultMapShuffleLocations] (defaultShuffleLocation.getBlockManagerId, shuffleLocationInfo._2) }, maxBytesInFlight, diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 88976861a22fc..6c7719ae38b58 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,13 +17,14 @@ package org.apache.spark.storage -import java.io.{InputStream, IOException} +import java.io.{IOException, InputStream} import java.util.concurrent.LinkedBlockingQueue -import javax.annotation.concurrent.GuardedBy +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import org.apache.spark.api.java.Optional import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReaderInputStream} import org.apache.spark.internal.Logging @@ -474,7 +475,7 @@ final class ShuffleBlockFetcherIterator( val blockId = currentResult.blockId.asInstanceOf[ShuffleBlockId] new ShuffleReaderInputStream( new ShuffleBlockInfo(blockId.shuffleId, blockId.mapId, blockId.reduceId, currentResult.size, - DefaultMapShuffleLocations.get(currentResult.address)), + Optional.of(DefaultMapShuffleLocations.get(currentResult.address))), new BufferReleasingInputStream(input, this)) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 5493eb6ced48e..5df872895146e 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -116,12 +116,12 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext (shuffleBlockId, byteOutputStream.size().toLong) } val blocksToRetrieve = Seq( - (DefaultMapShuffleLocations.get(localBlockManagerId), shuffleBlockIdsAndSizes)) + (Option.apply(DefaultMapShuffleLocations.get(localBlockManagerId)), shuffleBlockIdsAndSizes)) val mapOutputTracker = mock(classOf[MapOutputTracker]) when(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1)) - .thenAnswer(new Answer[Iterator[(ShuffleLocation, Seq[(BlockId, Long)])]] { + .thenAnswer(new Answer[Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]] { def answer(invocationOnMock: InvocationOnMock): - Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { + Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { blocksToRetrieve.iterator } }) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index a23e9e80b1fbc..61dda2dfa177b 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -199,7 +199,7 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { val shuffleBlockId = ShuffleBlockId(0, mapId, 0) (shuffleBlockId, dataFileLength) } - Seq((DefaultMapShuffleLocations.get(dataBlockId), shuffleBlockIdsAndSizes)).toIterator + Seq((Option.apply(DefaultMapShuffleLocations.get(dataBlockId)), shuffleBlockIdsAndSizes)).toIterator } when(dependency.serializer).thenReturn(serializer) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index d244082b4ff6a..4c9610c73464a 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -30,8 +30,9 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester +import org.apache.spark.api.java.Optional import org.apache.spark.{SparkFunSuite, TaskContext} -import org.apache.spark.api.shuffle.ShuffleBlockInfo +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleLocation} import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager} @@ -401,7 +402,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // The first block should be returned without an exception val id1 = iterator.next().getShuffleBlockInfo - assert(id1 === new ShuffleBlockInfo(0, 0, 0, 1, shuffleLocation)) + assert(id1 === new ShuffleBlockInfo(0, 0, 0, 1, Optional.of(shuffleLocation))) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { @@ -484,8 +485,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext.taskMetrics.createTempShuffleReadMetrics()) // Blocks should be returned without exceptions. assert(Set(iterator.next().getShuffleBlockInfo, iterator.next().getShuffleBlockInfo) === - Set(new ShuffleBlockInfo(0, 0, 0, size, DefaultMapShuffleLocations.get(localBmId)), - new ShuffleBlockInfo(0, 1, 0, size, DefaultMapShuffleLocations.get(remoteBmId)))) + Set( + new ShuffleBlockInfo(0, 0, 0, size, + Optional.of(DefaultMapShuffleLocations.get(localBmId))), + new ShuffleBlockInfo(0, 1, 0, size, + Optional.of(DefaultMapShuffleLocations.get(remoteBmId))))) } test("retry corrupt blocks (disabled)") { @@ -540,7 +544,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Continue only after the mock calls onBlockFetchFailure sem.acquire() - val remoteShuffleLocation = DefaultMapShuffleLocations.get(remoteBmId) + val remoteShuffleLocation: Optional[ShuffleLocation] = + Optional.of(DefaultMapShuffleLocations.get(remoteBmId)) // The first block should be returned without an exception val id1 = iterator.next().getShuffleBlockInfo assert(id1 === new ShuffleBlockInfo(0, 0, 0, 1, remoteShuffleLocation)) From 94275fd6ba4b37e311d3d7218b473b0a1ccf6dd8 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 19 Apr 2019 17:00:11 -0700 Subject: [PATCH 46/56] style --- .../apache/spark/storage/ShuffleBlockFetcherIterator.scala | 6 +++--- .../shuffle/sort/BlockStoreShuffleReaderBenchmark.scala | 3 ++- .../spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 6c7719ae38b58..649612bff9de7 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,15 +17,15 @@ package org.apache.spark.storage -import java.io.{IOException, InputStream} +import java.io.{InputStream, IOException} import java.util.concurrent.LinkedBlockingQueue - import javax.annotation.concurrent.GuardedBy + import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} -import org.apache.spark.api.java.Optional import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.api.java.Optional import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReaderInputStream} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index 61dda2dfa177b..4f3e757392c6d 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -199,7 +199,8 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { val shuffleBlockId = ShuffleBlockId(0, mapId, 0) (shuffleBlockId, dataFileLength) } - Seq((Option.apply(DefaultMapShuffleLocations.get(dataBlockId)), shuffleBlockIdsAndSizes)).toIterator + Seq((Option.apply(DefaultMapShuffleLocations.get(dataBlockId)), shuffleBlockIdsAndSizes)) + .toIterator } when(dependency.serializer).thenReturn(serializer) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 4c9610c73464a..65b18a55d6f90 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -30,8 +30,8 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester -import org.apache.spark.api.java.Optional import org.apache.spark.{SparkFunSuite, TaskContext} +import org.apache.spark.api.java.Optional import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleLocation} import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} From 26e97c12ba380c9897ecdcd77c0d5a5316006dc3 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 19 Apr 2019 17:18:12 -0700 Subject: [PATCH 47/56] refactor API --- .../spark/api/shuffle/ShuffleReadSupport.java | 5 +- .../api/shuffle/ShuffleReaderIterable.java | 46 ------------------- .../shuffle/BlockStoreShuffleReader.scala | 8 ++-- .../io/DefaultShuffleReadSupport.scala | 33 ++++--------- 4 files changed, 17 insertions(+), 75 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index 3e3d04a268320..7ca593cff359c 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -20,6 +20,7 @@ import org.apache.spark.annotation.Experimental; import java.io.IOException; +import java.util.Iterator; /** * :: Experimental :: @@ -29,9 +30,9 @@ @Experimental public interface ShuffleReadSupport { /** - * Returns an underlying {@link ShuffleReaderIterable} that will iterate through shuffle data, + * Returns an underlying {@link Iterable} that will iterate through shuffle data, * given an iterable for the shuffle blocks to fetch. */ - ShuffleReaderIterable getPartitionReaders(Iterable blockMetadata) + Iterable getPartitionReaders(Iterable blockMetadata) throws IOException; } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java deleted file mode 100644 index 2307d22342ed5..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderIterable.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import org.apache.spark.annotation.Experimental; - -import java.util.Iterator; - -/** - * :: Experimental :: - * An interface for iterating through shuffle blocks to read. - * @since 3.0.0 - */ -@Experimental -public interface ShuffleReaderIterable extends Iterable { - - interface ShuffleReaderIterator extends Iterator { - /** - * Instructs the shuffle iterator to fetch the last block again. This is useful - * if the block is determined to be corrupt after decryption or decompression. - * - * @throws Exception if current block cannot be retried. - */ - default void retryLastBlock(Throwable t) throws Exception { - throw new UnsupportedOperationException("Cannot retry fetching bad blocks", t); - } - } - - @Override - ShuffleReaderIterator iterator(); -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index e4004ec4c65bc..7ab6fbdcd3785 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -27,7 +27,8 @@ import org.apache.spark.api.java.Optional import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.storage.{ShuffleBlockFetcherIterator, ShuffleBlockId} import org.apache.spark.util.{CompletionIterator, Utils} import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -87,7 +88,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( blockInfo.getShuffleId, blockInfo.getMapId, blockInfo.getReduceId) - if (detectCorrupt && blockInfo.getLength < maxBytesInFlight) { + if (detectCorrupt && blockInfo.getLength < maxBytesInFlight + && shuffleReadSupport.isInstanceOf[DefaultShuffleReadSupport]) { try { val in = serializerManager.wrapStream(blockId, nextStream.getInputStream) val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) @@ -98,7 +100,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( returnStream = out.toChunkedByteBuffer.toInputStream(dispose = true) } catch { case e: IOException => - streamsIterator.retryLastBlock(e) + streamsIterator.asInstanceOf[ShuffleBlockFetcherIterator].retryLast(e) } } else { returnStream = serializerManager.wrapStream(blockId, nextStream.getInputStream) diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index f5127bd1b021b..61bf1305301ba 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -20,8 +20,7 @@ package org.apache.spark.shuffle.io import scala.collection.JavaConverters._ import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} -import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReaderInputStream, ShuffleReaderIterable, ShuffleReadSupport} -import org.apache.spark.api.shuffle.ShuffleReaderIterable.ShuffleReaderIterator +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReaderInputStream, ShuffleReadSupport} import org.apache.spark.internal.config import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations @@ -38,18 +37,11 @@ class DefaultShuffleReadSupport( conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS) private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) - override def getPartitionReaders( - blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): ShuffleReaderIterable = { + override def getPartitionReaders(blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): + java.lang.Iterable[ShuffleReaderInputStream] = { if (blockMetadata.asScala.isEmpty) { - val emptyIterator = new ShuffleReaderIterator { - override def hasNext: Boolean = Iterator.empty.hasNext - - override def next(): ShuffleReaderInputStream = Iterator.empty.next() - } - return new ShuffleReaderIterable { - override def iterator(): ShuffleReaderIterator = emptyIterator - } + return Iterable.empty.asJava } else { val (minReduceId, maxReduceId) = blockMetadata.asScala.map(block => block.getReduceId) .foldLeft(Int.MaxValue, 0) { @@ -69,7 +61,7 @@ class DefaultShuffleReadSupport( maxReduceId, shuffleId, mapOutputTracker - ) + ).asJava } } } @@ -85,10 +77,10 @@ private class ShuffleBlockFetcherIterable( minReduceId: Int, maxReduceId: Int, shuffleId: Int, - mapOutputTracker: MapOutputTracker) extends ShuffleReaderIterable { + mapOutputTracker: MapOutputTracker) extends Iterable[ShuffleReaderInputStream] { - override def iterator: ShuffleReaderIterator = { - val innerIterator = new ShuffleBlockFetcherIterator( + override def iterator: Iterator[ShuffleReaderInputStream] = { + new ShuffleBlockFetcherIterator( context, blockManager.shuffleClient, blockManager, @@ -102,15 +94,8 @@ private class ShuffleBlockFetcherIterable( maxReqsInFlight, maxBlocksInFlightPerAddress, maxReqSizeShuffleToMem, - shuffleMetrics) - val completionIterator = innerIterator.toCompletionIterator - new ShuffleReaderIterator { - override def hasNext: Boolean = completionIterator.hasNext - - override def next(): ShuffleReaderInputStream = completionIterator.next() + shuffleMetrics).toCompletionIterator - override def retryLastBlock(t: Throwable): Unit = innerIterator.retryLast(t) - } } } From 91db77600b2407f6869573323bf552da6328f312 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 19 Apr 2019 17:30:18 -0700 Subject: [PATCH 48/56] cleanup --- .../shuffle/sort/io/DefaultShuffleExecutorComponents.java | 5 +---- .../org/apache/spark/shuffle/BlockStoreShuffleReader.scala | 2 +- .../org/apache/spark/shuffle/sort/SortShuffleManager.scala | 1 - .../apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala | 2 +- .../shuffle/sort/BlockStoreShuffleReaderBenchmark.scala | 2 +- 5 files changed, 4 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java index 4451280bf4b07..ae652e5b9223b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -54,10 +54,7 @@ public ShuffleWriteSupport writes() { @Override public ShuffleReadSupport reads() { checkInitialized(); - return new DefaultShuffleReadSupport( - blockManager, - mapOutputTracker, - sparkConf); + return new DefaultShuffleReadSupport(blockManager, mapOutputTracker, sparkConf); } private void checkInitialized() { diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 7ab6fbdcd3785..ef8dcce1e2f9f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -43,8 +43,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( endPartition: Int, context: TaskContext, readMetrics: ShuffleReadMetricsReporter, - serializerManager: SerializerManager, shuffleReadSupport: ShuffleReadSupport, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, sparkConf: SparkConf = SparkEnv.get.conf) extends ShuffleReader[K, C] with Logging { diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 2fb0f4f78284d..38495ae523d86 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -128,7 +128,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition, context, metrics, - SparkEnv.get.serializerManager, shuffleExecutorComponents.reads()) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 5df872895146e..1bec58c623a7c 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -153,8 +153,8 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext reduceId + 1, taskContext, metrics, - serializerManager, shuffleReadSupport, + serializerManager, mapOutputTracker) assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index 4f3e757392c6d..002d1dce69273 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -215,8 +215,8 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { 1, taskContext, taskContext.taskMetrics().createTempShuffleReadMetrics(), - serializerManager, readSupport, + serializerManager, mapOutputTracker ) } From f0fa7b852effdb2953d990af5fa74b702de4a072 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Mon, 22 Apr 2019 10:50:12 -0700 Subject: [PATCH 49/56] fix tests and style --- .../spark/api/shuffle/ShuffleReadSupport.java | 1 - .../spark/scheduler/DAGSchedulerSuite.scala | 34 +++++++++++-------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index 7ca593cff359c..28d074f9e6978 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -20,7 +20,6 @@ import org.apache.spark.annotation.Experimental; import java.io.IOException; -import java.util.Iterator; /** * :: Experimental :: diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 14b93957734e4..21b4e56c9e801 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -703,7 +703,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -731,7 +731,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // we can see both result blocks now assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 0) - .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) .toSet === HashSet("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) @@ -774,7 +774,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } } else { assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) } } } @@ -1069,7 +1069,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 0) - .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) .toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. @@ -1201,11 +1201,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 0) - .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) .toSet === HashSet("hostA", "hostB")) assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 1) - .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) .toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. @@ -1397,7 +1397,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 2) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostB"), makeShuffleLocation("hostA"))) + HashSet(makeMaybeShuffleLocation("hostB"), makeMaybeShuffleLocation("hostA"))) // finish the next stage normally, which completes the job complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -1803,7 +1803,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostC"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB"))) // Make sure that the reduce stage was now submitted. assert(taskSets.size === 3) @@ -2066,7 +2066,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"))) + HashSet(makeMaybeShuffleLocation("hostA"))) // Reducer should run on the same host that map task ran val reduceTaskSet = taskSets(1) @@ -2112,7 +2112,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"))) + HashSet(makeMaybeShuffleLocation("hostA"))) // Reducer should run where RDD 2 has preferences, even though it also has a shuffle dep val reduceTaskSet = taskSets(1) @@ -2276,7 +2276,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) assert(listener1.results.size === 1) // When attempting the second stage, show a fetch failure @@ -2292,7 +2292,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostC"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB"))) assert(listener2.results.size === 0) // Second stage listener should still not have a result @@ -2302,7 +2302,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostB", rdd2.partitions.length)), (Success, makeMapStatus("hostD", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep2.shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostB"), makeShuffleLocation("hostD"))) + HashSet(makeMaybeShuffleLocation("hostB"), makeMaybeShuffleLocation("hostD"))) assert(listener2.results.size === 1) // Finally, the reduce job should be running as task set 4; make it see a fetch failure, @@ -2341,7 +2341,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) assert(listener1.results.size === 1) // When attempting stage1, trigger a fetch failure. @@ -2367,7 +2367,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - Set(makeShuffleLocation("hostC"), makeShuffleLocation("hostB"))) + Set(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB"))) // After stage0 is finished, stage1 will be submitted and found there is no missing // partitions in it. Then listener got triggered. @@ -2923,6 +2923,10 @@ object DAGSchedulerSuite { def makeShuffleLocation(host: String): MapShuffleLocations = { DefaultMapShuffleLocations.get(makeBlockManagerId(host)) } + + def makeMaybeShuffleLocation(host: String): Option[MapShuffleLocations] = { + Some(DefaultMapShuffleLocations.get(makeBlockManagerId(host))) + } } object FailThisAttempt { From 50c8fc3442689afc28af0e5e981f5c2640eaf9d1 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Mon, 22 Apr 2019 11:27:55 -0700 Subject: [PATCH 50/56] style --- .../java/org/apache/spark/api/shuffle/ShuffleReadSupport.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index 28d074f9e6978..997a504f2d8bd 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -29,8 +29,8 @@ @Experimental public interface ShuffleReadSupport { /** - * Returns an underlying {@link Iterable} that will iterate through shuffle data, - * given an iterable for the shuffle blocks to fetch. + * Returns an underlying {@link Iterable} that will iterate + * through shuffle data, given an iterable for the shuffle blocks to fetch. */ Iterable getPartitionReaders(Iterable blockMetadata) throws IOException; From 4aa4b6e470102d7d2a27cf5d5c458f6b74876497 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Mon, 22 Apr 2019 12:10:30 -0700 Subject: [PATCH 51/56] reorder result for test? --- .../org/apache/spark/MapOutputTrackerSuite.scala | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 0a77c4f6d5838..8fcbc845d1a7b 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -71,9 +71,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val statuses = tracker.getMapSizesByShuffleLocation(10, 0) assert(statuses.toSet === Seq( - (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), + (Some(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000))), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), - (DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000)), + (Some(DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000))), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) .toSet) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -155,7 +155,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByShuffleLocation(10, 0).toSeq === Seq( - (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), + (Some(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000))), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) @@ -324,12 +324,13 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(size10000, size0, size1000, size0))) assert(tracker.containsShuffle(10)) - assert(tracker.getMapSizesByShuffleLocation(10, 0, 4).toSeq === + assert(tracker.getMapSizesByShuffleLocation(10, 0, 4) + .map(x => (x._1.get, x._2)).toSeq === Seq( - (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), - Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), (DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000)), - Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))) + Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))), + (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), + Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))) ) ) From 7d23f472afb17426d6b5156091c84b253f69e210 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 26 Apr 2019 16:18:43 -0700 Subject: [PATCH 52/56] wip --- .../spark/api/shuffle/ShuffleReadSupport.java | 5 +++-- .../io/DefaultShuffleReadSupport.scala | 19 ++++++++++--------- .../storage/ShuffleBlockFetcherIterator.scala | 13 +++++-------- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java index 997a504f2d8bd..9cd8fde09064b 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -20,6 +20,7 @@ import org.apache.spark.annotation.Experimental; import java.io.IOException; +import java.io.InputStream; /** * :: Experimental :: @@ -29,9 +30,9 @@ @Experimental public interface ShuffleReadSupport { /** - * Returns an underlying {@link Iterable} that will iterate + * Returns an underlying {@link Iterable} that will iterate * through shuffle data, given an iterable for the shuffle blocks to fetch. */ - Iterable getPartitionReaders(Iterable blockMetadata) + Iterable getPartitionReaders(Iterable blockMetadata) throws IOException; } diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index 61bf1305301ba..a9be5aa1d7931 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -17,10 +17,12 @@ package org.apache.spark.shuffle.io +import java.io.InputStream + import scala.collection.JavaConverters._ import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} -import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReaderInputStream, ShuffleReadSupport} +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.config import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations @@ -38,17 +40,16 @@ class DefaultShuffleReadSupport( private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) override def getPartitionReaders(blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): - java.lang.Iterable[ShuffleReaderInputStream] = { + java.lang.Iterable[InputStream] = { - if (blockMetadata.asScala.isEmpty) { - return Iterable.empty.asJava + val iterableToReturn = if (blockMetadata.asScala.nonEmpty) { + Iterable.empty } else { val (minReduceId, maxReduceId) = blockMetadata.asScala.map(block => block.getReduceId) .foldLeft(Int.MaxValue, 0) { case ((min, max), elem) => (math.min(min, elem), math.max(max, elem)) } val shuffleId = blockMetadata.asScala.head.getShuffleId - new ShuffleBlockFetcherIterable( TaskContext.get(), blockManager, @@ -61,8 +62,9 @@ class DefaultShuffleReadSupport( maxReduceId, shuffleId, mapOutputTracker - ).asJava + ) } + iterableToReturn.asJava } } @@ -77,9 +79,9 @@ private class ShuffleBlockFetcherIterable( minReduceId: Int, maxReduceId: Int, shuffleId: Int, - mapOutputTracker: MapOutputTracker) extends Iterable[ShuffleReaderInputStream] { + mapOutputTracker: MapOutputTracker) extends Iterable[InputStream] { - override def iterator: Iterator[ShuffleReaderInputStream] = { + override def iterator: Iterator[InputStream] = { new ShuffleBlockFetcherIterator( context, blockManager.shuffleClient, @@ -95,7 +97,6 @@ private class ShuffleBlockFetcherIterable( maxBlocksInFlightPerAddress, maxReqSizeShuffleToMem, shuffleMetrics).toCompletionIterator - } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 649612bff9de7..e998c777041c8 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -71,7 +71,7 @@ final class ShuffleBlockFetcherIterator( maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, shuffleMetrics: ShuffleReadMetricsReporter) - extends Iterator[ShuffleReaderInputStream] with DownloadFileManager with Logging { + extends Iterator[InputStream] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -394,7 +394,7 @@ final class ShuffleBlockFetcherIterator( * * Throws a FetchFailedException if the next block could not be fetched. */ - override def next(): ShuffleReaderInputStream = { + override def next(): InputStream = { if (!hasNext) { throw new NoSuchElementException() } @@ -473,10 +473,7 @@ final class ShuffleBlockFetcherIterator( } currentResult = result.asInstanceOf[SuccessFetchResult] val blockId = currentResult.blockId.asInstanceOf[ShuffleBlockId] - new ShuffleReaderInputStream( - new ShuffleBlockInfo(blockId.shuffleId, blockId.mapId, blockId.reduceId, currentResult.size, - Optional.of(DefaultMapShuffleLocations.get(currentResult.address))), - new BufferReleasingInputStream(input, this)) + new BufferReleasingInputStream(input, this) } def retryLast(t: Throwable): Unit = { @@ -494,8 +491,8 @@ final class ShuffleBlockFetcherIterator( } } - def toCompletionIterator: Iterator[ShuffleReaderInputStream] = { - CompletionIterator[ShuffleReaderInputStream, this.type](this, + def toCompletionIterator: Iterator[InputStream] = { + CompletionIterator[InputStream, this.type](this, onCompleteCallback.onComplete(context)) } From 363d4abfde148c9f7113fa19e55a1389548daabd Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Mon, 29 Apr 2019 13:59:27 -0700 Subject: [PATCH 53/56] address comments --- .../io/DefaultShuffleExecutorComponents.java | 5 +- .../shuffle/BlockStoreShuffleReader.scala | 43 +++---- .../io/DefaultShuffleReadSupport.scala | 11 +- .../storage/ShuffleBlockFetcherIterator.scala | 59 +++++++--- .../BlockStoreShuffleReaderSuite.scala | 2 +- .../BlockStoreShuffleReaderBenchmark.scala | 6 +- .../ShuffleBlockFetcherIteratorSuite.scala | 105 ++++++++---------- 7 files changed, 124 insertions(+), 107 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java index ae652e5b9223b..57e5b1d0eea13 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -23,6 +23,7 @@ import org.apache.spark.api.shuffle.ShuffleExecutorComponents; import org.apache.spark.api.shuffle.ShuffleReadSupport; import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.io.DefaultShuffleReadSupport; import org.apache.spark.storage.BlockManager; @@ -33,6 +34,7 @@ public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponen private BlockManager blockManager; private IndexShuffleBlockResolver blockResolver; private MapOutputTracker mapOutputTracker; + private SerializerManager serializerManager; public DefaultShuffleExecutorComponents(SparkConf sparkConf) { this.sparkConf = sparkConf; @@ -42,6 +44,7 @@ public DefaultShuffleExecutorComponents(SparkConf sparkConf) { public void initializeExecutor(String appId, String execId) { blockManager = SparkEnv.get().blockManager(); mapOutputTracker = SparkEnv.get().mapOutputTracker(); + serializerManager = SparkEnv.get().serializerManager(); blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); } @@ -54,7 +57,7 @@ public ShuffleWriteSupport writes() { @Override public ShuffleReadSupport reads() { checkInitialized(); - return new DefaultShuffleReadSupport(blockManager, mapOutputTracker, sparkConf); + return new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, sparkConf); } private void checkInitialized() { diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index ef8dcce1e2f9f..530c3694ad1ec 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,8 +17,7 @@ package org.apache.spark.shuffle -import java.io.{InputStream, IOException} -import java.nio.ByteBuffer +import java.io.InputStream import scala.collection.JavaConverters._ @@ -26,12 +25,12 @@ import org.apache.spark._ import org.apache.spark.api.java.Optional import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.{config, Logging} +import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.io.DefaultShuffleReadSupport import org.apache.spark.storage.{ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.util.io.ChunkedByteBufferOutputStream /** * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by @@ -51,9 +50,9 @@ private[spark] class BlockStoreShuffleReader[K, C]( private val dep = handle.dependency - private val detectCorrupt = sparkConf.get(config.SHUFFLE_DETECT_CORRUPT) + private val compressionCodec = CompressionCodec.createCodec(sparkConf) - private val maxBytesInFlight = sparkConf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 + private val compressShuffle = sparkConf.get(config.SHUFFLE_COMPRESS) /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { @@ -82,28 +81,18 @@ private[spark] class BlockStoreShuffleReader[K, C]( override def next(): InputStream = { var returnStream: InputStream = null while (streamsIterator.hasNext && returnStream == null) { - val nextStream = streamsIterator.next() - val blockInfo = nextStream.getShuffleBlockInfo - val blockId = ShuffleBlockId( - blockInfo.getShuffleId, - blockInfo.getMapId, - blockInfo.getReduceId) - if (detectCorrupt && blockInfo.getLength < maxBytesInFlight - && shuffleReadSupport.isInstanceOf[DefaultShuffleReadSupport]) { - try { - val in = serializerManager.wrapStream(blockId, nextStream.getInputStream) - val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) - // Decompress the whole block at once to detect any corruption, which could increase - // the memory usage and potentially increase the chance of OOM. - // TODO: manage the memory used here, and spill it into disk in case of OOM. - Utils.copyStream(in, out, closeStreams = true) - returnStream = out.toChunkedByteBuffer.toInputStream(dispose = true) - } catch { - case e: IOException => - streamsIterator.asInstanceOf[ShuffleBlockFetcherIterator].retryLast(e) - } + if (shuffleReadSupport.isInstanceOf[DefaultShuffleReadSupport]) { + // The default implementation checks for corrupt streams, so it will already have + // decompressed/decrypted the bytes + returnStream = streamsIterator.next() } else { - returnStream = serializerManager.wrapStream(blockId, nextStream.getInputStream) + val nextStream = streamsIterator.next() + returnStream = if (compressShuffle) { + compressionCodec.compressedInputStream( + serializerManager.wrapForEncryption(nextStream)) + } else { + serializerManager.wrapForEncryption(nextStream) + } } } if (returnStream == null) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index a9be5aa1d7931..9b9b8508e88aa 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.config +import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} @@ -31,6 +32,7 @@ import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} class DefaultShuffleReadSupport( blockManager: BlockManager, mapOutputTracker: MapOutputTracker, + serializerManager: SerializerManager, conf: SparkConf) extends ShuffleReadSupport { private val maxBytesInFlight = conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 @@ -38,11 +40,12 @@ class DefaultShuffleReadSupport( private val maxBlocksInFlightPerAddress = conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS) private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + private val detectCorrupt = conf.get(config.SHUFFLE_DETECT_CORRUPT) override def getPartitionReaders(blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): java.lang.Iterable[InputStream] = { - val iterableToReturn = if (blockMetadata.asScala.nonEmpty) { + val iterableToReturn = if (blockMetadata.asScala.isEmpty) { Iterable.empty } else { val (minReduceId, maxReduceId) = blockMetadata.asScala.map(block => block.getReduceId) @@ -53,10 +56,12 @@ class DefaultShuffleReadSupport( new ShuffleBlockFetcherIterable( TaskContext.get(), blockManager, + serializerManager, maxBytesInFlight, maxReqsInFlight, maxBlocksInFlightPerAddress, maxReqSizeShuffleToMem, + detectCorrupt, shuffleMetrics = TaskContext.get().taskMetrics().createTempShuffleReadMetrics(), minReduceId, maxReduceId, @@ -71,10 +76,12 @@ class DefaultShuffleReadSupport( private class ShuffleBlockFetcherIterable( context: TaskContext, blockManager: BlockManager, + serializerManager: SerializerManager, maxBytesInFlight: Long, maxReqsInFlight: Int, maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, + detectCorruption: Boolean, shuffleMetrics: ShuffleReadMetricsReporter, minReduceId: Int, maxReduceId: Int, @@ -92,10 +99,12 @@ private class ShuffleBlockFetcherIterable( .get.asInstanceOf[DefaultMapShuffleLocations] (defaultShuffleLocation.getBlockManagerId, shuffleLocationInfo._2) }, + serializerManager.wrapStream, maxBytesInFlight, maxReqsInFlight, maxBlocksInFlightPerAddress, maxReqSizeShuffleToMem, + detectCorruption, shuffleMetrics).toCompletionIterator } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index e998c777041c8..287ffdd6e10e6 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.{InputStream, IOException} +import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue import javax.annotation.concurrent.GuardedBy @@ -25,15 +26,13 @@ import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.api.java.Optional -import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReaderInputStream} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} -import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} +import org.apache.spark.util.io.ChunkedByteBufferOutputStream /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -53,11 +52,13 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} * order to throttle the memory usage. Note that zero-sized blocks are * already excluded, which happened in * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. + * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point * for a given remote host:port. * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. + * @param detectCorrupt whether to detect any corruption in fetched blocks. * @param shuffleMetrics used to report shuffle metrics. */ private[spark] @@ -66,10 +67,12 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])], + streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, + detectCorrupt: Boolean, shuffleMetrics: ShuffleReadMetricsReporter) extends Iterator[InputStream] with DownloadFileManager with Logging { @@ -450,7 +453,7 @@ final class ShuffleBlockFetcherIterator( throwFetchFailedException(blockId, address, new IOException(msg)) } - input = try { + val in = try { buf.createInputStream() } catch { // The exception could only be throwed by local shuffle block @@ -460,6 +463,38 @@ final class ShuffleBlockFetcherIterator( buf.release() throwFetchFailedException(blockId, address, e) } + var isStreamCopied: Boolean = false + try { + input = streamWrapper(blockId, in) + // Only copy the stream if it's wrapped by compression or encryption, also the size of + // block is small (the decompressed block is smaller than maxBytesInFlight) + if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) { + isStreamCopied = true + val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) + // Decompress the whole block at once to detect any corruption, which could increase + // the memory usage tne potential increase the chance of OOM. + // TODO: manage the memory used here, and spill it into disk in case of OOM. + Utils.copyStream(input, out, closeStreams = true) + input = out.toChunkedByteBuffer.toInputStream(dispose = true) + } + } catch { + case e: IOException => + buf.release() + if (buf.isInstanceOf[FileSegmentManagedBuffer] + || corruptedBlocks.contains(blockId)) { + throwFetchFailedException(blockId, address, e) + } else { + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest(address, Array((blockId, size))) + result = null + } + } finally { + // TODO: release the buf here to free memory earlier + if (isStreamCopied) { + in.close() + } + } case FailureFetchResult(blockId, address, e) => throwFetchFailedException(blockId, address, e) } @@ -476,19 +511,9 @@ final class ShuffleBlockFetcherIterator( new BufferReleasingInputStream(input, this) } - def retryLast(t: Throwable): Unit = { - val blockId = currentResult.blockId - if (corruptedBlocks.contains(blockId)) { - throwFetchFailedException(blockId, currentResult.address, t) - } else { - logWarning(s"got a corrupted block $blockId from $currentResult.address, fetch again", t) - corruptedBlocks += blockId - fetchRequests += FetchRequest(currentResult.address, - Array((currentResult.blockId, currentResult.size))) - // Send fetch requests up to maxBytesInFlight - numBlocksToFetch += 1 - fetchUpToMaxBytes() - } + // for testing only + def getCurrentBlock(): ShuffleBlockId = { + currentResult.blockId.asInstanceOf[ShuffleBlockId] } def toCompletionIterator: Iterator[InputStream] = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 1bec58c623a7c..6468914bf3185 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -146,7 +146,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val shuffleReadSupport = - new DefaultShuffleReadSupport(blockManager, mapOutputTracker, testConf) + new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, testConf) val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index 002d1dce69273..4dc1251a4ca84 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -207,7 +207,11 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { when(dependency.aggregator).thenReturn(aggregator) when(dependency.keyOrdering).thenReturn(sorter) - val readSupport = new DefaultShuffleReadSupport(blockManager, mapOutputTracker, defaultConf) + val readSupport = new DefaultShuffleReadSupport( + blockManager, + mapOutputTracker, + serializerManager, + defaultConf) new BlockStoreShuffleReader[String, String]( shuffleHandle, diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 65b18a55d6f90..b77622d0dcc3b 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -21,23 +21,20 @@ import java.io.{File, InputStream, IOException} import java.util.UUID import java.util.concurrent.Semaphore -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.Future - import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{mock, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future import org.apache.spark.{SparkFunSuite, TaskContext} -import org.apache.spark.api.java.Optional -import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleLocation} import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager} +import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.util.Utils @@ -114,10 +111,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, + (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, Int.MaxValue, + true, metrics) // 3 local blocks fetched in initialization @@ -125,15 +124,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - val shuffleInputStream = iterator.next() - val shuffleBlockInfo = shuffleInputStream.getShuffleBlockInfo - val inputStream = shuffleInputStream.getInputStream + val inputStream = iterator.next() + val blockId = iterator.getCurrentBlock() // Make sure we release buffers when a wrapped input stream is closed. - val blockId = ShuffleBlockId( - shuffleBlockInfo.getShuffleId, - shuffleBlockInfo.getMapId, - shuffleBlockInfo.getReduceId) val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream] @@ -197,18 +191,20 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, + (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, Int.MaxValue, + true, taskContext.taskMetrics.createTempShuffleReadMetrics()) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() - iterator.next().getInputStream.close() // close() first block's input stream + iterator.next().close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator - val subIter = iterator.next().getInputStream + val subIter = iterator.next() // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() @@ -263,10 +259,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, + (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, Int.MaxValue, + true, taskContext.taskMetrics.createTempShuffleReadMetrics()) @@ -321,10 +319,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, + (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, Int.MaxValue, + true, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure @@ -390,19 +390,21 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, + (_, in) => new LimitedInputStream(in, 100), 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, Int.MaxValue, + true, taskContext.taskMetrics.createTempShuffleReadMetrics()) - val shuffleLocation = DefaultMapShuffleLocations.get(remoteBmId) // Continue only after the mock calls onBlockFetchFailure sem.acquire() // The first block should be returned without an exception - val id1 = iterator.next().getShuffleBlockInfo - assert(id1 === new ShuffleBlockInfo(0, 0, 0, 1, Optional.of(shuffleLocation))) + iterator.next() + val id1 = iterator.getCurrentBlock() + assert(id1 === ShuffleBlockId(0, 0, 0)) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { @@ -417,39 +419,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - // This should fail to read the bytes and call for a retry - val readByte = readNextStreamAndRetryOnError(iterator) - assert(readByte === None) + // The next block is corrupt local block (the second one is corrupt and retried) + intercept[FetchFailedException] { iterator.next() } sem.acquire() - // The next call should fail and not call for a retry because the stream wasn't - // corrupt, but the fetch itself failed - intercept[FetchFailedException] { - readNextStreamAndRetryOnError(iterator) - } - // The next call is the retry of the second block, which fails - intercept[FetchFailedException] { - readNextStreamAndRetryOnError(iterator) - } - } - - def readNextStreamAndRetryOnError(iterator: ShuffleBlockFetcherIterator): Option[Byte] = { - try { - val stream = iterator.next() - val readByte = Array[Byte](1) - stream.getInputStream.read(readByte, 0, 1) - Some(readByte) - } catch { - case e: IOException => - iterator.retryLast(e) - } - None + intercept[FetchFailedException] { iterator.next() } } test("big blocks are not checked for corruption") { - val size = 10000L - val corruptBuffer = mockCorruptBuffer(size) + val corruptBuffer = mockCorruptBuffer(10000L) val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -478,18 +457,19 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, + (_, in) => new LimitedInputStream(in, 100), 2048, Int.MaxValue, Int.MaxValue, Int.MaxValue, + true, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Blocks should be returned without exceptions. - assert(Set(iterator.next().getShuffleBlockInfo, iterator.next().getShuffleBlockInfo) === - Set( - new ShuffleBlockInfo(0, 0, 0, size, - Optional.of(DefaultMapShuffleLocations.get(localBmId))), - new ShuffleBlockInfo(0, 1, 0, size, - Optional.of(DefaultMapShuffleLocations.get(remoteBmId))))) + iterator.next() + val blockId1 = iterator.getCurrentBlock() + iterator.next() + val blockId2 = iterator.getCurrentBlock() + assert(Set(blockId1, blockId2) === Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) } test("retry corrupt blocks (disabled)") { @@ -535,24 +515,27 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, + (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, Int.MaxValue, + true, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure sem.acquire() - val remoteShuffleLocation: Optional[ShuffleLocation] = - Optional.of(DefaultMapShuffleLocations.get(remoteBmId)) // The first block should be returned without an exception - val id1 = iterator.next().getShuffleBlockInfo - assert(id1 === new ShuffleBlockInfo(0, 0, 0, 1, remoteShuffleLocation)) - val id2 = iterator.next().getShuffleBlockInfo - assert(id2 === new ShuffleBlockInfo(0, 1, 0, 1, remoteShuffleLocation)) - val id3 = iterator.next().getShuffleBlockInfo - assert(id3 === new ShuffleBlockInfo(0, 2, 0, 1, remoteShuffleLocation)) + iterator.next() + val id1 = iterator.getCurrentBlock() + assert(id1 === ShuffleBlockId(0, 0, 0)) + iterator.next() + val id2 = iterator.getCurrentBlock() + assert(id2 === ShuffleBlockId(0, 1, 0)) + iterator.next() + val id3 = iterator.getCurrentBlock() + assert(id3 === ShuffleBlockId(0, 2, 0)) } test("Blocks should be shuffled to disk when size of the request is above the" + @@ -597,10 +580,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, + (_, in) => in, maxBytesInFlight = Int.MaxValue, maxReqsInFlight = Int.MaxValue, maxBlocksInFlightPerAddress = Int.MaxValue, maxReqSizeShuffleToMem = 200, + true, taskContext.taskMetrics.createTempShuffleReadMetrics()) } @@ -642,10 +627,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress.toIterator, + (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, Int.MaxValue, + true, taskContext.taskMetrics.createTempShuffleReadMetrics()) // All blocks fetched return zero length and should trigger a receive-side error: From bb7fa4c572a431ee1b0d376475566049d290cf6c Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Mon, 29 Apr 2019 14:23:41 -0700 Subject: [PATCH 54/56] style --- .../shuffle/sort/io/DefaultShuffleExecutorComponents.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java index 57e5b1d0eea13..91a5d7f7945ee 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -57,7 +57,10 @@ public ShuffleWriteSupport writes() { @Override public ShuffleReadSupport reads() { checkInitialized(); - return new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, sparkConf); + return new DefaultShuffleReadSupport(blockManager, + mapOutputTracker, + serializerManager, + sparkConf); } private void checkInitialized() { From 711109bc558c5055ba89b6850d96875bd71a17ab Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Mon, 29 Apr 2019 14:28:23 -0700 Subject: [PATCH 55/56] cleanup tests --- .../spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index b77622d0dcc3b..a4b6920be04c0 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -457,7 +457,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - (_, in) => new LimitedInputStream(in, 100), + (_, in) => new LimitedInputStream(in, 10000), 2048, Int.MaxValue, Int.MaxValue, @@ -515,12 +515,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - (_, in) => in, + (_, in) => new LimitedInputStream(in, 100), 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, Int.MaxValue, - true, + false, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure @@ -585,7 +585,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxReqsInFlight = Int.MaxValue, maxBlocksInFlightPerAddress = Int.MaxValue, maxReqSizeShuffleToMem = 200, - true, + detectCorrupt = true, taskContext.taskMetrics.createTempShuffleReadMetrics()) } From 04a135c44cf2e73b9efea5f6be3aaf15eacc277e Mon Sep 17 00:00:00 2001 From: mcheah Date: Tue, 30 Apr 2019 13:12:19 -0700 Subject: [PATCH 56/56] Remove unused class --- .../api/shuffle/ShuffleReaderInputStream.java | 47 ------------------- 1 file changed, 47 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java deleted file mode 100644 index 0d203ea1ebdcc..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReaderInputStream.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import org.apache.spark.annotation.Experimental; - -import java.io.InputStream; - -/** - * :: Experimental :: - * An object containing the shuffle block's input stream and information about that block. - * @since 3.0.0 - */ -@Experimental -public class ShuffleReaderInputStream { - - private final ShuffleBlockInfo shuffleBlockInfo; - private final InputStream inputStream; - - public ShuffleReaderInputStream(ShuffleBlockInfo shuffleBlockInfo, InputStream inputStream) { - this.shuffleBlockInfo = shuffleBlockInfo; - this.inputStream = inputStream; - } - - public ShuffleBlockInfo getShuffleBlockInfo() { - return shuffleBlockInfo; - } - - public InputStream getInputStream() { - return inputStream; - } -}