-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-54663][CORE] Computes RowBasedChecksum in ShuffleWriters #50230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
4cd559f
53d11af
c7675b1
7b89c44
64dd36b
422e370
89901ca
a1c50fa
db59634
04e08eb
c9c28e6
d82bad2
2575d52
3b99edb
74266a5
22c79c8
df48158
cf28940
4cfaac8
786fdd3
602729c
137f254
dde16d4
f7d9dfa
5aabe70
2fd0a94
bbe26bf
1a8e9f7
97af717
5e01c52
ce29311
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.shuffle.checksum | ||
|
|
||
| import java.io.ObjectOutputStream | ||
| import java.util.zip.Checksum | ||
|
|
||
| import scala.util.control.NonFatal | ||
|
|
||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper | ||
| import org.apache.spark.util.MyByteArrayOutputStream | ||
|
|
||
| /** | ||
| * A class for computing checksum for input (key, value) pairs. The checksum is independent of | ||
| * the order of the input (key, value) pairs. It is done by computing a checksum for each row | ||
| * first, and then computing the XOR for all the row checksums. | ||
| */ | ||
| abstract class RowBasedChecksum() extends Serializable with Logging { | ||
| private var hasError: Boolean = false | ||
| private var checksumValue: Long = 0 | ||
| /** Returns the checksum value computed. Tt returns the default checksum value (0) if there | ||
| * are any errors encountered during the checksum computation. | ||
| */ | ||
| def getValue: Long = { | ||
| if (!hasError) checksumValue else 0 | ||
| } | ||
|
|
||
| /** Updates the row-based checksum with the given (key, value) pair */ | ||
| def update(key: Any, value: Any): Unit = { | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if (!hasError) { | ||
| try { | ||
| val rowChecksumValue = calculateRowChecksum(key, value) | ||
| checksumValue = checksumValue ^ rowChecksumValue | ||
|
||
| } catch { | ||
| case NonFatal(e) => | ||
| logError("Checksum computation encountered error: ", e) | ||
| hasError = true | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** Computes and returns the checksum value for the given (key, value) pair */ | ||
| protected def calculateRowChecksum(key: Any, value: Any): Long | ||
| } | ||
|
|
||
| /** | ||
| * A Concrete implementation of RowBasedChecksum. The checksum for each row is | ||
| * computed by first converting the (key, value) pair to byte array using OutputStreams, | ||
| * and then computing the checksum for the byte array. | ||
| * Note that this checksum computation is very expensive, and it is used only in tests | ||
| * in the core component. A much cheaper implementation of RowBasedChecksum is in | ||
| * UnsafeRowChecksum. | ||
| * | ||
| * @param checksumAlgorithm the algorithm used for computing checksum. | ||
| */ | ||
| class OutputStreamRowBasedChecksum(checksumAlgorithm: String) | ||
|
||
| extends RowBasedChecksum() { | ||
|
|
||
| private val DEFAULT_INITIAL_SER_BUFFER_SIZE = 32 * 1024 | ||
|
|
||
| @transient private lazy val serBuffer = | ||
| new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE) | ||
| @transient private lazy val objOut = new ObjectOutputStream(serBuffer) | ||
mridulm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @transient | ||
| protected lazy val checksum: Checksum = | ||
| ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) | ||
|
|
||
| override protected def calculateRowChecksum(key: Any, value: Any): Long = { | ||
| assert(checksum != null, "Checksum is null") | ||
|
|
||
| // Converts the (key, value) pair into byte array. | ||
| objOut.reset() | ||
| serBuffer.reset() | ||
| objOut.writeObject((key, value)) | ||
| objOut.flush() | ||
| serBuffer.flush() | ||
|
|
||
| // Computes and returns the checksum for the byte array. | ||
| checksum.reset() | ||
| checksum.update(serBuffer.getBuf, 0, serBuffer.size()) | ||
| checksum.getValue | ||
| } | ||
| } | ||
|
|
||
| object RowBasedChecksum { | ||
| def createPartitionRowBasedChecksums( | ||
|
||
| numPartitions: Int, | ||
| checksumAlgorithm: String): Array[RowBasedChecksum] = { | ||
| Array.tabulate(numPartitions)(_ => new OutputStreamRowBasedChecksum(checksumAlgorithm)) | ||
| } | ||
|
|
||
| def getAggregatedChecksumValue(rowBasedChecksums: Array[RowBasedChecksum]): Long = { | ||
| Option(rowBasedChecksums) | ||
| .map(_.foldLeft(0L)((acc, c) => acc * 31L + c.getValue)) | ||
| .getOrElse(0L) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,6 +53,7 @@ | |
| import org.apache.spark.scheduler.MapStatus$; | ||
| import org.apache.spark.serializer.Serializer; | ||
| import org.apache.spark.serializer.SerializerInstance; | ||
| import org.apache.spark.shuffle.checksum.RowBasedChecksum; | ||
| import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; | ||
| import org.apache.spark.shuffle.ShuffleWriter; | ||
| import org.apache.spark.storage.*; | ||
|
|
@@ -104,6 +105,14 @@ final class BypassMergeSortShuffleWriter<K, V> | |
| private long[] partitionLengths; | ||
| /** Checksum calculator for each partition. Empty when shuffle checksum disabled. */ | ||
| private final Checksum[] partitionChecksums; | ||
| /** | ||
| * Checksum calculator for each partition. Different from the above Checksum, | ||
| * RowBasedChecksum is independent of the input row order, which is used to | ||
| * detect whether different task attempts of the same partition produce different | ||
| * output data or not. | ||
| */ | ||
| private final RowBasedChecksum[] rowBasedChecksums; | ||
| private final SparkConf conf; | ||
|
||
|
|
||
| /** | ||
| * Are we in the process of stopping? Because map tasks can call stop() with success = true | ||
|
|
@@ -132,6 +141,8 @@ final class BypassMergeSortShuffleWriter<K, V> | |
| this.serializer = dep.serializer(); | ||
| this.shuffleExecutorComponents = shuffleExecutorComponents; | ||
| this.partitionChecksums = createPartitionChecksums(numPartitions, conf); | ||
| this.rowBasedChecksums = dep.rowBasedChecksums(); | ||
| this.conf = conf; | ||
|
||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -144,7 +155,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException { | |
| partitionLengths = mapOutputWriter.commitAllPartitions( | ||
| ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths(); | ||
| mapStatus = MapStatus$.MODULE$.apply( | ||
| blockManager.shuffleServerId(), partitionLengths, mapId); | ||
| blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue()); | ||
| return; | ||
| } | ||
| final SerializerInstance serInstance = serializer.newInstance(); | ||
|
|
@@ -171,7 +182,11 @@ public void write(Iterator<Product2<K, V>> records) throws IOException { | |
| while (records.hasNext()) { | ||
| final Product2<K, V> record = records.next(); | ||
| final K key = record._1(); | ||
| partitionWriters[partitioner.getPartition(key)].write(key, record._2()); | ||
| final int partitionId = partitioner.getPartition(key); | ||
| partitionWriters[partitionId].write(key, record._2()); | ||
| if (rowBasedChecksums.length > 0) { | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| rowBasedChecksums[partitionId].update(key, record._2()); | ||
| } | ||
| } | ||
|
|
||
| for (int i = 0; i < numPartitions; i++) { | ||
|
|
@@ -182,7 +197,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException { | |
|
|
||
| partitionLengths = writePartitionedData(mapOutputWriter); | ||
| mapStatus = MapStatus$.MODULE$.apply( | ||
| blockManager.shuffleServerId(), partitionLengths, mapId); | ||
| blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue()); | ||
| } catch (Exception e) { | ||
| try { | ||
| mapOutputWriter.abort(e); | ||
|
|
@@ -199,6 +214,14 @@ public long[] getPartitionLengths() { | |
| return partitionLengths; | ||
| } | ||
|
|
||
| public RowBasedChecksum[] getRowBasedChecksums() { | ||
|
||
| return rowBasedChecksums; | ||
| } | ||
|
|
||
| public long getAggregatedChecksumValue() { | ||
| return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums); | ||
| } | ||
|
|
||
| /** | ||
| * Concatenate all of the per-partition files into a single combined file. | ||
| * | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,9 +60,11 @@ | |
| import org.apache.spark.shuffle.api.ShufflePartitionWriter; | ||
| import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; | ||
| import org.apache.spark.shuffle.api.WritableByteChannelWrapper; | ||
| import org.apache.spark.shuffle.checksum.RowBasedChecksum; | ||
| import org.apache.spark.storage.BlockManager; | ||
| import org.apache.spark.storage.TimeTrackingOutputStream; | ||
| import org.apache.spark.unsafe.Platform; | ||
| import org.apache.spark.util.MyByteArrayOutputStream; | ||
| import org.apache.spark.util.Utils; | ||
|
|
||
| @Private | ||
|
|
@@ -94,15 +96,16 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { | |
| @Nullable private long[] partitionLengths; | ||
| private long peakMemoryUsedBytes = 0; | ||
|
|
||
| /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ | ||
| private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { | ||
| MyByteArrayOutputStream(int size) { super(size); } | ||
| public byte[] getBuf() { return buf; } | ||
| } | ||
|
|
||
| private MyByteArrayOutputStream serBuffer; | ||
| private SerializationStream serOutputStream; | ||
|
|
||
| /** | ||
| * RowBasedChecksum calculator for each partition. RowBasedChecksum is independent | ||
| * of the input row order, which is used to detect whether different task attempts | ||
| * of the same partition produce different output data or not. | ||
| */ | ||
| private final RowBasedChecksum[] rowBasedChecksums; | ||
|
|
||
| /** | ||
| * Are we in the process of stopping? Because map tasks can call stop() with success = true | ||
| * and then call stop() with success = false if they get an exception, we want to make sure | ||
|
|
@@ -142,6 +145,7 @@ public UnsafeShuffleWriter( | |
| (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()); | ||
| this.mergeBufferSizeInBytes = | ||
| (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_MERGE_BUFFER_SIZE()) * 1024; | ||
| this.rowBasedChecksums = dep.rowBasedChecksums(); | ||
| open(); | ||
| } | ||
|
|
||
|
|
@@ -163,6 +167,13 @@ public long getPeakMemoryUsedBytes() { | |
| return peakMemoryUsedBytes; | ||
| } | ||
|
|
||
| public RowBasedChecksum[] getRowBasedChecksums() { | ||
| return rowBasedChecksums; | ||
| } | ||
| public long getAggregatedChecksumValue() { | ||
|
||
| return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums); | ||
| } | ||
|
|
||
| /** | ||
| * This convenience method should only be called in test code. | ||
| */ | ||
|
|
@@ -234,7 +245,7 @@ void closeAndWriteOutput() throws IOException { | |
| } | ||
| } | ||
| mapStatus = MapStatus$.MODULE$.apply( | ||
| blockManager.shuffleServerId(), partitionLengths, mapId); | ||
| blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue()); | ||
| } | ||
|
|
||
| @VisibleForTesting | ||
|
|
@@ -252,6 +263,9 @@ void insertRecordIntoSorter(Product2<K, V> record) throws IOException { | |
|
|
||
| sorter.insertRecord( | ||
| serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); | ||
| if (rowBasedChecksums.length > 0) { | ||
| rowBasedChecksums[partitionId].update(key, record._2()); | ||
| } | ||
| } | ||
|
|
||
| @VisibleForTesting | ||
|
|
@@ -330,7 +344,8 @@ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOExcep | |
| logger.debug("Using slow merge"); | ||
| mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); | ||
| } | ||
| partitionLengths = mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths(); | ||
| partitionLengths = | ||
| mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths(); | ||
|
||
| } catch (Exception e) { | ||
| try { | ||
| mapWriter.abort(e); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.util; | ||
|
|
||
| import java.io.ByteArrayOutputStream; | ||
|
|
||
| /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ | ||
| public final class MyByteArrayOutputStream extends ByteArrayOutputStream { | ||
|
||
| public MyByteArrayOutputStream(int size) { super(size); } | ||
| public byte[] getBuf() { return buf; } | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -29,6 +29,7 @@ import org.apache.spark.internal.LogKeys._ | |||||
| import org.apache.spark.rdd.RDD | ||||||
| import org.apache.spark.serializer.Serializer | ||||||
| import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor} | ||||||
| import org.apache.spark.shuffle.checksum.RowBasedChecksum | ||||||
| import org.apache.spark.storage.BlockManagerId | ||||||
| import org.apache.spark.util.Utils | ||||||
|
|
||||||
|
|
@@ -59,6 +60,9 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { | |||||
| override def rdd: RDD[T] = _rdd | ||||||
| } | ||||||
|
|
||||||
| object ShuffleDependency { | ||||||
| private val EmptyRowBasedChecksums: Array[RowBasedChecksum] = Array.empty | ||||||
|
||||||
| private val EmptyRowBasedChecksums: Array[RowBasedChecksum] = Array.empty | |
| private val EMPTY_ROW_BASED_CHECKSUMS: Array[RowBasedChecksum] = Array.empty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| Array.empty | |
| EMPTY_ROW_BASED_CHECKSUMS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE | |
| import java.util.concurrent.locks.ReentrantReadWriteLock | ||
|
|
||
| import scala.collection | ||
| import scala.collection.mutable.{HashMap, ListBuffer, Map} | ||
| import scala.collection.mutable.{HashMap, ListBuffer, Map, Set} | ||
| import scala.concurrent.{ExecutionContext, Future} | ||
| import scala.concurrent.duration.Duration | ||
| import scala.jdk.CollectionConverters._ | ||
|
|
@@ -99,6 +99,11 @@ private class ShuffleStatus( | |
| */ | ||
| val mapStatusesDeleted = new Array[MapStatus](numPartitions) | ||
|
|
||
| /** | ||
| * Keep the indices of the Map tasks whose checksums are different across retries. | ||
| */ | ||
| private[this] val checksumMismatchIndices : Set[Int] = Set() | ||
|
|
||
| /** | ||
| * MergeStatus for each shuffle partition when push-based shuffle is enabled. The index of the | ||
| * array is the shuffle partition id (reduce id). Each value in the array is the MergeStatus for | ||
|
|
@@ -169,6 +174,12 @@ private class ShuffleStatus( | |
| } else { | ||
| mapIdToMapIndex.remove(currentMapStatus.mapId) | ||
| } | ||
|
|
||
| val preStatus = | ||
| if (mapStatuses(mapIndex) != null) mapStatuses(mapIndex) else mapStatusesDeleted(mapIndex) | ||
| if (preStatus != null && preStatus.checksumValue != status.checksumValue) { | ||
| checksumMismatchIndices.add(mapIndex) | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are three main cases here:
For the latter two, we dont need to track it in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, for case 1, we need to track the mismatches. The usage of checksumMismatchIndices is that (in the next PR) we will rollback the downstream stages, if we detect checksum mismatches for its upstream stages. For case 2, if downstream stages have not consumed output, which means they have not started. In this case, the rollback is a no-op, and it doesn't hurt to record the mismatches here. For case 3, I think we need to record the mismatches. Assuming a situation where all partitions of a stage have finished, while some speculative tasks are still running. As all outputs have been produced, the downstream stage can start and read from the data. Later, some speculative tasks finish, and new mapStatus will override the old mapStatus with new data location. For the downstream stage, the not yet started tasks or retried tasks would read from the new data, while the finished and running tasks would read from the old data, resulting in inconsistency.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It is unclear how
That is fair, this is indeed possible. |
||
| mapStatuses(mapIndex) = status | ||
| mapIdToMapIndex(status.mapId) = mapIndex | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's update the classdoc. We now also leverage the sum to handle duplicated values better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated