diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
index 21abe9a57cd25..0167002ceedb8 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
@@ -20,6 +20,7 @@
import java.io.IOException;
import org.apache.spark.annotation.Private;
+import org.apache.spark.shuffle.api.metadata.MapOutputCommitMessage;
/**
* :: Private ::
@@ -60,10 +61,15 @@ public interface ShuffleMapOutputWriter {
*
* This can also close any resources and clean up temporary state if necessary.
*
- * The returned array should contain, for each partition from (0) to (numPartitions - 1), the
- * number of bytes written by the partition writer for that partition id.
+ * The returned commit message is a structure with two components:
+ *
+ * 1) An array of longs, which should contain, for each partition from (0) to
+ * (numPartitions - 1), the number of bytes written by the partition writer
+ * for that partition id.
+ *
+ * 2) An optional metadata blob that can be used by shuffle readers.
*/
- long[] commitAllPartitions() throws IOException;
+ MapOutputCommitMessage commitAllPartitions() throws IOException;
/**
* Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}.
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputCommitMessage.java b/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputCommitMessage.java
new file mode 100644
index 0000000000000..7050690aaddf2
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputCommitMessage.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api.metadata;
+
+import java.util.Optional;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * :: Private ::
+ *
+ * Represents the result of writing map outputs for a shuffle map task.
+ *
+ * Partition lengths represents the length of each block written in the map task. This can
+ * be used for downstream readers to allocate resources, such as in-memory buffers.
+ *
+ * Map output writers can choose to attach arbitrary metadata tags to register with a
+ * shuffle output tracker (a module that is currently yet to be built in a future
+ * iteration of the shuffle storage APIs).
+ */
+@Private
+public final class MapOutputCommitMessage {
+
+ private final long[] partitionLengths;
+ private final Optional mapOutputMetadata;
+
+ private MapOutputCommitMessage(
+ long[] partitionLengths, Optional mapOutputMetadata) {
+ this.partitionLengths = partitionLengths;
+ this.mapOutputMetadata = mapOutputMetadata;
+ }
+
+ public static MapOutputCommitMessage of(long[] partitionLengths) {
+ return new MapOutputCommitMessage(partitionLengths, Optional.empty());
+ }
+
+ public static MapOutputCommitMessage of(
+ long[] partitionLengths, MapOutputMetadata mapOutputMetadata) {
+ return new MapOutputCommitMessage(partitionLengths, Optional.of(mapOutputMetadata));
+ }
+
+ public long[] getPartitionLengths() {
+ return partitionLengths;
+ }
+
+ public Optional getMapOutputMetadata() {
+ return mapOutputMetadata;
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputMetadata.java b/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputMetadata.java
new file mode 100644
index 0000000000000..6f0e5da4ffeb1
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputMetadata.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api.metadata;
+
+import java.io.Serializable;
+
+/**
+ * :: Private ::
+ *
+ * An opaque metadata tag for registering the result of committing the output of a
+ * shuffle map task.
+ *
+ * All implementations must be serializable since this is sent from the executors to
+ * the driver.
+ */
+public interface MapOutputMetadata extends Serializable {}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index dc157eaa3b253..256789b8c7827 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -130,7 +130,7 @@ public void write(Iterator> records) throws IOException {
.createMapOutputWriter(shuffleId, mapId, numPartitions);
try {
if (!records.hasNext()) {
- partitionLengths = mapOutputWriter.commitAllPartitions();
+ partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths();
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
return;
@@ -219,7 +219,7 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro
}
partitionWriters = null;
}
- return mapOutputWriter.commitAllPartitions();
+ return mapOutputWriter.commitAllPartitions().getPartitionLengths();
}
private void writePartitionedDataWithChannel(
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index d09282e61a9c7..5515a85295d78 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -266,7 +266,7 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
if (spills.length == 0) {
final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents
.createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions());
- return mapWriter.commitAllPartitions();
+ return mapWriter.commitAllPartitions().getPartitionLengths();
} else if (spills.length == 1) {
Optional maybeSingleFileWriter =
shuffleExecutorComponents.createSingleFileMapOutputWriter(shuffleId, mapId);
@@ -327,7 +327,7 @@ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOExcep
// to be counted as shuffle write, but this will lead to double-counting of the final
// SpillInfo's bytes.
writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
- partitionLengths = mapWriter.commitAllPartitions();
+ partitionLengths = mapWriter.commitAllPartitions().getPartitionLengths();
} catch (Exception e) {
try {
mapWriter.abort(e);
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
index a6529fd76188a..eea6c762f5c63 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
@@ -35,6 +35,7 @@
import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
import org.apache.spark.internal.config.package$;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.api.metadata.MapOutputCommitMessage;
import org.apache.spark.util.Utils;
/**
@@ -97,7 +98,7 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I
}
@Override
- public long[] commitAllPartitions() throws IOException {
+ public MapOutputCommitMessage commitAllPartitions() throws IOException {
// Check the position after transferTo loop to see if it is in the right position and raise a
// exception if it is incorrect. The position will not be increased to the expected length
// after calling transferTo in kernel version 2.6.32. This issue is described at
@@ -113,7 +114,7 @@ public long[] commitAllPartitions() throws IOException {
cleanUp();
File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;
blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
- return partitionLengths;
+ return MapOutputCommitMessage.of(partitionLengths);
}
@Override
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index a391bdf2db44e..83ebe3e12946c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -67,7 +67,7 @@ private[spark] class SortShuffleWriter[K, V, C](
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
dep.shuffleId, mapId, dep.partitioner.numPartitions)
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
- val partitionLengths = mapOutputWriter.commitAllPartitions()
+ val partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
index f92455912f510..d2c7d195e06fe 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
@@ -136,7 +136,7 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA
}
private def verifyWrittenRecords(): Unit = {
- val committedLengths = mapOutputWriter.commitAllPartitions()
+ val committedLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths
assert(partitionSizesInMergedFile === partitionLengths)
assert(committedLengths === partitionLengths)
assert(mergedOutputFile.length() === partitionLengths.sum)