diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index 21abe9a57cd25..0167002ceedb8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -20,6 +20,7 @@ import java.io.IOException; import org.apache.spark.annotation.Private; +import org.apache.spark.shuffle.api.metadata.MapOutputCommitMessage; /** * :: Private :: @@ -60,10 +61,15 @@ public interface ShuffleMapOutputWriter { *

* This can also close any resources and clean up temporary state if necessary. *

- * The returned array should contain, for each partition from (0) to (numPartitions - 1), the - * number of bytes written by the partition writer for that partition id. + * The returned commit message is a structure with two components: + *

+ * 1) An array of longs, which should contain, for each partition from (0) to + * (numPartitions - 1), the number of bytes written by the partition writer + * for that partition id. + *

+ * 2) An optional metadata blob that can be used by shuffle readers. */ - long[] commitAllPartitions() throws IOException; + MapOutputCommitMessage commitAllPartitions() throws IOException; /** * Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}. diff --git a/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputCommitMessage.java b/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputCommitMessage.java new file mode 100644 index 0000000000000..7050690aaddf2 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputCommitMessage.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api.metadata; + +import java.util.Optional; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * + * Represents the result of writing map outputs for a shuffle map task. + *

+ * Partition lengths represents the length of each block written in the map task. This can + * be used for downstream readers to allocate resources, such as in-memory buffers. + *

+ * Map output writers can choose to attach arbitrary metadata tags to register with a + * shuffle output tracker (a module that is currently yet to be built in a future + * iteration of the shuffle storage APIs). + */ +@Private +public final class MapOutputCommitMessage { + + private final long[] partitionLengths; + private final Optional mapOutputMetadata; + + private MapOutputCommitMessage( + long[] partitionLengths, Optional mapOutputMetadata) { + this.partitionLengths = partitionLengths; + this.mapOutputMetadata = mapOutputMetadata; + } + + public static MapOutputCommitMessage of(long[] partitionLengths) { + return new MapOutputCommitMessage(partitionLengths, Optional.empty()); + } + + public static MapOutputCommitMessage of( + long[] partitionLengths, MapOutputMetadata mapOutputMetadata) { + return new MapOutputCommitMessage(partitionLengths, Optional.of(mapOutputMetadata)); + } + + public long[] getPartitionLengths() { + return partitionLengths; + } + + public Optional getMapOutputMetadata() { + return mapOutputMetadata; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputMetadata.java b/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputMetadata.java new file mode 100644 index 0000000000000..6f0e5da4ffeb1 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputMetadata.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api.metadata; + +import java.io.Serializable; + +/** + * :: Private :: + * + * An opaque metadata tag for registering the result of committing the output of a + * shuffle map task. + *

+ * All implementations must be serializable since this is sent from the executors to + * the driver. + */ +public interface MapOutputMetadata extends Serializable {} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index dc157eaa3b253..256789b8c7827 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -130,7 +130,7 @@ public void write(Iterator> records) throws IOException { .createMapOutputWriter(shuffleId, mapId, numPartitions); try { if (!records.hasNext()) { - partitionLengths = mapOutputWriter.commitAllPartitions(); + partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths(); mapStatus = MapStatus$.MODULE$.apply( blockManager.shuffleServerId(), partitionLengths, mapId); return; @@ -219,7 +219,7 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro } partitionWriters = null; } - return mapOutputWriter.commitAllPartitions(); + return mapOutputWriter.commitAllPartitions().getPartitionLengths(); } private void writePartitionedDataWithChannel( diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index d09282e61a9c7..5515a85295d78 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -266,7 +266,7 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { if (spills.length == 0) { final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions()); - return mapWriter.commitAllPartitions(); + return mapWriter.commitAllPartitions().getPartitionLengths(); } else if (spills.length == 1) { Optional maybeSingleFileWriter = shuffleExecutorComponents.createSingleFileMapOutputWriter(shuffleId, mapId); @@ -327,7 +327,7 @@ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOExcep // to be counted as shuffle write, but this will lead to double-counting of the final // SpillInfo's bytes. writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); - partitionLengths = mapWriter.commitAllPartitions(); + partitionLengths = mapWriter.commitAllPartitions().getPartitionLengths(); } catch (Exception e) { try { mapWriter.abort(e); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java index a6529fd76188a..eea6c762f5c63 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -35,6 +35,7 @@ import org.apache.spark.shuffle.api.WritableByteChannelWrapper; import org.apache.spark.internal.config.package$; import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.api.metadata.MapOutputCommitMessage; import org.apache.spark.util.Utils; /** @@ -97,7 +98,7 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I } @Override - public long[] commitAllPartitions() throws IOException { + public MapOutputCommitMessage commitAllPartitions() throws IOException { // Check the position after transferTo loop to see if it is in the right position and raise a // exception if it is incorrect. The position will not be increased to the expected length // after calling transferTo in kernel version 2.6.32. This issue is described at @@ -113,7 +114,7 @@ public long[] commitAllPartitions() throws IOException { cleanUp(); File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); - return partitionLengths; + return MapOutputCommitMessage.of(partitionLengths); } @Override diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index a391bdf2db44e..83ebe3e12946c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -67,7 +67,7 @@ private[spark] class SortShuffleWriter[K, V, C]( val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( dep.shuffleId, mapId, dep.partitioner.numPartitions) sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - val partitionLengths = mapOutputWriter.commitAllPartitions() + val partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala index f92455912f510..d2c7d195e06fe 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -136,7 +136,7 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA } private def verifyWrittenRecords(): Unit = { - val committedLengths = mapOutputWriter.commitAllPartitions() + val committedLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths assert(partitionSizesInMergedFile === partitionLengths) assert(committedLengths === partitionLengths) assert(mergedOutputFile.length() === partitionLengths.sum)