diff --git a/core/pom.xml b/core/pom.xml
index 14b217d7fb22e..02adae68b8a38 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -229,6 +229,10 @@
org.scala-lang.modulesscala-xml_${scala.binary.version}
+
+ org.scala-lang.modules
+ scala-java8-compat_${scala.binary.version}
+ org.scala-langscala-library
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java
index e4554bda8acab..b8633c08c4932 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle.api;
+import java.util.Map;
+
import org.apache.spark.annotation.Private;
/**
@@ -44,12 +46,18 @@ public interface ShuffleDataIO {
/**
* Called once on executor processes to bootstrap the shuffle data storage modules that
* are only invoked on the executors.
+ *
+ * @param appId The Spark application id
+ * @param execId The unique identifier of the executor being initialized
+ * @param extraConfigs Extra configs that were returned by
+ * {@link ShuffleDriverComponents#getAddedExecutorSparkConf()}
*/
- ShuffleExecutorComponents executor();
+ ShuffleExecutorComponents initializeShuffleExecutorComponents(
+ String appId, String execId, Map extraConfigs);
/**
* Called once on driver process to bootstrap the shuffle metadata modules that
* are maintained by the driver.
*/
- ShuffleDriverComponents driver();
+ ShuffleDriverComponents initializeShuffleDriverComponents();
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java
index b4cec17b85b32..e8b664de25eab 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java
@@ -20,6 +20,8 @@
import java.util.Map;
import org.apache.spark.annotation.Private;
+import org.apache.spark.shuffle.api.metadata.NoOpShuffleOutputTracker;
+import org.apache.spark.shuffle.api.metadata.ShuffleOutputTracker;
/**
* :: Private ::
@@ -29,36 +31,21 @@
public interface ShuffleDriverComponents {
/**
- * Called once in the driver to bootstrap this module that is specific to this application.
- * This method is called before submitting executor requests to the cluster manager.
- *
- * This method should prepare the module with its shuffle components i.e. registering against
- * an external file servers or shuffle services, or creating tables in a shuffle
- * storage data database.
+ * Provide additional configuration for the executors when their plugin system is initialized
+ * via {@link ShuffleDataIO#initializeShuffleExecutorComponents(String, String, Map)} ()}
*
* @return additional SparkConf settings necessary for initializing the executor components.
* This would include configurations that cannot be statically set on the application, like
* the host:port of external services for shuffle storage.
*/
- Map initializeApplication();
+ Map getAddedExecutorSparkConf();
/**
* Called once at the end of the Spark application to clean up any existing shuffle state.
*/
void cleanupApplication();
- /**
- * Called once per shuffle id when the shuffle id is first generated for a shuffle stage.
- *
- * @param shuffleId The unique identifier for the shuffle stage.
- */
- default void registerShuffle(int shuffleId) {}
-
- /**
- * Removes shuffle data associated with the given shuffle.
- *
- * @param shuffleId The unique identifier for the shuffle stage.
- * @param blocking Whether this call should block on the deletion of the data.
- */
- default void removeShuffle(int shuffleId, boolean blocking) {}
+ default ShuffleOutputTracker shuffleOutputTracker() {
+ return new NoOpShuffleOutputTracker();
+ }
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java
index 30ca177545789..abb5c3b76e2e7 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java
@@ -18,7 +18,6 @@
package org.apache.spark.shuffle.api;
import java.io.IOException;
-import java.util.Map;
import java.util.Optional;
import org.apache.spark.annotation.Private;
@@ -32,17 +31,6 @@
@Private
public interface ShuffleExecutorComponents {
- /**
- * Called once per executor to bootstrap this module with state that is specific to
- * that executor, specifically the application ID and executor ID.
- *
- * @param appId The Spark application id
- * @param execId The unique identifier of the executor being initialized
- * @param extraConfigs Extra configs that were returned by
- * {@link ShuffleDriverComponents#initializeApplication()}
- */
- void initializeExecutor(String appId, String execId, Map extraConfigs);
-
/**
* Called once per map task to create a writer that will be responsible for persisting all the
* partitioned bytes written by that map task.
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
index cad8dcfda52bc..d4f5ef9ef730a 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
@@ -21,6 +21,7 @@
import java.io.IOException;
import org.apache.spark.annotation.Private;
+import org.apache.spark.shuffle.api.metadata.MapOutputCommitMessage;
/**
* Optional extension for partition writing that is optimized for transferring a single
@@ -32,5 +33,6 @@ public interface SingleSpillShuffleMapOutputWriter {
/**
* Transfer a file that contains the bytes of all the partitions written by this map task.
*/
- void transferMapSpillFile(File mapOutputFile, long[] partitionLengths) throws IOException;
+ MapOutputCommitMessage transferMapSpillFile(File mapOutputFile, long[] partitionLengths)
+ throws IOException;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/metadata/NoOpShuffleOutputTracker.java b/core/src/main/java/org/apache/spark/shuffle/api/metadata/NoOpShuffleOutputTracker.java
new file mode 100644
index 0000000000000..d922b421a7f53
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/metadata/NoOpShuffleOutputTracker.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api.metadata;
+
+/**
+ * An implementation of shuffle output tracking that does not keep track of any shuffle state.
+ */
+public class NoOpShuffleOutputTracker implements ShuffleOutputTracker {
+
+ @Override
+ public void registerShuffle(int shuffleId) {}
+
+ @Override
+ public void unregisterShuffle(int shuffleId, boolean blocking) {}
+
+ @Override
+ public void registerMapOutput(
+ int shuffleId, int mapIndex, long mapId, MapOutputMetadata mapOutputMetadata) {}
+
+ @Override
+ public void removeMapOutput(int shuffleId, int mapIndex, long mapId) {}
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/metadata/ShuffleOutputTracker.java b/core/src/main/java/org/apache/spark/shuffle/api/metadata/ShuffleOutputTracker.java
new file mode 100644
index 0000000000000..55b8457398ff1
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/metadata/ShuffleOutputTracker.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api.metadata;
+
+/**
+ * :: Private ::
+ *
+ * A plugin that can monitor the storage of shuffle data from map tasks, and can provide
+ * metadata to shuffle readers to aid their reading of shuffle blocks in reduce tasks.
+ *
+ * {@link MapOutputMetadata} instances provided from the plugin tree's implementation of
+ * {@link org.apache.spark.shuffle.api.ShuffleMapOutputWriter} are sent to this module's map output
+ * metadata registration method in {@link #registerMapOutput(int, int, long, MapOutputMetadata)}.
+ *
+ * Implementations MUST be thread-safe. Spark will invoke methods in this module in parallel.
+ *
+ * A singleton instance of this module is instantiated on the driver via
+ * {@link ShuffleDriverComponents#shuffleOutputTracker()}.
+ */
+public interface ShuffleOutputTracker {
+
+ /**
+ * Called when a new shuffle stage is going to be run.
+ *
+ * @param shuffleId the unique identifier for the new shuffle stage
+ */
+ void registerShuffle(int shuffleId);
+
+ /**
+ * Called when the shuffle with the given id is unregistered because it will no longer
+ * be used by Spark tasks.
+ *
+ * @param shuffleId the unique identifier for the shuffle stage to be unregistered
+ */
+ void unregisterShuffle(int shuffleId, boolean blocking);
+
+ /**
+ * Called when a map task completes, and the map output writer has provided metadata to be
+ * persisted by this shuffle output tracker.
+ *
+ * @param shuffleId the unique identifier for the shuffle stage that the map task is a
+ * part of
+ * @param mapIndex the map index of the map task in its shuffle map stage - not
+ * necessarily unique across multiple attempts of this task
+ * @param mapId the identifier for this map task, which is unique even across
+ * multiple attempts at this task
+ * @param mapOutputMetadata metadata about the map output data's storage returned by the map
+ * task's writer
+ *
+ */
+ void registerMapOutput(
+ int shuffleId, int mapIndex, long mapId, MapOutputMetadata mapOutputMetadata);
+
+ /**
+ * Called when the given map output is discarded, and will not longer be used in future Spark
+ * shuffles.
+ *
+ * @param shuffleId the unique identifier for the shuffle stage that the map task is a
+ * part of
+ * @param mapIndex the map index of the map task which is having its output being
+ * discarded - not necessarily unique across multiple attempts of this
+ * task
+ * @param mapId the identifier for the map task which is having its output being
+ * discarded, which is unique even across multiple attempts at this task
+ */
+ void removeMapOutput(int shuffleId, int mapIndex, long mapId);
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 256789b8c7827..eb91c577bc19c 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -30,6 +30,7 @@
import scala.Product2;
import scala.Tuple2;
import scala.collection.Iterator;
+import scala.compat.java8.OptionConverters;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Closeables;
@@ -39,17 +40,18 @@
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
+import org.apache.spark.scheduler.MapTaskResult;
import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
import org.apache.spark.internal.config.package$;
-import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.shuffle.api.metadata.MapOutputCommitMessage;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;
@@ -92,8 +94,8 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
/** Array of file writers, one for each partition */
private DiskBlockObjectWriter[] partitionWriters;
private FileSegment[] partitionWriterSegments;
- @Nullable private MapStatus mapStatus;
- private long[] partitionLengths;
+ @Nullable private MapTaskResult taskResult;
+ private MapOutputCommitMessage mapOutputCommitMessage;
/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
@@ -130,9 +132,13 @@ public void write(Iterator> records) throws IOException {
.createMapOutputWriter(shuffleId, mapId, numPartitions);
try {
if (!records.hasNext()) {
- partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths();
- mapStatus = MapStatus$.MODULE$.apply(
- blockManager.shuffleServerId(), partitionLengths, mapId);
+ mapOutputCommitMessage = mapOutputWriter.commitAllPartitions();
+ taskResult = new MapTaskResult(
+ MapStatus$.MODULE$.apply(
+ blockManager.shuffleServerId(),
+ mapOutputCommitMessage.getPartitionLengths(),
+ mapId),
+ OptionConverters.toScala(mapOutputCommitMessage.getMapOutputMetadata()));
return;
}
final SerializerInstance serInstance = serializer.newInstance();
@@ -164,9 +170,13 @@ public void write(Iterator> records) throws IOException {
}
}
- partitionLengths = writePartitionedData(mapOutputWriter);
- mapStatus = MapStatus$.MODULE$.apply(
- blockManager.shuffleServerId(), partitionLengths, mapId);
+ mapOutputCommitMessage = writePartitionedData(mapOutputWriter);
+ taskResult = new MapTaskResult(
+ MapStatus$.MODULE$.apply(
+ blockManager.shuffleServerId(),
+ mapOutputCommitMessage.getPartitionLengths(),
+ mapId),
+ OptionConverters.toScala(mapOutputCommitMessage.getMapOutputMetadata()));
} catch (Exception e) {
try {
mapOutputWriter.abort(e);
@@ -179,8 +189,8 @@ public void write(Iterator> records) throws IOException {
}
@VisibleForTesting
- long[] getPartitionLengths() {
- return partitionLengths;
+ MapOutputCommitMessage getMapOutputCommitMessage() {
+ return mapOutputCommitMessage;
}
/**
@@ -188,7 +198,8 @@ long[] getPartitionLengths() {
*
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
*/
- private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) throws IOException {
+ private MapOutputCommitMessage writePartitionedData(ShuffleMapOutputWriter mapOutputWriter)
+ throws IOException {
// Track location of the partition starts in the output file
if (partitionWriters != null) {
final long writeStartTime = System.nanoTime();
@@ -219,7 +230,7 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro
}
partitionWriters = null;
}
- return mapOutputWriter.commitAllPartitions().getPartitionLengths();
+ return mapOutputWriter.commitAllPartitions();
}
private void writePartitionedDataWithChannel(
@@ -259,16 +270,16 @@ private void writePartitionedDataWithStream(File file, ShufflePartitionWriter wr
}
@Override
- public Option stop(boolean success) {
+ public Option stop(boolean success) {
if (stopping) {
return None$.empty();
} else {
stopping = true;
if (success) {
- if (mapStatus == null) {
+ if (taskResult == null) {
throw new IllegalStateException("Cannot call stop(true) without having called write()");
}
- return Option.apply(mapStatus);
+ return Option.apply(taskResult);
} else {
// The map task failed, so delete our output data.
if (partitionWriters != null) {
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 79e38a824fea4..d3f0feee377a2 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -29,6 +29,7 @@
import scala.Option;
import scala.Product2;
import scala.collection.JavaConverters;
+import scala.compat.java8.OptionConverters;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
@@ -46,8 +47,8 @@
import org.apache.spark.io.NioBufferedFileInputStream;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
-import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
+import org.apache.spark.scheduler.MapTaskResult;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
@@ -57,6 +58,7 @@
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
+import org.apache.spark.shuffle.api.metadata.MapOutputCommitMessage;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
@@ -86,7 +88,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter {
private final int initialSortBufferSize;
private final int inputBufferSizeInBytes;
- @Nullable private MapStatus mapStatus;
+ @Nullable private MapTaskResult taskResult;
@Nullable private ShuffleExternalSorter sorter;
private long peakMemoryUsedBytes = 0;
@@ -219,9 +221,9 @@ void closeAndWriteOutput() throws IOException {
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
- final long[] partitionLengths;
+ final MapOutputCommitMessage mapOutputCommitMessage;
try {
- partitionLengths = mergeSpills(spills);
+ mapOutputCommitMessage = mergeSpills(spills);
} finally {
for (SpillInfo spill : spills) {
if (spill.file.exists() && !spill.file.delete()) {
@@ -229,8 +231,12 @@ void closeAndWriteOutput() throws IOException {
}
}
}
- mapStatus = MapStatus$.MODULE$.apply(
- blockManager.shuffleServerId(), partitionLengths, mapId);
+ taskResult = new MapTaskResult(
+ MapStatus$.MODULE$.apply(
+ blockManager.shuffleServerId(),
+ mapOutputCommitMessage.getPartitionLengths(),
+ mapId),
+ OptionConverters.toScala(mapOutputCommitMessage.getMapOutputMetadata()));
}
@VisibleForTesting
@@ -262,33 +268,35 @@ void forceSorterToSpill() throws IOException {
*
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpills(SpillInfo[] spills) throws IOException {
- long[] partitionLengths;
+ private MapOutputCommitMessage mergeSpills(SpillInfo[] spills) throws IOException {
+ MapOutputCommitMessage mapOutputCommitMessage;
if (spills.length == 0) {
final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents
.createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions());
- return mapWriter.commitAllPartitions().getPartitionLengths();
+ mapOutputCommitMessage = mapWriter.commitAllPartitions();
} else if (spills.length == 1) {
Optional maybeSingleFileWriter =
shuffleExecutorComponents.createSingleFileMapOutputWriter(shuffleId, mapId);
if (maybeSingleFileWriter.isPresent()) {
// Here, we don't need to perform any metrics updates because the bytes written to this
// output file would have already been counted as shuffle bytes written.
- partitionLengths = spills[0].partitionLengths;
- logger.debug("Merge shuffle spills for mapId {} with length {}", mapId,
+ long[] partitionLengths = spills[0].partitionLengths;
+ logger.debug("Transfer shuffle spills for mapId {} with length {}", mapId,
partitionLengths.length);
- maybeSingleFileWriter.get().transferMapSpillFile(spills[0].file, partitionLengths);
+ mapOutputCommitMessage = maybeSingleFileWriter.get().transferMapSpillFile(
+ spills[0].file, partitionLengths);
} else {
- partitionLengths = mergeSpillsUsingStandardWriter(spills);
+ mapOutputCommitMessage = mergeSpillsUsingStandardWriter(spills);
}
} else {
- partitionLengths = mergeSpillsUsingStandardWriter(spills);
+ mapOutputCommitMessage = mergeSpillsUsingStandardWriter(spills);
}
- return partitionLengths;
+ return mapOutputCommitMessage;
}
- private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOException {
- long[] partitionLengths;
+ private MapOutputCommitMessage mergeSpillsUsingStandardWriter(SpillInfo[] spills)
+ throws IOException {
+ MapOutputCommitMessage commitMessage;
final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS());
final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
final boolean fastMergeEnabled =
@@ -330,7 +338,7 @@ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOExcep
// to be counted as shuffle write, but this will lead to double-counting of the final
// SpillInfo's bytes.
writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
- partitionLengths = mapWriter.commitAllPartitions().getPartitionLengths();
+ commitMessage = mapWriter.commitAllPartitions();
} catch (Exception e) {
try {
mapWriter.abort(e);
@@ -340,7 +348,7 @@ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOExcep
}
throw e;
}
- return partitionLengths;
+ return commitMessage;
}
/**
@@ -492,7 +500,7 @@ private void mergeSpillsWithTransferTo(
}
@Override
- public Option stop(boolean success) {
+ public Option stop(boolean success) {
try {
taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes());
@@ -501,10 +509,10 @@ public Option stop(boolean success) {
} else {
stopping = true;
if (success) {
- if (mapStatus == null) {
+ if (taskResult == null) {
throw new IllegalStateException("Cannot call stop(true) without having called write()");
}
- return Option.apply(mapStatus);
+ return Option.apply(taskResult);
} else {
return Option.apply(null);
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java
index 50eb2f1813714..b8e010a06b57d 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle.sort.io;
+import java.util.Map;
+
import org.apache.spark.SparkConf;
import org.apache.spark.shuffle.api.ShuffleDataIO;
import org.apache.spark.shuffle.api.ShuffleDriverComponents;
@@ -35,12 +37,13 @@ public LocalDiskShuffleDataIO(SparkConf sparkConf) {
}
@Override
- public ShuffleExecutorComponents executor() {
- return new LocalDiskShuffleExecutorComponents(sparkConf);
+ public ShuffleDriverComponents initializeShuffleDriverComponents() {
+ return new LocalDiskShuffleDriverComponents(sparkConf);
}
@Override
- public ShuffleDriverComponents driver() {
- return new LocalDiskShuffleDriverComponents();
+ public ShuffleExecutorComponents initializeShuffleExecutorComponents(
+ String appId, String execId, Map extraConfigs) {
+ return new LocalDiskShuffleExecutorComponents(sparkConf);
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDriverComponents.java
index 92b4b318c552d..a18fcb283cecc 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDriverComponents.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDriverComponents.java
@@ -20,17 +20,20 @@
import java.util.Collections;
import java.util.Map;
-import org.apache.spark.SparkEnv;
+import org.apache.spark.SparkConf;
import org.apache.spark.shuffle.api.ShuffleDriverComponents;
-import org.apache.spark.storage.BlockManagerMaster;
+import org.apache.spark.shuffle.api.metadata.ShuffleOutputTracker;
public class LocalDiskShuffleDriverComponents implements ShuffleDriverComponents {
- private BlockManagerMaster blockManagerMaster;
+ private final LocalDiskShuffleOutputTracker outputTracker;
+
+ public LocalDiskShuffleDriverComponents(SparkConf sparkConf) {
+ this.outputTracker = new LocalDiskShuffleOutputTracker(sparkConf);
+ }
@Override
- public Map initializeApplication() {
- blockManagerMaster = SparkEnv.get().blockManager().master();
+ public Map getAddedExecutorSparkConf() {
return Collections.emptyMap();
}
@@ -40,10 +43,7 @@ public void cleanupApplication() {
}
@Override
- public void removeShuffle(int shuffleId, boolean blocking) {
- if (blockManagerMaster == null) {
- throw new IllegalStateException("Driver components must be initialized before using");
- }
- blockManagerMaster.removeShuffle(shuffleId, blocking);
+ public ShuffleOutputTracker shuffleOutputTracker() {
+ return outputTracker;
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java
index eb4d9d9abc8e3..934b67032417c 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java
@@ -17,46 +17,48 @@
package org.apache.spark.shuffle.sort.io;
-import java.util.Map;
import java.util.Optional;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Supplier;
+import com.google.common.base.Suppliers;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
-import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
-import org.apache.spark.storage.BlockManager;
public class LocalDiskShuffleExecutorComponents implements ShuffleExecutorComponents {
private final SparkConf sparkConf;
- private BlockManager blockManager;
- private IndexShuffleBlockResolver blockResolver;
+ private final Supplier blockResolver;
public LocalDiskShuffleExecutorComponents(SparkConf sparkConf) {
- this.sparkConf = sparkConf;
+ this(
+ sparkConf,
+ Suppliers.memoize(() -> {
+ if (SparkEnv.get() == null) {
+ throw new IllegalStateException("SparkEnv must be initialized before using the" +
+ " local disk executor components/");
+ }
+ return new IndexShuffleBlockResolver(sparkConf, SparkEnv.get().blockManager());
+ }));
}
@VisibleForTesting
public LocalDiskShuffleExecutorComponents(
SparkConf sparkConf,
- BlockManager blockManager,
IndexShuffleBlockResolver blockResolver) {
- this.sparkConf = sparkConf;
- this.blockManager = blockManager;
- this.blockResolver = blockResolver;
+ this(sparkConf, () -> blockResolver);
}
- @Override
- public void initializeExecutor(String appId, String execId, Map extraConfigs) {
- blockManager = SparkEnv.get().blockManager();
- if (blockManager == null) {
- throw new IllegalStateException("No blockManager available from the SparkEnv.");
- }
- blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager);
+ private LocalDiskShuffleExecutorComponents(
+ SparkConf sparkConf,
+ Supplier blockResolver) {
+ this.sparkConf = sparkConf;
+ this.blockResolver = blockResolver;
}
@Override
@@ -64,22 +66,15 @@ public ShuffleMapOutputWriter createMapOutputWriter(
int shuffleId,
long mapTaskId,
int numPartitions) {
- if (blockResolver == null) {
- throw new IllegalStateException(
- "Executor components must be initialized before getting writers.");
- }
return new LocalDiskShuffleMapOutputWriter(
- shuffleId, mapTaskId, numPartitions, blockResolver, sparkConf);
+ shuffleId, mapTaskId, numPartitions, blockResolver.get(), sparkConf);
}
@Override
public Optional createSingleFileMapOutputWriter(
int shuffleId,
long mapId) {
- if (blockResolver == null) {
- throw new IllegalStateException(
- "Executor components must be initialized before getting writers.");
- }
- return Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver));
+ return Optional.of(
+ new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver.get()));
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleOutputTracker.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleOutputTracker.java
new file mode 100644
index 0000000000000..97bd05f802b67
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleOutputTracker.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import com.google.common.base.Supplier;
+import com.google.common.base.Suppliers;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.shuffle.api.metadata.MapOutputMetadata;
+import org.apache.spark.shuffle.api.metadata.ShuffleOutputTracker;
+import org.apache.spark.storage.BlockManagerMaster;
+
+public final class LocalDiskShuffleOutputTracker implements ShuffleOutputTracker {
+
+ private final Supplier env = Suppliers.memoize(SparkEnv::get);
+ private final Supplier blockManagerMaster = Suppliers.memoize(
+ () -> {
+ SparkEnv env = SparkEnv.get();
+ if (env == null) {
+ throw new IllegalStateException("SparkEnv should not be null here.");
+ }
+ return env.blockManager().master();
+ });
+ private final SparkConf sparkConf;
+
+ public LocalDiskShuffleOutputTracker(SparkConf sparkConf) {
+ this.sparkConf = sparkConf;
+ }
+
+ @Override
+ public void registerShuffle(int shuffleId) {}
+
+ @Override
+ public void unregisterShuffle(int shuffleId, boolean blocking) {
+ blockManagerMaster.get().removeShuffle(shuffleId, blocking);
+ }
+
+ @Override
+ public void registerMapOutput(
+ int shuffleId, int mapIndex, long mapId, MapOutputMetadata mapOutputMetadata) {}
+
+ @Override
+ public void removeMapOutput(int shuffleId, int mapIndex, long mapId) {}
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
index c8b41992a8919..d964a8c6bef24 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
@@ -23,6 +23,7 @@
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
+import org.apache.spark.shuffle.api.metadata.MapOutputCommitMessage;
import org.apache.spark.util.Utils;
public class LocalDiskSingleSpillMapOutputWriter
@@ -42,14 +43,14 @@ public LocalDiskSingleSpillMapOutputWriter(
}
@Override
- public void transferMapSpillFile(
- File mapSpillFile,
- long[] partitionLengths) throws IOException {
+ public MapOutputCommitMessage transferMapSpillFile(
+ File mapSpillFile, long[] partitionLengths) throws IOException {
// The map spill file already has the proper format, and it contains all of the partition data.
// So just transfer it directly to the destination without any merging.
File outputFile = blockResolver.getDataFile(shuffleId, mapId);
File tempFile = Utils.tempFileWith(outputFile);
Files.move(mapSpillFile.toPath(), tempFile.toPath());
blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile);
+ return MapOutputCommitMessage.of(partitionLengths);
}
}
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index cfa1139140025..bad4781d3ecea 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -59,9 +59,7 @@ private class CleanupTaskWeakReference(
* to be processed when the associated object goes out of scope of the application. Actual
* cleanup is performed in a separate daemon thread.
*/
-private[spark] class ContextCleaner(
- sc: SparkContext,
- shuffleDriverComponents: ShuffleDriverComponents) extends Logging {
+private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
/**
* A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they
@@ -222,14 +220,11 @@ private[spark] class ContextCleaner(
/** Perform shuffle cleanup. */
def doCleanupShuffle(shuffleId: Int, blocking: Boolean): Unit = {
try {
- if (mapOutputTrackerMaster.containsShuffle(shuffleId)) {
- logDebug("Cleaning shuffle " + shuffleId)
- mapOutputTrackerMaster.unregisterShuffle(shuffleId)
- shuffleDriverComponents.removeShuffle(shuffleId, blocking)
+ logDebug("Cleaning shuffle " + shuffleId)
+ val shuffleRemoved = mapOutputTrackerMaster.unregisterShuffle(shuffleId, blocking)
+ if (shuffleRemoved) {
listeners.asScala.foreach(_.shuffleCleaned(shuffleId))
logDebug("Cleaned shuffle " + shuffleId)
- } else {
- logDebug("Asked to cleanup non-existent shuffle (maybe it was already removed)")
}
} catch {
case e: Exception => logError("Error cleaning shuffle " + shuffleId, e)
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index ba8e4d69ba755..f0ac9acd90156 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -96,7 +96,6 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
shuffleId, this)
_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
- _rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId)
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index c3152d9225107..a1212205575e6 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -35,8 +35,9 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
-import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, MapStatus}
+import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, MapStatus, MapTaskResult}
import org.apache.spark.shuffle.MetadataFetchFailedException
+import org.apache.spark.shuffle.api.metadata.{MapOutputMetadata, ShuffleOutputTracker}
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util._
@@ -49,7 +50,10 @@ import org.apache.spark.util._
*
* All public methods of this class are thread-safe.
*/
-private class ShuffleStatus(numPartitions: Int) extends Logging {
+private class ShuffleStatus(
+ shuffleId: Int,
+ numPartitions: Int,
+ val shuffleOutputTracker: ShuffleOutputTracker) extends Logging {
private val (readLock, writeLock) = {
val lock = new ReentrantReadWriteLock()
@@ -75,6 +79,10 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
}
}
+ withWriteLock {
+ this.shuffleOutputTracker.registerShuffle(shuffleId)
+ }
+
/**
* MapStatus for each partition. The index of the array is the map partition id.
* Each value in the array is the MapStatus for a partition, or null if the partition
@@ -113,12 +121,18 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
* Register a map output. If there is already a registered location for the map output then it
* will be replaced by the new location.
*/
- def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock {
- if (mapStatuses(mapIndex) == null) {
- _numAvailableOutputs += 1
- invalidateSerializedMapOutputStatusCache()
+ def addMapOutput(
+ mapIndex: Int, status: MapStatus, maybeMetadata: Option[MapOutputMetadata]): Unit = {
+ withWriteLock {
+ if (mapStatuses(mapIndex) == null) {
+ _numAvailableOutputs += 1
+ invalidateSerializedMapOutputStatusCache()
+ }
+ mapStatuses(mapIndex) = status
+ maybeMetadata.foreach { metadata =>
+ this.shuffleOutputTracker.registerMapOutput(shuffleId, mapIndex, status.mapId, metadata)
+ }
}
- mapStatuses(mapIndex) = status
}
/**
@@ -149,8 +163,10 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock {
logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}")
if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) {
+ val status = mapStatuses(mapIndex)
_numAvailableOutputs -= 1
mapStatuses(mapIndex) = null
+ this.shuffleOutputTracker.removeMapOutput(shuffleId, mapIndex, status.mapId)
invalidateSerializedMapOutputStatusCache()
}
}
@@ -369,8 +385,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
/**
* Deletes map output status information for the specified shuffle stage.
+ *
+ * @return true if a shuffle status was present and was removed
*/
- def unregisterShuffle(shuffleId: Int): Unit
+ def unregisterShuffle(shuffleId: Int, blocking: Boolean): Boolean
def stop(): Unit = {}
}
@@ -386,6 +404,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
*/
private[spark] class MapOutputTrackerMaster(
conf: SparkConf,
+ shuffleOutputTracker: ShuffleOutputTracker,
private[spark] val broadcastManager: BroadcastManager,
private[spark] val isLocal: Boolean)
extends MapOutputTracker(conf) {
@@ -483,11 +502,18 @@ private[spark] class MapOutputTrackerMaster(
}
def registerShuffle(shuffleId: Int, numMaps: Int): Unit = {
- if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) {
+ if (shuffleStatuses.put(
+ shuffleId,
+ new ShuffleStatus(shuffleId, numMaps, shuffleOutputTracker)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}
+ def registerMapOutput(shuffleId: Int, mapIndex: Int, mapTaskResult: MapTaskResult): Unit = {
+ shuffleStatuses(shuffleId).addMapOutput(
+ mapIndex, mapTaskResult.mapStatus, mapTaskResult.metadata)
+ }
+
def updateMapOutput(shuffleId: Int, mapId: Long, bmAddress: BlockManagerId): Unit = {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
@@ -497,10 +523,6 @@ private[spark] class MapOutputTrackerMaster(
}
}
- def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Unit = {
- shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
- }
-
/** Unregister map output information of the given shuffle, mapper and block manager */
def unregisterMapOutput(shuffleId: Int, mapIndex: Int, bmAddress: BlockManagerId): Unit = {
shuffleStatuses.get(shuffleId) match {
@@ -525,10 +547,13 @@ private[spark] class MapOutputTrackerMaster(
}
/** Unregister shuffle data */
- def unregisterShuffle(shuffleId: Int): Unit = {
- shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
+ def unregisterShuffle(shuffleId: Int, blocking: Boolean): Boolean = {
+ val removedStatus = shuffleStatuses.remove(shuffleId)
+ removedStatus.foreach { shuffleStatus =>
shuffleStatus.invalidateSerializedMapOutputStatusCache()
+ shuffleOutputTracker.unregisterShuffle(shuffleId, blocking)
}
+ removedStatus.isDefined
}
/**
@@ -858,8 +883,8 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
/** Unregister shuffle data. */
- def unregisterShuffle(shuffleId: Int): Unit = {
- mapStatuses.remove(shuffleId)
+ def unregisterShuffle(shuffleId: Int, blocking: Boolean): Boolean = {
+ mapStatuses.remove(shuffleId).isDefined
}
/**
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 501e865c4105a..32325d576634c 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -27,6 +27,7 @@ import scala.collection.JavaConverters._
import scala.collection.Map
import scala.collection.immutable
import scala.collection.mutable.HashMap
+import scala.compat.java8.OptionConverters._
import scala.language.implicitConversions
import scala.reflect.{classTag, ClassTag}
import scala.util.control.NonFatal
@@ -225,7 +226,6 @@ class SparkContext(config: SparkConf) extends Logging {
private var _statusStore: AppStatusStore = _
private var _heartbeater: Heartbeater = _
private var _resources: immutable.Map[String, ResourceInformation] = _
- private var _shuffleDriverComponents: ShuffleDriverComponents = _
private var _plugins: Option[PluginContainer] = None
private var _resourceProfileManager: ResourceProfileManager = _
@@ -321,8 +321,6 @@ class SparkContext(config: SparkConf) extends Logging {
_dagScheduler = ds
}
- private[spark] def shuffleDriverComponents: ShuffleDriverComponents = _shuffleDriverComponents
-
/**
* A unique identifier for the Spark application.
* Its format depends on the scheduler implementation.
@@ -528,9 +526,12 @@ class SparkContext(config: SparkConf) extends Logging {
executorEnvs ++= _conf.getExecutorEnv
executorEnvs("SPARK_USER") = sparkUser
- _shuffleDriverComponents = ShuffleDataIOUtils.loadShuffleDataIO(config).driver()
- _shuffleDriverComponents.initializeApplication().asScala.foreach { case (k, v) =>
- _conf.set(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX + k, v)
+ _env.shuffleDataIO
+ .driver()
+ .getAddedExecutorSparkConf()
+ .asScala
+ .foreach { case (k, v) =>
+ _conf.set(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX + k, v)
}
// We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will
@@ -595,7 +596,7 @@ class SparkContext(config: SparkConf) extends Logging {
_cleaner =
if (_conf.get(CLEANER_REFERENCE_TRACKING)) {
- Some(new ContextCleaner(this, _shuffleDriverComponents))
+ Some(new ContextCleaner(this))
} else {
None
}
@@ -2020,10 +2021,8 @@ class SparkContext(config: SparkConf) extends Logging {
}
_heartbeater = null
}
- if (_shuffleDriverComponents != null) {
- Utils.tryLogNonFatalError {
- _shuffleDriverComponents.cleanupApplication()
- }
+ Utils.tryLogNonFatalError {
+ env.shuffleDataIO.driver().cleanupApplication()
}
if (env != null && _heartbeatReceiver != null) {
Utils.tryLogNonFatalError {
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index d543359f4dedf..30fb1db6fbb68 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -43,7 +43,7 @@ import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator}
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager}
-import org.apache.spark.shuffle.ShuffleManager
+import org.apache.spark.shuffle.{MemoizingShuffleDataIO, ShuffleDataIOUtils, ShuffleManager}
import org.apache.spark.storage._
import org.apache.spark.util.{RpcUtils, Utils}
@@ -69,7 +69,8 @@ class SparkEnv (
val metricsSystem: MetricsSystem,
val memoryManager: MemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
- val conf: SparkConf) extends Logging {
+ val conf: SparkConf,
+ val shuffleDataIO: MemoizingShuffleDataIO) extends Logging {
@volatile private[spark] var isStopped = false
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
@@ -315,9 +316,14 @@ object SparkEnv extends Logging {
}
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
+ val shuffleDataIo = new MemoizingShuffleDataIO(ShuffleDataIOUtils.loadShuffleDataIO(conf))
val mapOutputTracker = if (isDriver) {
- new MapOutputTrackerMaster(conf, broadcastManager, isLocal)
+ new MapOutputTrackerMaster(
+ conf,
+ shuffleDataIo.driver().shuffleOutputTracker(),
+ broadcastManager,
+ isLocal)
} else {
new MapOutputTrackerWorker(conf)
}
@@ -415,7 +421,6 @@ object SparkEnv extends Logging {
val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator",
new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)
-
val envInstance = new SparkEnv(
executorId,
rpcEnv,
@@ -430,7 +435,8 @@ object SparkEnv extends Logging {
metricsSystem,
memoryManager,
outputCommitCoordinator,
- conf)
+ conf,
+ shuffleDataIo)
// Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is
// called, and we only need to do it for driver. Because driver may run as a service, and if we
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 080e0e7f1552f..ff26179005e51 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1618,8 +1618,8 @@ private[spark] class DAGScheduler(
case smt: ShuffleMapTask =>
val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
shuffleStage.pendingPartitions -= task.partitionId
- val status = event.result.asInstanceOf[MapStatus]
- val execId = status.location.executorId
+ val mapTaskResult = event.result.asInstanceOf[MapTaskResult]
+ val execId = mapTaskResult.mapStatus.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
if (executorFailureEpoch.contains(execId) &&
smt.epoch <= executorFailureEpoch(execId)) {
@@ -1629,7 +1629,7 @@ private[spark] class DAGScheduler(
// recent failure we're aware of for the executor), so mark the task's output as
// available.
mapOutputTracker.registerMapOutput(
- shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
+ shuffleStage.shuffleDep.shuffleId, smt.partitionId, mapTaskResult)
}
if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapTaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/MapTaskResult.scala
new file mode 100644
index 0000000000000..d52dc8d69dfb9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapTaskResult.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.shuffle.api.metadata.MapOutputMetadata
+
+private[spark] case class MapTaskResult(
+ mapStatus: MapStatus, metadata: Option[MapOutputMetadata])
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index a0ba9208ea647..c1d03cb400e6b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -61,7 +61,7 @@ private[spark] class ShuffleMapTask(
appId: Option[String] = None,
appAttemptId: Option[String] = None,
isBarrier: Boolean = false)
- extends Task[MapStatus](stageId, stageAttemptId, partition.index, localProperties,
+ extends Task[MapTaskResult](stageId, stageAttemptId, partition.index, localProperties,
serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier)
with Logging {
@@ -74,7 +74,7 @@ private[spark] class ShuffleMapTask(
if (locs == null) Nil else locs.distinct
}
- override def runTask(context: TaskContext): MapStatus = {
+ override def runTask(context: TaskContext): MapTaskResult = {
// Deserialize the RDD using the broadcast variable.
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTimeNs = System.nanoTime()
diff --git a/core/src/main/scala/org/apache/spark/shuffle/MemoizingShuffleDataIO.scala b/core/src/main/scala/org/apache/spark/shuffle/MemoizingShuffleDataIO.scala
new file mode 100644
index 0000000000000..60cb390e6274a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/MemoizingShuffleDataIO.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents}
+
+/**
+ * Thin wrapper around {@link ShuffleDataIO} that ensures the given components are
+ * only initialized once and providing the same instance each time.
+ *
+ * Used to ensure the SparkEnv only instantiates the given components once lazily
+ * and then reuses them throughout the lifetime of the SparkEnv.
+ */
+class MemoizingShuffleDataIO(delegate: ShuffleDataIO) {
+ private lazy val _driver = delegate.initializeShuffleDriverComponents()
+ private lazy val _executor = {
+ val env = SparkEnv.get
+ delegate.initializeShuffleExecutorComponents(
+ env.conf.getAppId,
+ env.executorId,
+ env.conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap.asJava)
+ }
+
+ def driver(): ShuffleDriverComponents = _driver
+
+ def executor(): ShuffleExecutorComponents = _executor
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala
index 1429144c6f6e2..84b889968ad5e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala
@@ -20,7 +20,7 @@ package org.apache.spark.shuffle
import org.apache.spark.{Partition, ShuffleDependency, SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.scheduler.MapTaskResult
/**
* The interface for customizing shuffle write process. The driver create a ShuffleWriteProcessor
@@ -46,7 +46,7 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging {
dep: ShuffleDependency[_, _, _],
mapId: Long,
context: TaskContext,
- partition: Partition): MapStatus = {
+ partition: Partition): MapTaskResult = {
var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
index 4cc4ef5f1886e..ea1777be9f57c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
@@ -19,7 +19,7 @@ package org.apache.spark.shuffle
import java.io.IOException
-import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.scheduler.MapTaskResult
/**
* Obtained inside a map task to write out records to the shuffle system.
@@ -30,5 +30,5 @@ private[spark] abstract class ShuffleWriter[K, V] {
def write(records: Iterator[Product2[K, V]]): Unit
/** Close this writer, passing along whether the map completed */
- def stop(success: Boolean): Option[MapStatus]
+ def stop(success: Boolean): Option[MapTaskResult]
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 72460180f5908..95674c4ed09c8 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -87,7 +87,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
*/
private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]()
- private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf)
+ private lazy val shuffleExecutorComponents = SparkEnv.get.shuffleDataIO.executor()
override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
@@ -240,17 +240,6 @@ private[spark] object SortShuffleManager extends Logging {
true
}
}
-
- private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = {
- val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor()
- val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX)
- .toMap
- executorComponents.initializeExecutor(
- conf.getAppId,
- SparkEnv.get.executorId,
- extraConfigs.asJava)
- executorComponents
- }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 83ebe3e12946c..19584352d708d 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -17,9 +17,11 @@
package org.apache.spark.shuffle.sort
+import scala.compat.java8.OptionConverters._
+
import org.apache.spark._
import org.apache.spark.internal.{config, Logging}
-import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.scheduler.{MapStatus, MapTaskResult}
import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.util.collection.ExternalSorter
@@ -43,7 +45,7 @@ private[spark] class SortShuffleWriter[K, V, C](
// we don't try deleting files, etc twice.
private var stopping = false
- private var mapStatus: MapStatus = null
+ private var taskResult: MapTaskResult = null
private val writeMetrics = context.taskMetrics().shuffleWriteMetrics
@@ -67,19 +69,24 @@ private[spark] class SortShuffleWriter[K, V, C](
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
dep.shuffleId, mapId, dep.partitioner.numPartitions)
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
- val partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths
- mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
+ val mapOutputCommitMessage = mapOutputWriter.commitAllPartitions()
+ taskResult = MapTaskResult(
+ MapStatus(
+ blockManager.shuffleServerId,
+ mapOutputCommitMessage.getPartitionLengths,
+ mapId),
+ mapOutputCommitMessage.getMapOutputMetadata.asScala)
}
/** Close this writer, passing along whether the map completed */
- override def stop(success: Boolean): Option[MapStatus] = {
+ override def stop(success: Boolean): Option[MapTaskResult] = {
try {
if (stopping) {
return None
}
stopping = true
if (success) {
- return Option(mapStatus)
+ return Option(taskResult)
} else {
return None
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala
index a69bebc23c661..55734defe553e 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala
@@ -56,7 +56,7 @@ class BlockManagerStorageEndpoint(
case RemoveShuffle(shuffleId) =>
doAsync[Boolean]("removing shuffle " + shuffleId, context) {
if (mapOutputTracker != null) {
- mapOutputTracker.unregisterShuffle(shuffleId)
+ mapOutputTracker.unregisterShuffle(shuffleId, true)
}
SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
}
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index ee8e38c24b47f..2863a95921c8f 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -50,7 +50,7 @@
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
-import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.scheduler.MapTaskResult;
import org.apache.spark.security.CryptoStreamUtils;
import org.apache.spark.serializer.*;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
@@ -181,7 +181,7 @@ private UnsafeShuffleWriter
+
+ org.scala-lang.modules
+ scala-java8-compat_${scala.binary.version}
+ 0.9.1
+ jlinejline
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 98769d951b6ac..5e7fe38bec8e3 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -38,6 +38,25 @@ object MimaExcludes {
lazy val v31excludes = v30excludes ++ Seq(
// mima plugin update caused new incompatibilities to be detected
// core module
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO.executor")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO.driver")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.io.LocalDiskShuffleDriverComponents.this")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.io.LocalDiskShuffleDriverComponents.initializeApplication")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.io.LocalDiskShuffleDriverComponents.removeShuffle")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents.this")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents.initializeExecutor")
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.shuffle.sort.io.LocalDiskSingleSpillMapOutputWriter.transferMapSpillFile")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleDriverComponents.initializeApplication")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleDriverComponents.registerShuffle")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleDriverComponents.removeShuffle")
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleDriverComponents.getAddedExecutorSparkConf")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleExecutorComponents.initializeExecutor")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleDataIO.executor")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleDataIO.driver")
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleDataIO.initializeShuffleExecutorComponents")
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleDataIO.initializeShuffleDriverComponents")
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter.transferMapSpillFile")
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter.transferMapSpillFile")
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.shuffle.sort.io.LocalDiskShuffleMapOutputWriter.commitAllPartitions"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions"),
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index 86c20f5a46b9a..2792dd5670824 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -72,7 +72,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean)
val streamId = 1
val securityMgr = new SecurityManager(conf, encryptionKey)
val broadcastManager = new BroadcastManager(true, conf, securityMgr)
- val mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true)
+ val mapOutputTracker = new MapOutputTrackerMaster(conf, null, broadcastManager, true)
val shuffleManager = new SortShuffleManager(conf)
val serializer = new KryoSerializer(conf)
var serializerManager = new SerializerManager(serializer, conf, encryptionKey)