From 7b40e2edbdd9813bb42664da4a9bb190476adf13 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 20 Jul 2021 16:14:17 +0800 Subject: [PATCH 01/48] init --- .../spark/network/corruption/Cause.java | 22 ++++ .../network/shuffle/BlockStoreClient.java | 21 ++++ .../shuffle/ExternalBlockStoreClient.java | 11 ++ .../protocol/BlockTransferMessage.java | 4 +- .../shuffle/protocol/CorruptionCause.java | 73 +++++++++++++ .../shuffle/protocol/DiagnoseCorruption.java | 100 ++++++++++++++++++ .../checksum/ShuffleChecksumHelper.java | 48 ++++++++- .../spark/network/BlockDataManager.scala | 6 ++ .../network/netty/NettyBlockRpcServer.scala | 5 + .../netty/NettyBlockTransferService.scala | 42 +++++++- .../apache/spark/storage/BlockManager.scala | 53 +++++++++- .../storage/ShuffleBlockFetcherIterator.scala | 100 +++++++++++++++--- .../scala/org/apache/spark/ShuffleSuite.scala | 39 ++++++- 13 files changed, 502 insertions(+), 22 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java b/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java new file mode 100644 index 0000000000000..2019a0b842a8c --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.corruption; + +public enum Cause { + DISK_ISSUE, NETWORK_ISSUE, UNKNOWN_ISSUE, CHECKSUM_VERIFY_PASS +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index 829884645d9d5..4c7322bf25e0a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -33,6 +33,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.corruption.Cause; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.GetLocalDirsForExecutors; import org.apache.spark.network.shuffle.protocol.LocalDirsForExecutors; @@ -47,6 +48,26 @@ public abstract class BlockStoreClient implements Closeable { protected volatile TransportClientFactory clientFactory; protected String appId; + /** + * Send the diagnosis request for the corrupted shuffle block to the server. + * + * @param host the host of the remote node. + * @param port the port of the remote node. + * @param execId the executor id. + * @param blockId the blockId of the corrupted shuffle block + * @param checksum the shuffle checksum which calculated at client side for the corrupted + * shuffle block + * @return The cause of the shuffle block corruption + */ + public Cause diagnoseCorruption( + String host, + int port, + String execId, + String blockId, + long checksum) { + return Cause.UNKNOWN_ISSUE; + } + /** * Fetch a sequence of blocks from a remote node asynchronously, * diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index eb2d118b7d4fa..2144096ed6909 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -35,6 +35,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.corruption.Cause; import org.apache.spark.network.crypto.AuthClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; @@ -83,6 +84,16 @@ public void init(String appId) { clientFactory = context.createClientFactory(bootstraps); } + @Override + public Cause diagnoseCorruption( + String host, + int port, + String execId, + String blockId, + long checksum) { + return Cause.UNKNOWN_ISSUE; + } + @Override public void fetchBlocks( String host, diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index a55a6cf7ed939..453791da7bba2 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -49,7 +49,7 @@ public enum Type { HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8), FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11), PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14), - FETCH_SHUFFLE_BLOCK_CHUNKS(15); + FETCH_SHUFFLE_BLOCK_CHUNKS(15), DIAGNOSE_CORRUPTION(16), CORRUPTION_CAUSE(17); private final byte id; @@ -84,6 +84,8 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 13: return FinalizeShuffleMerge.decode(buf); case 14: return MergeStatuses.decode(buf); case 15: return FetchShuffleBlockChunks.decode(buf); + case 16: return DiagnoseCorruption.decode(buf); + case 17: return CorruptionCause.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java new file mode 100644 index 0000000000000..4bb7a3aef012a --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; +import org.apache.spark.network.corruption.Cause; + +/** Response to the {@link DiagnoseCorruption} */ +public class CorruptionCause extends BlockTransferMessage { + public Cause cause; + + public CorruptionCause(Cause cause) { + this.cause = cause; + } + + @Override + protected Type type() { + return Type.CORRUPTION_CAUSE; + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("cause", cause) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + CorruptionCause that = (CorruptionCause) o; + return cause == that.cause; + } + + @Override + public int hashCode() { + return cause.hashCode(); + } + + @Override + public int encodedLength() { + return 4; /* encoded length of cause */ + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(cause.ordinal()); + } + + public static CorruptionCause decode(ByteBuf buf) { + int ordinal = buf.readByte(); + return new CorruptionCause(Cause.values()[ordinal]); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java new file mode 100644 index 0000000000000..119497809c596 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; +import org.apache.spark.network.protocol.Encoders; + +/** Request to get the cause of a corrupted block. Returns {@link CorruptionCause} */ +public class DiagnoseCorruption extends BlockTransferMessage { + private final String appId; + private final String execId; + public final String blockId; + public final long checksum; + + public DiagnoseCorruption(String appId, String execId, String blockId, long checksum) { + this.appId = appId; + this.execId = execId; + this.blockId = blockId; + this.checksum = checksum; + } + + @Override + protected Type type() { + return Type.DIAGNOSE_CORRUPTION; + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("appId", appId) + .append("execId", execId) + .append("blockId", blockId) + .append("checksum", checksum) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + DiagnoseCorruption that = (DiagnoseCorruption) o; + + if (checksum != that.checksum) return false; + if (!appId.equals(that.appId)) return false; + if (!execId.equals(that.execId)) return false; + if (!blockId.equals(that.blockId)) return false; + return true; + } + + @Override + public int hashCode() { + int result = appId.hashCode(); + result = 31 * result + execId.hashCode(); + result = 31 * result + blockId.hashCode(); + result = 31 * result + Long.hashCode(checksum); + return result; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.Strings.encodedLength(blockId) + + 8; /* encoded length of checksum */ + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.Strings.encode(buf, blockId); + buf.writeLong(checksum); + } + + public static DiagnoseCorruption decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String blockId = Encoders.Strings.decode(buf); + long checksum = buf.readLong(); + return new DiagnoseCorruption(appId, execId, blockId, checksum); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java index a368836d2bb1d..cca6b8ba31d46 100644 --- a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java +++ b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java @@ -17,15 +17,25 @@ package org.apache.spark.shuffle.checksum; +import java.io.*; +import java.nio.channels.Channels; +import java.nio.channels.SeekableByteChannel; +import java.nio.file.Files; import java.util.zip.Adler32; import java.util.zip.CRC32; +import java.util.zip.CheckedInputStream; import java.util.zip.Checksum; import org.apache.spark.SparkConf; import org.apache.spark.SparkException; import org.apache.spark.annotation.Private; import org.apache.spark.internal.config.package$; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleBlockResolver; +import org.apache.spark.storage.ShuffleBlockId; import org.apache.spark.storage.ShuffleChecksumBlockId; +import org.apache.spark.util.Utils; +import scala.Option; /** * A set of utility functions for the shuffle checksum. @@ -33,9 +43,12 @@ @Private public class ShuffleChecksumHelper { - /** Used when the checksum is disabled for shuffle. */ + /** + * Used when the checksum is disabled for shuffle. + */ private static final Checksum[] EMPTY_CHECKSUM = new Checksum[0]; public static final long[] EMPTY_CHECKSUM_VALUE = new long[0]; + public static final int CHECKSUM_CALCULATION_BUFFER = 8192; public static boolean isShuffleChecksumEnabled(SparkConf conf) { return (boolean) conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ENABLED()); @@ -52,19 +65,19 @@ public static Checksum[] createPartitionChecksumsIfEnabled(int numPartitions, Sp } private static Checksum[] getChecksumByAlgorithm(int num, String algorithm) - throws SparkException { + throws SparkException { Checksum[] checksums; switch (algorithm) { case "ADLER32": checksums = new Adler32[num]; - for (int i = 0; i < num; i ++) { + for (int i = 0; i < num; i++) { checksums[i] = new Adler32(); } return checksums; case "CRC32": checksums = new CRC32[num]; - for (int i = 0; i < num; i ++) { + for (int i = 0; i < num; i++) { checksums[i] = new CRC32(); } return checksums; @@ -77,7 +90,7 @@ private static Checksum[] getChecksumByAlgorithm(int num, String algorithm) public static long[] getChecksumValues(Checksum[] partitionChecksums) { int numPartitions = partitionChecksums.length; long[] checksumValues = new long[numPartitions]; - for (int i = 0; i < numPartitions; i ++) { + for (int i = 0; i < numPartitions; i++) { checksumValues[i] = partitionChecksums[i].getValue(); } return checksumValues; @@ -93,8 +106,33 @@ public static Checksum getChecksumByFileExtension(String fileName) throws SparkE return getChecksumByAlgorithm(1, algorithm)[0]; } + public static Checksum getChecksumByConf(SparkConf conf) throws SparkException { + String algorithm = shuffleChecksumAlgorithm(conf); + return getChecksumByAlgorithm(1, algorithm)[0]; + } + public static String getChecksumFileName(ShuffleChecksumBlockId blockId, SparkConf conf) { // append the shuffle checksum algorithm as the file extension return String.format("%s.%s", blockId.name(), shuffleChecksumAlgorithm(conf)); } + + public static long readChecksumByReduceId(File checksumFile, int reduceId) throws IOException { + try (DataInputStream in = new DataInputStream(new FileInputStream(checksumFile))) { + in.skip(reduceId * 8L); + return in.readLong(); + } + } + + public static long calculateChecksumForPartition( + ShuffleBlockId blockId, + IndexShuffleBlockResolver resolver) throws IOException, SparkException { + InputStream in = resolver.getBlockData(blockId, Option.empty()).createInputStream(); + File checksumFile = resolver.getChecksumFile(blockId.shuffleId(), blockId.reduceId(), Option.empty()); + Checksum checksumAlgo = getChecksumByFileExtension(checksumFile.getName()); + byte[] buffer = new byte[CHECKSUM_CALCULATION_BUFFER]; + try(CheckedInputStream checksumIn = new CheckedInputStream(in, checksumAlgo)) { + while (checksumIn.read(buffer, 0, CHECKSUM_CALCULATION_BUFFER) != -1) {} + return checksumAlgo.getValue(); + } + } } diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index cafb39ea82ad9..d35f6770f497e 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -22,11 +22,17 @@ import scala.reflect.ClassTag import org.apache.spark.TaskContext import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.StreamCallbackWithID +import org.apache.spark.network.corruption.Cause import org.apache.spark.storage.{BlockId, StorageLevel} private[spark] trait BlockDataManager { + /** + * Diagnose the possible cause of the shuffle data corruption by verify the shuffle checksums + */ + def diagnoseShuffleBlockCorruption(blockId: BlockId, checksumByReader: Long): Cause + /** * Get the local directories that used by BlockManager to save the blocks to disk */ diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 5f831dc666ca5..daa60caf9e189 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -133,6 +133,11 @@ class NettyBlockRpcServer( Map(actualExecId -> blockManager.getLocalDiskDirs).asJava).toByteBuffer) } } + + case diagnose: DiagnoseCorruption => + val cause = blockManager + .diagnoseShuffleBlockCorruption(BlockId.apply(diagnose.blockId), diagnose.checksum) + responseContext.onSuccess(new CorruptionCause(cause).toByteBuffer) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 4e0beeaec97ad..600116d50eb37 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -20,9 +20,11 @@ package org.apache.spark.network.netty import java.io.IOException import java.nio.ByteBuffer import java.util.{HashMap => JHashMap, Map => JMap} +import java.util.concurrent.TimeoutException import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} +import scala.concurrent.duration._ import scala.reflect.ClassTag import scala.util.{Success, Try} @@ -31,15 +33,17 @@ import com.codahale.metrics.{Metric, MetricSet} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.ExecutorDeadException import org.apache.spark.internal.config +import org.apache.spark.internal.config.Network import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap} +import org.apache.spark.network.corruption.Cause import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockTransferListener, DownloadFileManager, OneForOneBlockFetcher, RetryingBlockTransferor} -import org.apache.spark.network.shuffle.protocol.{UploadBlock, UploadBlockStream} +import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, CorruptionCause, DiagnoseCorruption, UploadBlock, UploadBlockStream} import org.apache.spark.network.util.JavaUtils -import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.storage.BlockManagerMessages.IsExecutorAlive @@ -104,6 +108,40 @@ private[spark] class NettyBlockTransferService( } } + override def diagnoseCorruption( + host: String, + port: Int, + execId: String, + blockId: String, + checksum: Long): Cause = { + // A monitor for the thread to wait on. + val result = Promise[Cause]() + val client = clientFactory.createClient(host, port) + client.sendRpc(new DiagnoseCorruption(appId, execId, blockId, checksum).toByteBuffer, + new RpcResponseCallback { + override def onSuccess(response: ByteBuffer): Unit = { + val cause = BlockTransferMessage.Decoder + .fromByteBuffer(response).asInstanceOf[CorruptionCause] + result.success(cause.cause) + } + + override def onFailure(e: Throwable): Unit = { + logger.warn("Failed to get the corruption cause.", e) + result.success(Cause.UNKNOWN_ISSUE) + } + }) + val timeout = new RpcTimeout( + conf.get(Network.NETWORK_TIMEOUT).seconds, + Network.NETWORK_TIMEOUT.key) + try { + timeout.awaitResult(result.future) + } catch { + case _: TimeoutException => + logger.warn("Failed to get the corruption cause due to timeout.") + Cause.UNKNOWN_ISSUE + } + } + override def fetchBlocks( host: String, port: Int, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index b81b3b60520c1..d87c4368bc5de 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -47,6 +47,7 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.client.StreamCallbackWithID +import org.apache.spark.network.corruption.Cause import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo @@ -54,7 +55,8 @@ import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} -import org.apache.spark.shuffle.{MigratableResolver, ShuffleManager, ShuffleWriteMetricsReporter} +import org.apache.spark.shuffle.{IndexShuffleBlockResolver, MigratableResolver, ShuffleManager, ShuffleWriteMetricsReporter} +import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper._ import org.apache.spark.storage.BlockManagerMessages.{DecommissionBlockManager, ReplicateBlock} import org.apache.spark.storage.memory._ import org.apache.spark.unsafe.Platform @@ -282,6 +284,55 @@ private[spark] class BlockManager( override def getLocalDiskDirs: Array[String] = diskBlockManager.localDirsString + /** + * Diagnose the possible cause of the shuffle data corruption by verify the shuffle checksums. + * + * There're 3 different kinds of checksums for the same shuffle partition: + * - checksum (c1) that calculated by the shuffle data reader + * - checksum (c2) that calculated by the shuffle data writer and stored in the checksum file + * - checksum (c3) that recalculated during diagnosis + * + * And the diagnosis mechanism works like this: + * If c2 != c3, we suspect the corruption is caused by the DISK_ISSUE. Otherwise, if c1 != c3, + * we suspect the corruption is caused by the NETWORK_ISSUE. Otherwise, the cause remains + * CHECKSUM_VERIFY_PASS. In case of the any other failures, the cause remains UNKNOWN_ISSUE. + * + * @param blockId The shuffle block Id + * @param checksumByReader The checksum value that calculated by the shuffle data reader + * @return The cause of data corruption + */ + override def diagnoseShuffleBlockCorruption(blockId: BlockId, checksumByReader: Long): Cause = { + assert(blockId.isInstanceOf[ShuffleBlockId], + s"Corruption diagnosis only supports shuffle block yet, but got $blockId") + val shuffleBlock = blockId.asInstanceOf[ShuffleBlockId] + val resolver = shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver] + val checksumFile = resolver.getChecksumFile(shuffleBlock.shuffleId, shuffleBlock.mapId) + val reduceId = shuffleBlock.reduceId + if (checksumFile.exists()) { + try { + val checksumByWriter = readChecksumByReduceId(checksumFile, reduceId) + val (checksumByReCalculation, t) = + Utils.timeTakenMs(calculateChecksumForPartition(shuffleBlock, resolver)) + logInfo(s"Checksum recalculation for shuffle block $shuffleBlock took $t ms") + if (checksumByWriter != checksumByReCalculation) { + Cause.DISK_ISSUE + } else if (checksumByWriter != checksumByReader) { + Cause.NETWORK_ISSUE + } else { + Cause.CHECKSUM_VERIFY_PASS + } + } catch { + case NonFatal(e) => + logWarning("Exception throws while diagnosing shuffle block corruption.", e) + Cause.UNKNOWN_ISSUE + } + } else { + // Even if checksum is enabled, a checksum file may not exist if error throws during writing. + logWarning(s"Checksum file ${checksumFile.getName} doesn't exit") + Cause.UNKNOWN_ISSUE + } + } + /** * Abstraction for storing blocks from bytes, whether they start in memory or on disk. * diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index fd87f5e568d0c..d819a06782566 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -21,6 +21,7 @@ import java.io.{InputStream, IOException} import java.nio.channels.ClosedByInterruptException import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} import java.util.concurrent.atomic.AtomicBoolean +import java.util.zip.CheckedInputStream import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -31,13 +32,15 @@ import io.netty.util.internal.OutOfDirectMemoryError import org.apache.commons.io.IOUtils import org.roaringbitmap.RoaringBitmap -import org.apache.spark.{MapOutputTracker, SparkException, TaskContext} +import org.apache.spark.{MapOutputTracker, SparkEnv, SparkException, TaskContext} import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.corruption.Cause import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.{NettyUtils, TransportConf} import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} +import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} /** @@ -161,6 +164,8 @@ final class ShuffleBlockFetcherIterator( */ private[this] val corruptedBlocks = mutable.HashSet[BlockId]() + private[this] val checksumEnabled = SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED) + /** * Whether the iterator is still active. If isZombie is true, the callback interface will no * longer place fetched blocks into [[results]]. @@ -732,6 +737,8 @@ final class ShuffleBlockFetcherIterator( var result: FetchResult = null var input: InputStream = null + // This's only initialized when shuffle checksum is enabled. + var checkedIn: CheckedInputStream = null var streamCompressedOrEncrypted: Boolean = false // Take the next fetched result and try to decompress it to detect data corruption, // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch @@ -787,7 +794,13 @@ final class ShuffleBlockFetcherIterator( } val in = try { - buf.createInputStream() + var bufIn = buf.createInputStream() + if (checksumEnabled) { + val checksum = ShuffleChecksumHelper.getChecksumByConf(SparkEnv.get.conf) + checkedIn = new CheckedInputStream(bufIn, checksum) + bufIn = checkedIn + } + bufIn } catch { // The exception could only be throwed by local shuffle block case e: IOException => @@ -822,8 +835,8 @@ final class ShuffleBlockFetcherIterator( } } catch { case e: IOException => - buf.release() if (blockId.isShuffleChunk) { + buf.release() // Retrying a corrupt block may result again in a corrupt block. For shuffle // chunks, we opt to fallback on the original shuffle blocks that belong to that // corrupt shuffle chunk immediately instead of retrying to fetch the corrupt @@ -837,13 +850,28 @@ final class ShuffleBlockFetcherIterator( } else { if (buf.isInstanceOf[FileSegmentManagedBuffer] || corruptedBlocks.contains(blockId)) { + buf.release() throwFetchFailedException(blockId, mapIndex, address, e) } else { - logWarning(s"got an corrupted block $blockId from $address, fetch again", e) - corruptedBlocks += blockId - fetchRequests += FetchRequest( - address, Array(FetchBlockInfo(blockId, size, mapIndex))) - result = null + logWarning(s"Got an corrupted block $blockId from $address", e) + // A disk issue indicates the data on disk has already corrupted, so it's + // meaningless to retry on this case. We'll give a retry in the case of + // network issue and other unknown issues (in order to keep the same + // behavior as previously) + val allowRetry = !checksumEnabled || + diagnoseCorruption(checkedIn, address, blockId) != Cause.DISK_ISSUE + buf.release() + if (allowRetry) { + logInfo(s"Will retry the block $blockId") + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null + } else { + logError(s"Block $blockId is corrupted due to disk issue, won't retry.") + throwFetchFailedException(blockId, mapIndex, address, e, + Some(s"Block $blockId is corrupted due to disk issue")) + } } } } finally { @@ -975,7 +1003,48 @@ final class ShuffleBlockFetcherIterator( currentResult.mapIndex, currentResult.address, detectCorrupt && streamCompressedOrEncrypted, - currentResult.isNetworkReqDone)) + currentResult.isNetworkReqDone, + Option(checkedIn))) + } + + /** + * Get the suspect corruption cause for the corrupted block. It should be only invoked + * when checksum is enabled. + * + * This will firstly consume the rest of stream of the corrupted block to calculate the + * checksum of the block. Then, it will raise a synchronized RPC call along with the + * checksum to ask the server(where the corrupted block is fetched from) to diagnose the + * cause of corruption and return it. + * + * Any exception raised during the process will result in the [[Cause.UNKNOWN_ISSUE]] of the + * corruption cause since corruption diagnosis is only a best effort. + * + * @param checkedIn the [[CheckedInputStream]] which is used to calculate the checksum. + * @param address the address where the corrupted block is fetched from. + * @param blockId the blockId of the corrupted block. + * @return the cause of corruption, which should be one of the [[Cause]]. + */ + private[storage] def diagnoseCorruption( + checkedIn: CheckedInputStream, + address: BlockManagerId, + blockId: BlockId): Cause = { + logInfo("Start corruption diagnosis.") + val startTimeNs = System.nanoTime() + val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER) + // consume the remaining data to calculate the checksum + try { + while (checkedIn.read(buffer, 0, 8192) != -1) {} + } catch { + case e: IOException => + logWarning("IOException throws while consuming the rest stream of the corrupted block", e) + return Cause.UNKNOWN_ISSUE + } + val checksum = checkedIn.getChecksum.getValue + val cause = shuffleClient.diagnoseCorruption( + address.host, address.port, address.executorId, blockId.toString, checksum) + val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + logInfo(s"Finished corruption diagnosis in ${duration} ms, cause: $cause") + cause } def toCompletionIterator: Iterator[(BlockId, InputStream)] = { @@ -1158,7 +1227,8 @@ private class BufferReleasingInputStream( private val mapIndex: Int, private val address: BlockManagerId, private val detectCorruption: Boolean, - private val isNetworkReqDone: Boolean) + private val isNetworkReqDone: Boolean, + private val checkedInOpt: Option[CheckedInputStream]) extends InputStream { private[this] var closed = false @@ -1207,8 +1277,14 @@ private class BufferReleasingInputStream( block } catch { case e: IOException if detectCorruption => + val message = checkedInOpt.map { checkedIn => + val cause = iterator.diagnoseCorruption(checkedIn, address, blockId) + s"Block $blockId is corrupted due to $cause" + }.orNull IOUtils.closeQuietly(this) - iterator.throwFetchFailedException(blockId, mapIndex, address, e) + // We'd never retry the block whatever the cause is since the block has been + // partially consumed by downstream RDDs. + iterator.throwFetchFailedException(blockId, mapIndex, address, e, Some(message)) } } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index fd75d91d8dd2b..91a454e6a218e 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark -import java.io.File +import java.io.{File, FileOutputStream} +import java.nio.ByteBuffer import java.util.{Locale, Properties} import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService } @@ -447,6 +448,42 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDi } } } + + test("SPARK-18188: shuffle checksum detect disk corruption") { + conf + .set(config.SHUFFLE_CHECKSUM_ENABLED, true) + .set(TEST_NO_STAGE_RETRY, false) + .set("spark.stage.maxConsecutiveAttempts", "1") + .set(config.SHUFFLE_SERVICE_ENABLED, true) + sc = new SparkContext("local-cluster[2, 1, 2048]", "test", conf) + val rdd = sc.parallelize(1 to 10, 2).map((_, 1)).reduceByKey(_ + _) + // materialize the shuffle map outputs + rdd.count() + + sc.parallelize(1 to 10, 2).barrier().mapPartitions { iter => + var dataFile = SparkEnv.get.blockManager + .diskBlockManager.getFile(ShuffleDataBlockId(0, 0, 0)) + if (!dataFile.exists()) { + dataFile = SparkEnv.get.blockManager + .diskBlockManager.getFile(ShuffleDataBlockId(0, 1, 0)) + } + + if (dataFile.exists()) { + val f = new FileOutputStream(dataFile, true) + val ch = f.getChannel + // corrupt the shuffle data files by writing some arbitrary bytes + ch.write(ByteBuffer.wrap(Array[Byte](12)), 0) + ch.close() + } + BarrierTaskContext.get().barrier() + iter + }.collect() + + val e = intercept[SparkException] { + rdd.count() + } + assert(e.getMessage.contains("corrupted due to DISK_ISSUE")) + } } /** From 3cdd858ce2ec8f4f57882d152b9a752f91227244 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 21 Jul 2021 11:57:40 +0800 Subject: [PATCH 02/48] update --- .../network/shuffle/BlockStoreClient.java | 30 ++++- .../network/shuffle/ExternalBlockHandler.java | 9 ++ .../shuffle/ExternalBlockStoreClient.java | 32 ++--- .../shuffle/ExternalShuffleBlockResolver.java | 28 ++++ .../ShuffleCorruptionDiagnosisHelper.java | 123 ++++++++++++++++++ .../shuffle/protocol/CorruptionCause.java | 2 +- .../shuffle/protocol/DiagnoseCorruption.java | 46 +++++-- .../shuffle/ExternalBlockHandlerSuite.java | 73 ++++++++++- .../checksum/ShuffleChecksumHelper.java | 39 +----- .../network/netty/NettyBlockRpcServer.scala | 5 +- .../netty/NettyBlockTransferService.scala | 48 ++----- .../shuffle/IndexShuffleBlockResolver.scala | 14 +- .../apache/spark/storage/BlockManager.scala | 27 +--- .../storage/ShuffleBlockFetcherIterator.scala | 13 +- .../sort/UnsafeShuffleWriterSuite.java | 5 +- .../scala/org/apache/spark/ShuffleSuite.scala | 2 +- .../shuffle/ShuffleChecksumTestHelper.scala | 5 +- .../BypassMergeSortShuffleWriterSuite.scala | 2 +- 18 files changed, 340 insertions(+), 163 deletions(-) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index 4c7322bf25e0a..dc621cfa3c680 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -34,9 +34,8 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.corruption.Cause; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; -import org.apache.spark.network.shuffle.protocol.GetLocalDirsForExecutors; -import org.apache.spark.network.shuffle.protocol.LocalDirsForExecutors; +import org.apache.spark.network.shuffle.protocol.*; +import org.apache.spark.network.util.TransportConf; /** * Provides an interface for reading both shuffle files and RDD blocks, either from an Executor @@ -47,6 +46,7 @@ public abstract class BlockStoreClient implements Closeable { protected volatile TransportClientFactory clientFactory; protected String appId; + protected TransportConf transportConf; /** * Send the diagnosis request for the corrupted shuffle block to the server. @@ -54,7 +54,9 @@ public abstract class BlockStoreClient implements Closeable { * @param host the host of the remote node. * @param port the port of the remote node. * @param execId the executor id. - * @param blockId the blockId of the corrupted shuffle block + * @param shuffleId the shuffleId of the corrupted shuffle block + * @param mapId the mapId of the corrupted shuffle block + * @param reduceId the reduceId of the corrupted shuffle block * @param checksum the shuffle checksum which calculated at client side for the corrupted * shuffle block * @return The cause of the shuffle block corruption @@ -63,9 +65,23 @@ public Cause diagnoseCorruption( String host, int port, String execId, - String blockId, - long checksum) { - return Cause.UNKNOWN_ISSUE; + int shuffleId, + long mapId, + int reduceId, + long checksum) throws IOException, InterruptedException { + TransportClient client = clientFactory.createClient(host, port); + try { + ByteBuffer response = client.sendRpcSync( + new DiagnoseCorruption(appId, execId, shuffleId, mapId, reduceId, checksum).toByteBuffer(), + transportConf.connectionTimeoutMs() + ); + CorruptionCause cause = + (CorruptionCause) BlockTransferMessage.Decoder.fromByteBuffer(response); + return cause.cause; + } catch (Exception e) { + logger.warn("Failed to get the corruption cause."); + return Cause.UNKNOWN_ISSUE; + } } /** diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index cfabcd5ba4a28..f942ab3b8d9b2 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -37,6 +37,7 @@ import com.codahale.metrics.Counter; import com.google.common.collect.Sets; import com.google.common.annotations.VisibleForTesting; +import org.apache.spark.network.corruption.Cause; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -223,6 +224,14 @@ protected void handleMessage( } finally { responseDelayContext.stop(); } + } else if (msgObj instanceof DiagnoseCorruption) { + DiagnoseCorruption msg = (DiagnoseCorruption) msgObj; + checkAuth(client, msg.appId); + Cause cause = blockManager.diagnoseShuffleBlockCorruption( + msg.appId, msg.execId, msg.shuffleId, msg.mapId, msg.reduceId, msg.checksum); + // In any cases of the error, diagnoseShuffleBlockCorruption should return UNKNOWN_ISSUE, + // so it should always reply as success. + callback.onSuccess(new CorruptionCause(cause).toByteBuffer()); } else { throw new UnsupportedOperationException("Unexpected message: " + msgObj); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index 2144096ed6909..f53215f744ad4 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -50,7 +50,6 @@ public class ExternalBlockStoreClient extends BlockStoreClient { private static final ErrorHandler PUSH_ERROR_HANDLER = new ErrorHandler.BlockPushErrorHandler(); - private final TransportConf conf; private final boolean authEnabled; private final SecretKeyHolder secretKeyHolder; private final long registrationTimeoutMs; @@ -64,7 +63,7 @@ public ExternalBlockStoreClient( SecretKeyHolder secretKeyHolder, boolean authEnabled, long registrationTimeoutMs) { - this.conf = conf; + this.transportConf = conf; this.secretKeyHolder = secretKeyHolder; this.authEnabled = authEnabled; this.registrationTimeoutMs = registrationTimeoutMs; @@ -76,24 +75,15 @@ public ExternalBlockStoreClient( */ public void init(String appId) { this.appId = appId; - TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true, true); + TransportContext context = new TransportContext( + transportConf, new NoOpRpcHandler(), true, true); List bootstraps = Lists.newArrayList(); if (authEnabled) { - bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); + bootstraps.add(new AuthClientBootstrap(transportConf, appId, secretKeyHolder)); } clientFactory = context.createClientFactory(bootstraps); } - @Override - public Cause diagnoseCorruption( - String host, - int port, - String execId, - String blockId, - long checksum) { - return Cause.UNKNOWN_ISSUE; - } - @Override public void fetchBlocks( String host, @@ -105,7 +95,7 @@ public void fetchBlocks( checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { - int maxRetries = conf.maxIORetries(); + int maxRetries = transportConf.maxIORetries(); RetryingBlockTransferor.BlockTransferStarter blockFetchStarter = (inputBlockId, inputListener) -> { // Unless this client is closed. @@ -114,7 +104,7 @@ public void fetchBlocks( "Expecting a BlockFetchingListener, but got " + inputListener.getClass(); TransportClient client = clientFactory.createClient(host, port, maxRetries > 0); new OneForOneBlockFetcher(client, appId, execId, inputBlockId, - (BlockFetchingListener) inputListener, conf, downloadFileManager).start(); + (BlockFetchingListener) inputListener, transportConf, downloadFileManager).start(); } else { logger.info("This clientFactory was closed. Skipping further block fetch retries."); } @@ -123,7 +113,7 @@ public void fetchBlocks( if (maxRetries > 0) { // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's // a bug in this code. We should remove the if statement once we're sure of the stability. - new RetryingBlockTransferor(conf, blockFetchStarter, blockIds, listener).start(); + new RetryingBlockTransferor(transportConf, blockFetchStarter, blockIds, listener).start(); } else { blockFetchStarter.createAndStart(blockIds, listener); } @@ -157,16 +147,16 @@ public void pushBlocks( assert inputListener instanceof BlockPushingListener : "Expecting a BlockPushingListener, but got " + inputListener.getClass(); TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockPusher(client, appId, conf.appAttemptId(), inputBlockId, + new OneForOneBlockPusher(client, appId, transportConf.appAttemptId(), inputBlockId, (BlockPushingListener) inputListener, buffersWithId).start(); } else { logger.info("This clientFactory was closed. Skipping further block push retries."); } }; - int maxRetries = conf.maxIORetries(); + int maxRetries = transportConf.maxIORetries(); if (maxRetries > 0) { new RetryingBlockTransferor( - conf, blockPushStarter, blockIds, listener, PUSH_ERROR_HANDLER).start(); + transportConf, blockPushStarter, blockIds, listener, PUSH_ERROR_HANDLER).start(); } else { blockPushStarter.createAndStart(blockIds, listener); } @@ -189,7 +179,7 @@ public void finalizeShuffleMerge( try { TransportClient client = clientFactory.createClient(host, port); ByteBuffer finalizeShuffleMerge = - new FinalizeShuffleMerge(appId, conf.appAttemptId(), shuffleId, + new FinalizeShuffleMerge(appId, transportConf.appAttemptId(), shuffleId, shuffleMergeId).toByteBuffer(); client.sendRpc(finalizeShuffleMerge, new RpcResponseCallback() { @Override diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 493edd2b34628..656c337822c01 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -45,6 +45,8 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.corruption.Cause; +import org.apache.spark.network.shuffle.checksum.ShuffleCorruptionDiagnosisHelper; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.LevelDBProvider; import org.apache.spark.network.util.LevelDBProvider.StoreVersion; @@ -374,6 +376,32 @@ public Map getLocalDirs(String appId, Set execIds) { .collect(Collectors.toMap(Pair::getKey, Pair::getValue)); } + /** + * Diagnose the possible cause of the shuffle data corruption by verify the shuffle checksums + */ + public Cause diagnoseShuffleBlockCorruption( + String appId, + String execId, + int shuffleId, + long mapId, + int reduceId, + long checksumByReader) { + ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); + String fileName = "shuffle_" + shuffleId + "_" + mapId + "_0.checksum"; + File probeFile = ExecutorDiskUtils.getFile( + executor.localDirs, + executor.subDirsPerLocalDir, + fileName); + File parentFile = probeFile.getParentFile(); + + File[] checksumFiles = parentFile.listFiles(f -> f.getName().startsWith(fileName)); + assert checksumFiles.length == 1; + File checksumFile = checksumFiles[0]; + ManagedBuffer data = getBlockData(appId, execId, shuffleId, mapId, reduceId); + return ShuffleCorruptionDiagnosisHelper + .diagnoseCorruption(checksumFile, reduceId, data, checksumByReader); + } + /** Simply encodes an executor's full ID, which is appId + execId. */ public static class AppExecId { public final String appId; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java new file mode 100644 index 0000000000000..70d5b7df8de9c --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.checksum; + +import java.io.*; +import java.util.zip.Adler32; +import java.util.zip.CRC32; +import java.util.zip.CheckedInputStream; +import java.util.zip.Checksum; + +import org.apache.spark.annotation.Private; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.corruption.Cause; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A set of utility functions for the shuffle checksum. + */ +@Private +public class ShuffleCorruptionDiagnosisHelper { + private static final Logger logger = + LoggerFactory.getLogger(ShuffleCorruptionDiagnosisHelper.class); + + public static final int CHECKSUM_CALCULATION_BUFFER = 8192; + + private static Checksum[] getChecksumByAlgorithm(int num, String algorithm) + throws UnsupportedOperationException { + Checksum[] checksums; + switch (algorithm) { + case "ADLER32": + checksums = new Adler32[num]; + for (int i = 0; i < num; i++) { + checksums[i] = new Adler32(); + } + return checksums; + + case "CRC32": + checksums = new CRC32[num]; + for (int i = 0; i < num; i++) { + checksums[i] = new CRC32(); + } + return checksums; + + default: + throw new UnsupportedOperationException("Unsupported shuffle checksum algorithm: " + + algorithm); + } + } + + public static Checksum getChecksumByFileExtension(String fileName) + throws UnsupportedOperationException { + int index = fileName.lastIndexOf("."); + String algorithm = fileName.substring(index + 1); + return getChecksumByAlgorithm(1, algorithm)[0]; + } + + private static long readChecksumByReduceId(File checksumFile, int reduceId) throws IOException { + try (DataInputStream in = new DataInputStream(new FileInputStream(checksumFile))) { + in.skip(reduceId * 8L); + return in.readLong(); + } + } + + private static long calculateChecksumForPartition( + ManagedBuffer partitionData, + Checksum checksumAlgo) throws IOException { + InputStream in = partitionData.createInputStream(); + byte[] buffer = new byte[CHECKSUM_CALCULATION_BUFFER]; + try(CheckedInputStream checksumIn = new CheckedInputStream(in, checksumAlgo)) { + while (checksumIn.read(buffer, 0, CHECKSUM_CALCULATION_BUFFER) != -1) {} + return checksumAlgo.getValue(); + } + } + + public static Cause diagnoseCorruption( + File checksumFile, + int reduceId, + ManagedBuffer partitionData, + long checksumByReader) { + Cause cause; + if (checksumFile.exists()) { + try { + long diagnoseStart = System.currentTimeMillis(); + long checksumByWriter = readChecksumByReduceId(checksumFile, reduceId); + Checksum checksumAlgo = getChecksumByFileExtension(checksumFile.getName()); + long checksumByReCalculation = calculateChecksumForPartition(partitionData, checksumAlgo); + long duration = System.currentTimeMillis() - diagnoseStart; + logger.info("Shuffle corruption diagnosis took " + duration + " ms"); + if (checksumByWriter != checksumByReCalculation) { + cause = Cause.DISK_ISSUE; + } else if (checksumByWriter != checksumByReader) { + cause = Cause.NETWORK_ISSUE; + } else { + cause = Cause.CHECKSUM_VERIFY_PASS; + } + } catch (Exception e) { + logger.warn("Exception throws while diagnosing shuffle block corruption.", e); + cause = Cause.UNKNOWN_ISSUE; + } + } else { + // Even if checksum is enabled, a checksum file may not exist if error throws during writing. + logger.warn("Checksum file " + checksumFile.getName() + " doesn't exit"); + cause = Cause.UNKNOWN_ISSUE; + } + return cause; + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java index 4bb7a3aef012a..d9b04030946f9 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java @@ -58,7 +58,7 @@ public int hashCode() { @Override public int encodedLength() { - return 4; /* encoded length of cause */ + return 1; /* encoded length of cause */ } @Override diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java index 119497809c596..8aa524c425312 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java @@ -24,15 +24,25 @@ /** Request to get the cause of a corrupted block. Returns {@link CorruptionCause} */ public class DiagnoseCorruption extends BlockTransferMessage { - private final String appId; - private final String execId; - public final String blockId; + public final String appId; + public final String execId; + public final int shuffleId; + public final long mapId; + public final int reduceId; public final long checksum; - public DiagnoseCorruption(String appId, String execId, String blockId, long checksum) { + public DiagnoseCorruption( + String appId, + String execId, + int shuffleId, + long mapId, + int reduceId, + long checksum) { this.appId = appId; this.execId = execId; - this.blockId = blockId; + this.shuffleId = shuffleId; + this.mapId = mapId; + this.reduceId = reduceId; this.checksum = checksum; } @@ -46,7 +56,9 @@ public String toString() { return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) .append("appId", appId) .append("execId", execId) - .append("blockId", blockId) + .append("shuffleId", shuffleId) + .append("mapId", mapId) + .append("reduceId", reduceId) .append("checksum", checksum) .toString(); } @@ -61,7 +73,9 @@ public boolean equals(Object o) { if (checksum != that.checksum) return false; if (!appId.equals(that.appId)) return false; if (!execId.equals(that.execId)) return false; - if (!blockId.equals(that.blockId)) return false; + if (shuffleId != that.shuffleId) return false; + if (mapId != that.mapId) return false; + if (reduceId != that.reduceId) return false; return true; } @@ -69,7 +83,9 @@ public boolean equals(Object o) { public int hashCode() { int result = appId.hashCode(); result = 31 * result + execId.hashCode(); - result = 31 * result + blockId.hashCode(); + result = 31 * result + Integer.hashCode(shuffleId); + result = 31 * result + Long.hashCode(mapId); + result = 31 * result + Integer.hashCode(reduceId); result = 31 * result + Long.hashCode(checksum); return result; } @@ -78,7 +94,9 @@ public int hashCode() { public int encodedLength() { return Encoders.Strings.encodedLength(appId) + Encoders.Strings.encodedLength(execId) - + Encoders.Strings.encodedLength(blockId) + + 4 /* encoded length of shuffleId */ + + 8 /* encoded length of mapId */ + + 4 /* encoded length of reduceId */ + 8; /* encoded length of checksum */ } @@ -86,15 +104,19 @@ public int encodedLength() { public void encode(ByteBuf buf) { Encoders.Strings.encode(buf, appId); Encoders.Strings.encode(buf, execId); - Encoders.Strings.encode(buf, blockId); + buf.writeInt(shuffleId); + buf.writeLong(mapId); + buf.writeInt(reduceId); buf.writeLong(checksum); } public static DiagnoseCorruption decode(ByteBuf buf) { String appId = Encoders.Strings.decode(buf); String execId = Encoders.Strings.decode(buf); - String blockId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + long mapId = buf.readLong(); + int reduceId = buf.readInt(); long checksum = buf.readLong(); - return new DiagnoseCorruption(appId, execId, blockId, checksum); + return new DiagnoseCorruption(appId, execId, shuffleId, mapId, reduceId, checksum); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 9e0b3c65c9202..3098b17659af0 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -17,7 +17,7 @@ package org.apache.spark.network.shuffle; -import java.io.IOException; +import java.io.*; import java.nio.ByteBuffer; import java.util.Iterator; import java.util.Map; @@ -25,6 +25,9 @@ import com.codahale.metrics.Meter; import com.codahale.metrics.Metric; import com.codahale.metrics.Timer; +import com.google.common.io.Files; +import org.apache.spark.network.corruption.Cause; +import org.apache.spark.network.shuffle.checksum.ShuffleCorruptionDiagnosisHelper; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; @@ -43,6 +46,8 @@ import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.CorruptionCause; +import org.apache.spark.network.shuffle.protocol.DiagnoseCorruption; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks; import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks; @@ -108,6 +113,72 @@ public void testCompatibilityWithOldVersion() { verifyOpenBlockLatencyMetrics(2, 2); } + private void checkDiagnosisResult( + long checksumByReader, + long checksumByWriter, + Cause expectedCaused) throws IOException { + String appId = "app0"; + String execId = "execId"; + int shuffleId = 0; + long mapId = 0; + int reduceId = 0; + + // prepare the checksum file + File tmpDir = Files.createTempDir(); + tmpDir.deleteOnExit(); + File checksumFile = new File(tmpDir, + "shuffle_" + shuffleId +"_" + mapId + "_" + reduceId + ".checksum.ADLER32"); + DataOutputStream out = new DataOutputStream(new FileOutputStream(checksumFile)); + if (checksumByWriter != 0) { + out.writeLong(checksumByWriter); + } + out.close(); + + // Checksum for the blockMarkers[0] using adler32 is 196609. + when(blockResolver.getBlockData(appId, execId, shuffleId, mapId, reduceId)).thenReturn(blockMarkers[0]); + Cause actualCause = ShuffleCorruptionDiagnosisHelper.diagnoseCorruption(checksumFile, reduceId, + blockResolver.getBlockData(appId, execId, shuffleId, mapId, reduceId), checksumByReader); + when(blockResolver + .diagnoseShuffleBlockCorruption(appId, execId, shuffleId, mapId, reduceId, checksumByReader)) + .thenReturn(actualCause); + + when(client.getClientId()).thenReturn(appId); + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + DiagnoseCorruption diagnoseMsg = new DiagnoseCorruption( + appId, execId, shuffleId, mapId, reduceId, checksumByReader); + handler.receive(client, diagnoseMsg.toByteBuffer(), callback); + + ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); + verify(callback, times(1)).onSuccess(response.capture()); + verify(callback, never()).onFailure(any()); + + CorruptionCause cause = + (CorruptionCause) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); + assertEquals(expectedCaused, cause.cause); + } + + @Test + public void testShuffleCorruptionDiagnosisDiskIssue() throws IOException { + checkDiagnosisResult(1, 1, Cause.DISK_ISSUE); + } + + @Test + public void testShuffleCorruptionDiagnosisNetworkIssue() throws IOException { + checkDiagnosisResult(1, 196609, Cause.NETWORK_ISSUE); + } + + @Test + public void testShuffleCorruptionDiagnosisUnknownIssue() throws IOException { + // Use checksumByWriter=0 to create the invalid checksum file + checkDiagnosisResult(196609, 0, Cause.UNKNOWN_ISSUE); + } + + @Test + public void testShuffleCorruptionDiagnosisChecksumVerifyPass() throws IOException { + checkDiagnosisResult(196609, 196609, Cause.CHECKSUM_VERIFY_PASS); + } + @Test public void testFetchShuffleBlocks() { when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(blockMarkers[0]); diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java index cca6b8ba31d46..98993d0e2a7e2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java +++ b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java @@ -17,25 +17,15 @@ package org.apache.spark.shuffle.checksum; -import java.io.*; -import java.nio.channels.Channels; -import java.nio.channels.SeekableByteChannel; -import java.nio.file.Files; import java.util.zip.Adler32; import java.util.zip.CRC32; -import java.util.zip.CheckedInputStream; import java.util.zip.Checksum; import org.apache.spark.SparkConf; import org.apache.spark.SparkException; import org.apache.spark.annotation.Private; import org.apache.spark.internal.config.package$; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.ShuffleBlockResolver; -import org.apache.spark.storage.ShuffleBlockId; import org.apache.spark.storage.ShuffleChecksumBlockId; -import org.apache.spark.util.Utils; -import scala.Option; /** * A set of utility functions for the shuffle checksum. @@ -48,7 +38,6 @@ public class ShuffleChecksumHelper { */ private static final Checksum[] EMPTY_CHECKSUM = new Checksum[0]; public static final long[] EMPTY_CHECKSUM_VALUE = new long[0]; - public static final int CHECKSUM_CALCULATION_BUFFER = 8192; public static boolean isShuffleChecksumEnabled(SparkConf conf) { return (boolean) conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ENABLED()); @@ -100,12 +89,6 @@ public static String shuffleChecksumAlgorithm(SparkConf conf) { return conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); } - public static Checksum getChecksumByFileExtension(String fileName) throws SparkException { - int index = fileName.lastIndexOf("."); - String algorithm = fileName.substring(index + 1); - return getChecksumByAlgorithm(1, algorithm)[0]; - } - public static Checksum getChecksumByConf(SparkConf conf) throws SparkException { String algorithm = shuffleChecksumAlgorithm(conf); return getChecksumByAlgorithm(1, algorithm)[0]; @@ -115,24 +98,4 @@ public static String getChecksumFileName(ShuffleChecksumBlockId blockId, SparkCo // append the shuffle checksum algorithm as the file extension return String.format("%s.%s", blockId.name(), shuffleChecksumAlgorithm(conf)); } - - public static long readChecksumByReduceId(File checksumFile, int reduceId) throws IOException { - try (DataInputStream in = new DataInputStream(new FileInputStream(checksumFile))) { - in.skip(reduceId * 8L); - return in.readLong(); - } - } - - public static long calculateChecksumForPartition( - ShuffleBlockId blockId, - IndexShuffleBlockResolver resolver) throws IOException, SparkException { - InputStream in = resolver.getBlockData(blockId, Option.empty()).createInputStream(); - File checksumFile = resolver.getChecksumFile(blockId.shuffleId(), blockId.reduceId(), Option.empty()); - Checksum checksumAlgo = getChecksumByFileExtension(checksumFile.getName()); - byte[] buffer = new byte[CHECKSUM_CALCULATION_BUFFER]; - try(CheckedInputStream checksumIn = new CheckedInputStream(in, checksumAlgo)) { - while (checksumIn.read(buffer, 0, CHECKSUM_CALCULATION_BUFFER) != -1) {} - return checksumAlgo.getValue(); - } - } -} +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index daa60caf9e189..c33b65deef457 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -135,8 +135,9 @@ class NettyBlockRpcServer( } case diagnose: DiagnoseCorruption => - val cause = blockManager - .diagnoseShuffleBlockCorruption(BlockId.apply(diagnose.blockId), diagnose.checksum) + val cause = blockManager.diagnoseShuffleBlockCorruption( + ShuffleBlockId(diagnose.shuffleId, diagnose.mapId, diagnose.reduceId ), + diagnose.checksum) responseContext.onSuccess(new CorruptionCause(cause).toByteBuffer) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 600116d50eb37..e151b4c98d49f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -20,11 +20,9 @@ package org.apache.spark.network.netty import java.io.IOException import java.nio.ByteBuffer import java.util.{HashMap => JHashMap, Map => JMap} -import java.util.concurrent.TimeoutException import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} -import scala.concurrent.duration._ import scala.reflect.ClassTag import scala.util.{Success, Try} @@ -33,17 +31,20 @@ import com.codahale.metrics.{Metric, MetricSet} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.ExecutorDeadException import org.apache.spark.internal.config -import org.apache.spark.internal.config.Network import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap} -import org.apache.spark.network.corruption.Cause import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ +<<<<<<< HEAD import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockTransferListener, DownloadFileManager, OneForOneBlockFetcher, RetryingBlockTransferor} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, CorruptionCause, DiagnoseCorruption, UploadBlock, UploadBlockStream} +======= +import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, OneForOneBlockFetcher, RetryingBlockFetcher} +import org.apache.spark.network.shuffle.protocol.{UploadBlock, UploadBlockStream} +>>>>>>> update import org.apache.spark.network.util.JavaUtils -import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.storage.BlockManagerMessages.IsExecutorAlive @@ -65,7 +66,6 @@ private[spark] class NettyBlockTransferService( // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. private val serializer = new JavaSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ @@ -74,6 +74,7 @@ private[spark] class NettyBlockTransferService( val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager) var serverBootstrap: Option[TransportServerBootstrap] = None var clientBootstrap: Option[TransportClientBootstrap] = None + transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) if (authEnabled) { serverBootstrap = Some(new AuthServerBootstrap(transportConf, securityManager)) clientBootstrap = Some(new AuthClientBootstrap(transportConf, conf.getAppId, securityManager)) @@ -82,6 +83,7 @@ private[spark] class NettyBlockTransferService( clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava) server = createServer(serverBootstrap.toList) appId = conf.getAppId + logger.info(s"Server created on $hostName:${server.getPort}") } @@ -108,40 +110,6 @@ private[spark] class NettyBlockTransferService( } } - override def diagnoseCorruption( - host: String, - port: Int, - execId: String, - blockId: String, - checksum: Long): Cause = { - // A monitor for the thread to wait on. - val result = Promise[Cause]() - val client = clientFactory.createClient(host, port) - client.sendRpc(new DiagnoseCorruption(appId, execId, blockId, checksum).toByteBuffer, - new RpcResponseCallback { - override def onSuccess(response: ByteBuffer): Unit = { - val cause = BlockTransferMessage.Decoder - .fromByteBuffer(response).asInstanceOf[CorruptionCause] - result.success(cause.cause) - } - - override def onFailure(e: Throwable): Unit = { - logger.warn("Failed to get the corruption cause.", e) - result.success(Cause.UNKNOWN_ISSUE) - } - }) - val timeout = new RpcTimeout( - conf.get(Network.NETWORK_TIMEOUT).seconds, - Network.NETWORK_TIMEOUT.key) - try { - timeout.awaitResult(result.future) - } catch { - case _: TimeoutException => - logger.warn("Failed to get the corruption cause due to timeout.") - Cause.UNKNOWN_ISSUE - } - } - override def fetchBlocks( host: String, port: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 07928f8c52252..5a3372cf4c923 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -543,11 +543,15 @@ private[spark] class IndexShuffleBlockResolver( dirs: Option[Array[String]] = None): File = { val blockId = ShuffleChecksumBlockId(shuffleId, mapId, NOOP_REDUCE_ID) val fileName = ShuffleChecksumHelper.getChecksumFileName(blockId, conf) - dirs - .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, fileName)) - .getOrElse { - blockManager.diskBlockManager.getFile(fileName) - } + // We should use the blockId.name as the file name first to create the file so that + // readers (e.g., shuffle external service) without knowing the checksum algorithm + // could also find the file. + val file = dirs + .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name)) + .getOrElse(blockManager.diskBlockManager.getFile(blockId)) + + // Return the file with the checksum algorithm as extension + new File(file.getParentFile, fileName) } override def getBlockData( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d87c4368bc5de..576eb6786ad9a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -50,13 +50,13 @@ import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.network.corruption.Cause import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.checksum.ShuffleCorruptionDiagnosisHelper.diagnoseCorruption import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.{IndexShuffleBlockResolver, MigratableResolver, ShuffleManager, ShuffleWriteMetricsReporter} -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper._ import org.apache.spark.storage.BlockManagerMessages.{DecommissionBlockManager, ReplicateBlock} import org.apache.spark.storage.memory._ import org.apache.spark.unsafe.Platform @@ -308,29 +308,8 @@ private[spark] class BlockManager( val resolver = shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver] val checksumFile = resolver.getChecksumFile(shuffleBlock.shuffleId, shuffleBlock.mapId) val reduceId = shuffleBlock.reduceId - if (checksumFile.exists()) { - try { - val checksumByWriter = readChecksumByReduceId(checksumFile, reduceId) - val (checksumByReCalculation, t) = - Utils.timeTakenMs(calculateChecksumForPartition(shuffleBlock, resolver)) - logInfo(s"Checksum recalculation for shuffle block $shuffleBlock took $t ms") - if (checksumByWriter != checksumByReCalculation) { - Cause.DISK_ISSUE - } else if (checksumByWriter != checksumByReader) { - Cause.NETWORK_ISSUE - } else { - Cause.CHECKSUM_VERIFY_PASS - } - } catch { - case NonFatal(e) => - logWarning("Exception throws while diagnosing shuffle block corruption.", e) - Cause.UNKNOWN_ISSUE - } - } else { - // Even if checksum is enabled, a checksum file may not exist if error throws during writing. - logWarning(s"Checksum file ${checksumFile.getName} doesn't exit") - Cause.UNKNOWN_ISSUE - } + diagnoseCorruption( + checksumFile, reduceId, resolver.getBlockData(shuffleBlock), checksumByReader) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d819a06782566..9051067611ea6 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -38,9 +38,10 @@ import org.apache.spark.internal.{config, Logging} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.corruption.Cause import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.checksum.ShuffleCorruptionDiagnosisHelper import org.apache.spark.network.util.{NettyUtils, TransportConf} import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper +import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper.getChecksumByConf import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} /** @@ -796,7 +797,7 @@ final class ShuffleBlockFetcherIterator( val in = try { var bufIn = buf.createInputStream() if (checksumEnabled) { - val checksum = ShuffleChecksumHelper.getChecksumByConf(SparkEnv.get.conf) + val checksum = getChecksumByConf(SparkEnv.get.conf) checkedIn = new CheckedInputStream(bufIn, checksum) bufIn = checkedIn } @@ -1030,7 +1031,9 @@ final class ShuffleBlockFetcherIterator( blockId: BlockId): Cause = { logInfo("Start corruption diagnosis.") val startTimeNs = System.nanoTime() - val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER) + assert(blockId.isInstanceOf[ShuffleBlockId], s"Expected ShuffleBlockId, but got $blockId") + val shuffleBlock = blockId.asInstanceOf[ShuffleBlockId] + val buffer = new Array[Byte](ShuffleCorruptionDiagnosisHelper.CHECKSUM_CALCULATION_BUFFER) // consume the remaining data to calculate the checksum try { while (checkedIn.read(buffer, 0, 8192) != -1) {} @@ -1040,8 +1043,8 @@ final class ShuffleBlockFetcherIterator( return Cause.UNKNOWN_ISSUE } val checksum = checkedIn.getChecksum.getValue - val cause = shuffleClient.diagnoseCorruption( - address.host, address.port, address.executorId, blockId.toString, checksum) + val cause = shuffleClient.diagnoseCorruption(address.host, address.port, address.executorId, + shuffleBlock.shuffleId, shuffleBlock.mapId, shuffleBlock.reduceId, checksum) val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) logInfo(s"Finished corruption diagnosis in ${duration} ms, cause: $cause") cause diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 63220ed49f56c..43b898f940883 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -305,8 +305,7 @@ public void writeChecksumFileWithoutSpill() throws Exception { ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf)); File dataFile = new File(tempDir, "data"); File indexFile = new File(tempDir, "index"); - when(diskBlockManager.getFile(checksumFile.getName())) - .thenReturn(checksumFile); + when(diskBlockManager.getFile(checksumBlockId)).thenReturn(checksumFile); when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0))) .thenReturn(dataFile); when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0))) @@ -334,7 +333,7 @@ public void writeChecksumFileWithSpill() throws Exception { new File(tempDir, ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf)); File dataFile = new File(tempDir, "data"); File indexFile = new File(tempDir, "index"); - when(diskBlockManager.getFile(eq(checksumFile.getName()))).thenReturn(checksumFile); + when(diskBlockManager.getFile(checksumBlockId)).thenReturn(checksumFile); when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0))) .thenReturn(dataFile); when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0))) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 91a454e6a218e..9523b2a71e2a6 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -454,7 +454,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDi .set(config.SHUFFLE_CHECKSUM_ENABLED, true) .set(TEST_NO_STAGE_RETRY, false) .set("spark.stage.maxConsecutiveAttempts", "1") - .set(config.SHUFFLE_SERVICE_ENABLED, true) + .set(config.SHUFFLE_SERVICE_ENABLED, false) sc = new SparkContext("local-cluster[2, 1, 2048]", "test", conf) val rdd = sc.parallelize(1 to 10, 2).map((_, 1)).reduceByKey(_ + _) // materialize the shuffle map outputs diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala index a8f2c4088c422..38c6b28fbb28f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala @@ -20,8 +20,8 @@ package org.apache.spark.shuffle import java.io.{DataInputStream, File, FileInputStream} import java.util.zip.CheckedInputStream +import org.apache.spark.network.shuffle.checksum.ShuffleCorruptionDiagnosisHelper import org.apache.spark.network.util.LimitedInputStream -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper trait ShuffleChecksumTestHelper { @@ -55,7 +55,8 @@ trait ShuffleChecksumTestHelper { val curOffset = indexIn.readLong val limit = (curOffset - prevOffset).toInt val bytes = new Array[Byte](limit) - val checksumCal = ShuffleChecksumHelper.getChecksumByFileExtension(checksum.getName) + val checksumCal = + ShuffleCorruptionDiagnosisHelper.getChecksumByFileExtension(checksum.getName) checkedIn = new CheckedInputStream( new LimitedInputStream(dataIn, curOffset - prevOffset), checksumCal) checkedIn.read(bytes, 0, limit) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 39eef9749eac3..0697fbea4ca0c 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -253,7 +253,7 @@ class BypassMergeSortShuffleWriterSuite val dataFile = new File(tempDir, dataBlockId.name) val indexFile = new File(tempDir, indexBlockId.name) reset(diskBlockManager) - when(diskBlockManager.getFile(checksumFile.getName)).thenAnswer(_ => checksumFile) + when(diskBlockManager.getFile(checksumBlockId)).thenAnswer(_ => checksumFile) when(diskBlockManager.getFile(dataBlockId)).thenAnswer(_ => dataFile) when(diskBlockManager.getFile(indexBlockId)).thenAnswer(_ => indexFile) when(diskBlockManager.createTempShuffleBlock()) From 4c01069d60cffd2cdd6a34405e6fbd648c48b861 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 21 Jul 2021 12:15:01 +0800 Subject: [PATCH 03/48] move comment --- .../ShuffleCorruptionDiagnosisHelper.java | 19 +++++++++++++++++++ .../apache/spark/storage/BlockManager.scala | 17 ----------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java index 70d5b7df8de9c..3f042d6eacc97 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java @@ -88,6 +88,25 @@ private static long calculateChecksumForPartition( } } + /** + * Diagnose the possible cause of the shuffle data corruption by verify the shuffle checksums. + * + * There're 3 different kinds of checksums for the same shuffle partition: + * - checksum (c1) that calculated by the shuffle data reader + * - checksum (c2) that calculated by the shuffle data writer and stored in the checksum file + * - checksum (c3) that recalculated during diagnosis + * + * And the diagnosis mechanism works like this: + * If c2 != c3, we suspect the corruption is caused by the DISK_ISSUE. Otherwise, if c1 != c3, + * we suspect the corruption is caused by the NETWORK_ISSUE. Otherwise, the cause remains + * CHECKSUM_VERIFY_PASS. In case of the any other failures, the cause remains UNKNOWN_ISSUE. + * + * @param checksumFile The checksum file that written by the shuffle writer + * @param reduceId The reduceId of the shuffle block + * @param partitionData The partition data of the shuffle block + * @param checksumByReader The checksum value that calculated by the shuffle data reader + * @return The cause of data corruption + */ public static Cause diagnoseCorruption( File checksumFile, int reduceId, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 576eb6786ad9a..3715790c39e63 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -284,23 +284,6 @@ private[spark] class BlockManager( override def getLocalDiskDirs: Array[String] = diskBlockManager.localDirsString - /** - * Diagnose the possible cause of the shuffle data corruption by verify the shuffle checksums. - * - * There're 3 different kinds of checksums for the same shuffle partition: - * - checksum (c1) that calculated by the shuffle data reader - * - checksum (c2) that calculated by the shuffle data writer and stored in the checksum file - * - checksum (c3) that recalculated during diagnosis - * - * And the diagnosis mechanism works like this: - * If c2 != c3, we suspect the corruption is caused by the DISK_ISSUE. Otherwise, if c1 != c3, - * we suspect the corruption is caused by the NETWORK_ISSUE. Otherwise, the cause remains - * CHECKSUM_VERIFY_PASS. In case of the any other failures, the cause remains UNKNOWN_ISSUE. - * - * @param blockId The shuffle block Id - * @param checksumByReader The checksum value that calculated by the shuffle data reader - * @return The cause of data corruption - */ override def diagnoseShuffleBlockCorruption(blockId: BlockId, checksumByReader: Long): Cause = { assert(blockId.isInstanceOf[ShuffleBlockId], s"Corruption diagnosis only supports shuffle block yet, but got $blockId") From 7bd378885568801236003400ceeb050234c9bdaf Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 21 Jul 2021 13:12:01 +0800 Subject: [PATCH 04/48] fix ExternalBlockStoreClient --- .../apache/spark/network/shuffle/ExternalBlockStoreClient.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index f53215f744ad4..cfae74294a188 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -147,7 +147,7 @@ public void pushBlocks( assert inputListener instanceof BlockPushingListener : "Expecting a BlockPushingListener, but got " + inputListener.getClass(); TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockPusher(client, appId, transportConf.appAttemptId(), inputBlockId, + new OneForOneBlockPusher(client, appId, conf.appAttemptId(), inputBlockId, (BlockPushingListener) inputListener, buffersWithId).start(); } else { logger.info("This clientFactory was closed. Skipping further block push retries."); From 8eb667f2b4b8436d2218e8164dab2ec6b5b25382 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 10:26:21 +0800 Subject: [PATCH 05/48] mark Cause as private --- .../java/org/apache/spark/network/corruption/Cause.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java b/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java index 2019a0b842a8c..0544abfe95fe9 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java +++ b/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java @@ -17,6 +17,12 @@ package org.apache.spark.network.corruption; +import org.apache.spark.annotation.Private; + +/** + * The cause of shuffle data corruption. + */ +@Private public enum Cause { DISK_ISSUE, NETWORK_ISSUE, UNKNOWN_ISSUE, CHECKSUM_VERIFY_PASS } From 8015e867894981b1a3d13a1d305ddb728d882739 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 10:27:13 +0800 Subject: [PATCH 06/48] remove unused Cause import --- .../apache/spark/network/shuffle/ExternalBlockStoreClient.java | 1 - 1 file changed, 1 deletion(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index cfae74294a188..62353b590981e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -35,7 +35,6 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.corruption.Cause; import org.apache.spark.network.crypto.AuthClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; From b25914db6525039816d2add07d14fca1bc84151a Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 10:28:25 +0800 Subject: [PATCH 07/48] fix indents of BlockStoreClient.diagnoseCorruption --- .../spark/network/shuffle/BlockStoreClient.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index dc621cfa3c680..82a5666404c9d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -62,13 +62,13 @@ public abstract class BlockStoreClient implements Closeable { * @return The cause of the shuffle block corruption */ public Cause diagnoseCorruption( - String host, - int port, - String execId, - int shuffleId, - long mapId, - int reduceId, - long checksum) throws IOException, InterruptedException { + String host, + int port, + String execId, + int shuffleId, + long mapId, + int reduceId, + long checksum) throws IOException, InterruptedException { TransportClient client = clientFactory.createClient(host, port); try { ByteBuffer response = client.sendRpcSync( From 03b26fefc557d3348c8edc3dd72ac60eb183e38a Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 10:40:42 +0800 Subject: [PATCH 08/48] verify -> verifying --- .../shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java index 3f042d6eacc97..f91e28a7b1079 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java @@ -89,7 +89,7 @@ private static long calculateChecksumForPartition( } /** - * Diagnose the possible cause of the shuffle data corruption by verify the shuffle checksums. + * Diagnose the possible cause of the shuffle data corruption by verifying the shuffle checksums. * * There're 3 different kinds of checksums for the same shuffle partition: * - checksum (c1) that calculated by the shuffle data reader From 10d60f89b8b84f9c577879a8af4d5941e86198bf Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 10:41:56 +0800 Subject: [PATCH 09/48] fix comment --- .../shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java index f91e28a7b1079..44d3409511758 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java @@ -92,9 +92,9 @@ private static long calculateChecksumForPartition( * Diagnose the possible cause of the shuffle data corruption by verifying the shuffle checksums. * * There're 3 different kinds of checksums for the same shuffle partition: - * - checksum (c1) that calculated by the shuffle data reader - * - checksum (c2) that calculated by the shuffle data writer and stored in the checksum file - * - checksum (c3) that recalculated during diagnosis + * - checksum (c1) that is calculated by the shuffle data reader + * - checksum (c2) that is calculated by the shuffle data writer and stored in the checksum file + * - checksum (c3) that is recalculated during diagnosis * * And the diagnosis mechanism works like this: * If c2 != c3, we suspect the corruption is caused by the DISK_ISSUE. Otherwise, if c1 != c3, From b0b17ab389789f7b92adb2a17b9323590c52c713 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 10:47:07 +0800 Subject: [PATCH 10/48] include checksumfile path --- .../shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java index 44d3409511758..ab12cf59ce8e6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java @@ -120,7 +120,8 @@ public static Cause diagnoseCorruption( Checksum checksumAlgo = getChecksumByFileExtension(checksumFile.getName()); long checksumByReCalculation = calculateChecksumForPartition(partitionData, checksumAlgo); long duration = System.currentTimeMillis() - diagnoseStart; - logger.info("Shuffle corruption diagnosis took " + duration + " ms"); + logger.info("Shuffle corruption diagnosis took {} ms, checksum file {}", + duration, checksumFile.getAbsolutePath()); if (checksumByWriter != checksumByReCalculation) { cause = Cause.DISK_ISSUE; } else if (checksumByWriter != checksumByReader) { From c3dcef0558f25cce5f8223716ad1c9132e6481ea Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 11:16:15 +0800 Subject: [PATCH 11/48] remove throws --- .../shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java index ab12cf59ce8e6..60180d70e5a6f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java @@ -63,8 +63,7 @@ private static Checksum[] getChecksumByAlgorithm(int num, String algorithm) } } - public static Checksum getChecksumByFileExtension(String fileName) - throws UnsupportedOperationException { + public static Checksum getChecksumByFileExtension(String fileName) { int index = fileName.lastIndexOf("."); String algorithm = fileName.substring(index + 1); return getChecksumByAlgorithm(1, algorithm)[0]; From a8f15bd54c03e4047c9f16ff60ee71c5ffad184a Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 11:19:03 +0800 Subject: [PATCH 12/48] use ByteStreams.skipFully --- .../shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java index 60180d70e5a6f..03c2385a2adb7 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java @@ -23,6 +23,7 @@ import java.util.zip.CheckedInputStream; import java.util.zip.Checksum; +import com.google.common.io.ByteStreams; import org.apache.spark.annotation.Private; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.corruption.Cause; @@ -71,7 +72,7 @@ public static Checksum getChecksumByFileExtension(String fileName) { private static long readChecksumByReduceId(File checksumFile, int reduceId) throws IOException { try (DataInputStream in = new DataInputStream(new FileInputStream(checksumFile))) { - in.skip(reduceId * 8L); + ByteStreams.skipFully(in, reduceId * 8); return in.readLong(); } } From 815fdf6707b91175712bd59b6511f508ae656882 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 11:20:55 +0800 Subject: [PATCH 13/48] check cheaper fields first --- .../spark/network/shuffle/protocol/DiagnoseCorruption.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java index 8aa524c425312..f40ec69781e0b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java @@ -71,11 +71,11 @@ public boolean equals(Object o) { DiagnoseCorruption that = (DiagnoseCorruption) o; if (checksum != that.checksum) return false; - if (!appId.equals(that.appId)) return false; - if (!execId.equals(that.execId)) return false; if (shuffleId != that.shuffleId) return false; if (mapId != that.mapId) return false; if (reduceId != that.reduceId) return false; + if (!appId.equals(that.appId)) return false; + if (!execId.equals(that.execId)) return false; return true; } From f1933682816a2872e358d5011a6b9dc70c3786dc Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 11:21:45 +0800 Subject: [PATCH 14/48] use this.transportConf --- .../apache/spark/network/netty/NettyBlockTransferService.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index e151b4c98d49f..f8d7fcef38355 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -74,7 +74,7 @@ private[spark] class NettyBlockTransferService( val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager) var serverBootstrap: Option[TransportServerBootstrap] = None var clientBootstrap: Option[TransportClientBootstrap] = None - transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) + this.transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) if (authEnabled) { serverBootstrap = Some(new AuthServerBootstrap(transportConf, securityManager)) clientBootstrap = Some(new AuthClientBootstrap(transportConf, conf.getAppId, securityManager)) From dfb594da9e9522e89ab4b7c42496ba0feaa9a5b2 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 12:49:07 +0800 Subject: [PATCH 15/48] add todo for pushbased shuffle --- .../org/apache/spark/storage/ShuffleBlockFetcherIterator.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 9051067611ea6..b4b446f45f04b 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -837,6 +837,7 @@ final class ShuffleBlockFetcherIterator( } catch { case e: IOException => if (blockId.isShuffleChunk) { + // TODO (SPARK-36284): Add shuffle checksum support for push-based shuffle buf.release() // Retrying a corrupt block may result again in a corrupt block. For shuffle // chunks, we opt to fallback on the original shuffle blocks that belong to that From 30ed389d7ac71347009224ef04475eaad61123ec Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 12:50:22 +0800 Subject: [PATCH 16/48] resolve magic number --- .../org/apache/spark/storage/ShuffleBlockFetcherIterator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b4b446f45f04b..cdf67bb5fd4d4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -1037,7 +1037,7 @@ final class ShuffleBlockFetcherIterator( val buffer = new Array[Byte](ShuffleCorruptionDiagnosisHelper.CHECKSUM_CALCULATION_BUFFER) // consume the remaining data to calculate the checksum try { - while (checkedIn.read(buffer, 0, 8192) != -1) {} + while (checkedIn.read(buffer) != -1) {} } catch { case e: IOException => logWarning("IOException throws while consuming the rest stream of the corrupted block", e) From c70fa095c1e964c3584764253692f20656af1455 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 12:55:57 +0800 Subject: [PATCH 17/48] use fileName strip the suffix --- .../apache/spark/shuffle/IndexShuffleBlockResolver.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 5a3372cf4c923..d25fe10aac554 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -543,12 +543,13 @@ private[spark] class IndexShuffleBlockResolver( dirs: Option[Array[String]] = None): File = { val blockId = ShuffleChecksumBlockId(shuffleId, mapId, NOOP_REDUCE_ID) val fileName = ShuffleChecksumHelper.getChecksumFileName(blockId, conf) - // We should use the blockId.name as the file name first to create the file so that + val fileNameWithoutChecksum = fileName.substring(0, fileName.lastIndexOf('.')) + // We should use the file name without checksum first to create the file so that // readers (e.g., shuffle external service) without knowing the checksum algorithm - // could also find the file. + // can also find the file. val file = dirs - .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name)) - .getOrElse(blockManager.diskBlockManager.getFile(blockId)) + .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, fileNameWithoutChecksum)) + .getOrElse(blockManager.diskBlockManager.getFile(fileNameWithoutChecksum)) // Return the file with the checksum algorithm as extension new File(file.getParentFile, fileName) From 94576d1cdd67265cb3fec6f86261ffc94f43ed47 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 14:50:52 +0800 Subject: [PATCH 18/48] use Files.newDirectoryStream --- .../shuffle/ExternalShuffleBlockResolver.java | 40 +++++++++++++------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 656c337822c01..8dd450ced76f0 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -18,7 +18,10 @@ package org.apache.spark.network.shuffle; import java.io.*; +import java.nio.file.DirectoryStream; +import java.nio.file.Files; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; import java.util.*; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; @@ -388,18 +391,31 @@ public Cause diagnoseShuffleBlockCorruption( long checksumByReader) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); String fileName = "shuffle_" + shuffleId + "_" + mapId + "_0.checksum"; - File probeFile = ExecutorDiskUtils.getFile( - executor.localDirs, - executor.subDirsPerLocalDir, - fileName); - File parentFile = probeFile.getParentFile(); - - File[] checksumFiles = parentFile.listFiles(f -> f.getName().startsWith(fileName)); - assert checksumFiles.length == 1; - File checksumFile = checksumFiles[0]; - ManagedBuffer data = getBlockData(appId, execId, shuffleId, mapId, reduceId); - return ShuffleCorruptionDiagnosisHelper - .diagnoseCorruption(checksumFile, reduceId, data, checksumByReader); + try { + // This's consistent with `IndexShuffleBlockResolver.getChecksumFile`. + // We firstly use `fileName` to get the location of the checksum file. + // Then, we use `Files.newDirectoryStream` to list all the files under the same directory + // with `probeFile`. Since there's only one single checksum file for a certain map task, + // so it's supposed to return one matched file too. + File probeFile = ExecutorDiskUtils.getFile( + executor.localDirs, + executor.subDirsPerLocalDir, + fileName); + Path parentPath = probeFile.getParentFile().toPath(); + // we don't the exact checksum algorithm, so we have to list all the files here. + DirectoryStream stream = + Files.newDirectoryStream(parentPath, f -> f.getFileName().startsWith(fileName)); + Iterator pathIterator = stream.iterator(); + if (pathIterator.hasNext()) { + ManagedBuffer data = getBlockData(appId, execId, shuffleId, mapId, reduceId); + return ShuffleCorruptionDiagnosisHelper + .diagnoseCorruption(pathIterator.next().toFile(), reduceId, data, checksumByReader); + } else { + return Cause.UNKNOWN_ISSUE; + } + } catch (IOException e) { + return Cause.UNKNOWN_ISSUE; + } } /** Simply encodes an executor's full ID, which is appId + execId. */ From e606ca6134cd15c8124b5dce0b376b6e3589b6bd Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 14:55:04 +0800 Subject: [PATCH 19/48] remove checksum file existence check --- .../ShuffleCorruptionDiagnosisHelper.java | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java index 03c2385a2adb7..769386aeaaec7 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java @@ -113,30 +113,28 @@ public static Cause diagnoseCorruption( ManagedBuffer partitionData, long checksumByReader) { Cause cause; - if (checksumFile.exists()) { - try { - long diagnoseStart = System.currentTimeMillis(); - long checksumByWriter = readChecksumByReduceId(checksumFile, reduceId); - Checksum checksumAlgo = getChecksumByFileExtension(checksumFile.getName()); - long checksumByReCalculation = calculateChecksumForPartition(partitionData, checksumAlgo); - long duration = System.currentTimeMillis() - diagnoseStart; - logger.info("Shuffle corruption diagnosis took {} ms, checksum file {}", - duration, checksumFile.getAbsolutePath()); - if (checksumByWriter != checksumByReCalculation) { - cause = Cause.DISK_ISSUE; - } else if (checksumByWriter != checksumByReader) { - cause = Cause.NETWORK_ISSUE; - } else { - cause = Cause.CHECKSUM_VERIFY_PASS; - } - } catch (Exception e) { - logger.warn("Exception throws while diagnosing shuffle block corruption.", e); - cause = Cause.UNKNOWN_ISSUE; + try { + long diagnoseStart = System.currentTimeMillis(); + long checksumByWriter = readChecksumByReduceId(checksumFile, reduceId); + Checksum checksumAlgo = getChecksumByFileExtension(checksumFile.getName()); + long checksumByReCalculation = calculateChecksumForPartition(partitionData, checksumAlgo); + long duration = System.currentTimeMillis() - diagnoseStart; + logger.info("Shuffle corruption diagnosis took {} ms, checksum file {}", + duration, checksumFile.getAbsolutePath()); + if (checksumByWriter != checksumByReCalculation) { + cause = Cause.DISK_ISSUE; + } else if (checksumByWriter != checksumByReader) { + cause = Cause.NETWORK_ISSUE; + } else { + cause = Cause.CHECKSUM_VERIFY_PASS; } - } else { + } catch (FileNotFoundException e) { // Even if checksum is enabled, a checksum file may not exist if error throws during writing. logger.warn("Checksum file " + checksumFile.getName() + " doesn't exit"); cause = Cause.UNKNOWN_ISSUE; + } catch (Exception e) { + logger.warn("Exception throws while diagnosing shuffle block corruption.", e); + cause = Cause.UNKNOWN_ISSUE; } return cause; } From acf20cfb945958c8d2359af7634620802c9b6b57 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 20:13:20 +0800 Subject: [PATCH 20/48] combine ShuffleChecksumHelper&ShuffleCorruptionDiagnosisHelper --- .../spark/network/corruption/Cause.java | 3 - .../shuffle/ExternalShuffleBlockResolver.java | 6 +- ...Helper.java => ShuffleChecksumHelper.java} | 28 +++-- .../shuffle/ExternalBlockHandlerSuite.java | 4 +- .../checksum/ShuffleChecksumHelper.java | 101 ------------------ .../checksum/ShuffleChecksumSupport.java | 28 +++++ .../sort/BypassMergeSortShuffleWriter.java | 13 ++- .../shuffle/sort/ShuffleExternalSorter.java | 10 +- .../shuffle/sort/UnsafeShuffleWriter.java | 2 +- .../shuffle/IndexShuffleBlockResolver.scala | 5 +- .../apache/spark/storage/BlockManager.scala | 4 +- .../storage/ShuffleBlockFetcherIterator.scala | 8 +- .../util/collection/ExternalSorter.scala | 9 +- .../sort/UnsafeShuffleWriterSuite.java | 9 +- .../shuffle/ShuffleChecksumTestHelper.scala | 4 +- .../BypassMergeSortShuffleWriterSuite.scala | 7 +- 16 files changed, 91 insertions(+), 150 deletions(-) rename common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/{ShuffleCorruptionDiagnosisHelper.java => ShuffleChecksumHelper.java} (84%) delete mode 100644 core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumSupport.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java b/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java index 0544abfe95fe9..d2aa3edfce19a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java +++ b/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java @@ -17,12 +17,9 @@ package org.apache.spark.network.corruption; -import org.apache.spark.annotation.Private; - /** * The cause of shuffle data corruption. */ -@Private public enum Cause { DISK_ISSUE, NETWORK_ISSUE, UNKNOWN_ISSUE, CHECKSUM_VERIFY_PASS } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 8dd450ced76f0..19a3543436241 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -49,7 +49,7 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.corruption.Cause; -import org.apache.spark.network.shuffle.checksum.ShuffleCorruptionDiagnosisHelper; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.LevelDBProvider; import org.apache.spark.network.util.LevelDBProvider.StoreVersion; @@ -408,8 +408,8 @@ public Cause diagnoseShuffleBlockCorruption( Iterator pathIterator = stream.iterator(); if (pathIterator.hasNext()) { ManagedBuffer data = getBlockData(appId, execId, shuffleId, mapId, reduceId); - return ShuffleCorruptionDiagnosisHelper - .diagnoseCorruption(pathIterator.next().toFile(), reduceId, data, checksumByReader); + return ShuffleChecksumHelper.diagnoseCorruption( + pathIterator.next().toFile(), reduceId, data, checksumByReader); } else { return Cause.UNKNOWN_ISSUE; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java similarity index 84% rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java index 769386aeaaec7..a53d89e07e2cb 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleCorruptionDiagnosisHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java @@ -34,14 +34,19 @@ * A set of utility functions for the shuffle checksum. */ @Private -public class ShuffleCorruptionDiagnosisHelper { +public class ShuffleChecksumHelper { private static final Logger logger = - LoggerFactory.getLogger(ShuffleCorruptionDiagnosisHelper.class); + LoggerFactory.getLogger(ShuffleChecksumHelper.class); public static final int CHECKSUM_CALCULATION_BUFFER = 8192; + public static final Checksum[] EMPTY_CHECKSUM = new Checksum[0]; + public static final long[] EMPTY_CHECKSUM_VALUE = new long[0]; - private static Checksum[] getChecksumByAlgorithm(int num, String algorithm) - throws UnsupportedOperationException { + public static Checksum[] createPartitionChecksums(int numPartitions, String algorithm) { + return getChecksumsByAlgorithm(numPartitions, algorithm); + } + + private static Checksum[] getChecksumsByAlgorithm(int num, String algorithm) { Checksum[] checksums; switch (algorithm) { case "ADLER32": @@ -59,15 +64,24 @@ private static Checksum[] getChecksumByAlgorithm(int num, String algorithm) return checksums; default: - throw new UnsupportedOperationException("Unsupported shuffle checksum algorithm: " + - algorithm); + throw new UnsupportedOperationException( + "Unsupported shuffle checksum algorithm: " + algorithm); } } + public static Checksum getChecksumByAlgorithm(String algorithm) { + return getChecksumsByAlgorithm(1, algorithm)[0]; + } + + public static String getChecksumFileName(String blockName, String algorithm) { + // append the shuffle checksum algorithm as the file extension + return String.format("%s.%s", blockName, algorithm); + } + public static Checksum getChecksumByFileExtension(String fileName) { int index = fileName.lastIndexOf("."); String algorithm = fileName.substring(index + 1); - return getChecksumByAlgorithm(1, algorithm)[0]; + return getChecksumsByAlgorithm(1, algorithm)[0]; } private static long readChecksumByReduceId(File checksumFile, int reduceId) throws IOException { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 3098b17659af0..20c558eb1431a 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -27,7 +27,7 @@ import com.codahale.metrics.Timer; import com.google.common.io.Files; import org.apache.spark.network.corruption.Cause; -import org.apache.spark.network.shuffle.checksum.ShuffleCorruptionDiagnosisHelper; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; @@ -136,7 +136,7 @@ private void checkDiagnosisResult( // Checksum for the blockMarkers[0] using adler32 is 196609. when(blockResolver.getBlockData(appId, execId, shuffleId, mapId, reduceId)).thenReturn(blockMarkers[0]); - Cause actualCause = ShuffleCorruptionDiagnosisHelper.diagnoseCorruption(checksumFile, reduceId, + Cause actualCause = ShuffleChecksumHelper.diagnoseCorruption(checksumFile, reduceId, blockResolver.getBlockData(appId, execId, shuffleId, mapId, reduceId), checksumByReader); when(blockResolver .diagnoseShuffleBlockCorruption(appId, execId, shuffleId, mapId, reduceId, checksumByReader)) diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java deleted file mode 100644 index 98993d0e2a7e2..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.checksum; - -import java.util.zip.Adler32; -import java.util.zip.CRC32; -import java.util.zip.Checksum; - -import org.apache.spark.SparkConf; -import org.apache.spark.SparkException; -import org.apache.spark.annotation.Private; -import org.apache.spark.internal.config.package$; -import org.apache.spark.storage.ShuffleChecksumBlockId; - -/** - * A set of utility functions for the shuffle checksum. - */ -@Private -public class ShuffleChecksumHelper { - - /** - * Used when the checksum is disabled for shuffle. - */ - private static final Checksum[] EMPTY_CHECKSUM = new Checksum[0]; - public static final long[] EMPTY_CHECKSUM_VALUE = new long[0]; - - public static boolean isShuffleChecksumEnabled(SparkConf conf) { - return (boolean) conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ENABLED()); - } - - public static Checksum[] createPartitionChecksumsIfEnabled(int numPartitions, SparkConf conf) - throws SparkException { - if (!isShuffleChecksumEnabled(conf)) { - return EMPTY_CHECKSUM; - } - - String checksumAlgo = shuffleChecksumAlgorithm(conf); - return getChecksumByAlgorithm(numPartitions, checksumAlgo); - } - - private static Checksum[] getChecksumByAlgorithm(int num, String algorithm) - throws SparkException { - Checksum[] checksums; - switch (algorithm) { - case "ADLER32": - checksums = new Adler32[num]; - for (int i = 0; i < num; i++) { - checksums[i] = new Adler32(); - } - return checksums; - - case "CRC32": - checksums = new CRC32[num]; - for (int i = 0; i < num; i++) { - checksums[i] = new CRC32(); - } - return checksums; - - default: - throw new SparkException("Unsupported shuffle checksum algorithm: " + algorithm); - } - } - - public static long[] getChecksumValues(Checksum[] partitionChecksums) { - int numPartitions = partitionChecksums.length; - long[] checksumValues = new long[numPartitions]; - for (int i = 0; i < numPartitions; i++) { - checksumValues[i] = partitionChecksums[i].getValue(); - } - return checksumValues; - } - - public static String shuffleChecksumAlgorithm(SparkConf conf) { - return conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); - } - - public static Checksum getChecksumByConf(SparkConf conf) throws SparkException { - String algorithm = shuffleChecksumAlgorithm(conf); - return getChecksumByAlgorithm(1, algorithm)[0]; - } - - public static String getChecksumFileName(ShuffleChecksumBlockId blockId, SparkConf conf) { - // append the shuffle checksum algorithm as the file extension - return String.format("%s.%s", blockId.name(), shuffleChecksumAlgorithm(conf)); - } -} \ No newline at end of file diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumSupport.java b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumSupport.java new file mode 100644 index 0000000000000..b65dfa847d407 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumSupport.java @@ -0,0 +1,28 @@ +package org.apache.spark.shuffle.checksum; + +import java.util.zip.Checksum; + +import org.apache.spark.SparkConf; +import org.apache.spark.internal.config.package$; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; + +public interface ShuffleChecksumSupport { + + default Checksum[] createPartitionChecksums(int numPartitions, SparkConf conf) { + if ((boolean) conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ENABLED())) { + String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); + return ShuffleChecksumHelper.createPartitionChecksums(numPartitions, checksumAlgorithm); + } else { + return ShuffleChecksumHelper.EMPTY_CHECKSUM; + } + } + + default long[] getChecksumValues(Checksum[] partitionChecksums) { + int numPartitions = partitionChecksums.length; + long[] checksumValues = new long[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + checksumValues[i] = partitionChecksums[i].getValue(); + } + return checksumValues; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 322224053df09..53323b6eb817c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -40,10 +40,12 @@ import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.SparkException; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.apache.spark.shuffle.api.WritableByteChannelWrapper; +import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport; import org.apache.spark.internal.config.package$; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -51,7 +53,6 @@ import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.ShuffleWriter; -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -76,7 +77,7 @@ *

* There have been proposals to completely remove this code path; see SPARK-6026 for details. */ -final class BypassMergeSortShuffleWriter extends ShuffleWriter { +final class BypassMergeSortShuffleWriter extends ShuffleWriter implements ShuffleChecksumSupport { private static final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); @@ -125,8 +126,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); this.shuffleExecutorComponents = shuffleExecutorComponents; - this.partitionChecksums = - ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf); + this.partitionChecksums = createPartitionChecksums(numPartitions, conf); } @Override @@ -230,9 +230,8 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro } partitionWriters = null; } - return mapOutputWriter.commitAllPartitions( - ShuffleChecksumHelper.getChecksumValues(partitionChecksums) - ).getPartitionLengths(); + return mapOutputWriter.commitAllPartitions(getChecksumValues(partitionChecksums)) + .getPartitionLengths(); } private void writePartitionedDataWithChannel( diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 0307027c6f264..ea08f77c3141c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -38,10 +38,11 @@ import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TooLargePageException; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper; +import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.FileSegment; @@ -68,7 +69,7 @@ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a * specialized merge procedure that avoids extra serialization/deserialization. */ -final class ShuffleExternalSorter extends MemoryConsumer { +final class ShuffleExternalSorter extends MemoryConsumer implements ShuffleChecksumSupport { private static final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class); @@ -139,12 +140,11 @@ final class ShuffleExternalSorter extends MemoryConsumer { this.peakMemoryUsedBytes = getMemoryUsage(); this.diskWriteBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE()); - this.partitionChecksums = - ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf); + this.partitionChecksums = createPartitionChecksums(numPartitions, conf); } public long[] getChecksums() { - return ShuffleChecksumHelper.getChecksumValues(partitionChecksums); + return getChecksumValues(partitionChecksums); } /** diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 2659b172bf68c..b1779a135b786 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -45,6 +45,7 @@ import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -57,7 +58,6 @@ import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; import org.apache.spark.shuffle.api.WritableByteChannelWrapper; -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d25fe10aac554..b5076d2461eb5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -31,9 +31,9 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.{ExecutorDiskUtils, MergedBlockMeta} +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -542,7 +542,8 @@ private[spark] class IndexShuffleBlockResolver( mapId: Long, dirs: Option[Array[String]] = None): File = { val blockId = ShuffleChecksumBlockId(shuffleId, mapId, NOOP_REDUCE_ID) - val fileName = ShuffleChecksumHelper.getChecksumFileName(blockId, conf) + val fileName = ShuffleChecksumHelper.getChecksumFileName( + blockId.name, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)) val fileNameWithoutChecksum = fileName.substring(0, fileName.lastIndexOf('.')) // We should use the file name without checksum first to create the file so that // readers (e.g., shuffle external service) without knowing the checksum algorithm diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3715790c39e63..627304d56a618 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -50,7 +50,7 @@ import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.network.corruption.Cause import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle._ -import org.apache.spark.network.shuffle.checksum.ShuffleCorruptionDiagnosisHelper.diagnoseCorruption +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv @@ -291,7 +291,7 @@ private[spark] class BlockManager( val resolver = shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver] val checksumFile = resolver.getChecksumFile(shuffleBlock.shuffleId, shuffleBlock.mapId) val reduceId = shuffleBlock.reduceId - diagnoseCorruption( + ShuffleChecksumHelper.diagnoseCorruption( checksumFile, reduceId, resolver.getBlockData(shuffleBlock), checksumByReader) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index cdf67bb5fd4d4..67e9fd2b2969b 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -38,10 +38,9 @@ import org.apache.spark.internal.{config, Logging} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.corruption.Cause import org.apache.spark.network.shuffle._ -import org.apache.spark.network.shuffle.checksum.ShuffleCorruptionDiagnosisHelper +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.network.util.{NettyUtils, TransportConf} import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper.getChecksumByConf import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} /** @@ -797,7 +796,8 @@ final class ShuffleBlockFetcherIterator( val in = try { var bufIn = buf.createInputStream() if (checksumEnabled) { - val checksum = getChecksumByConf(SparkEnv.get.conf) + val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm( + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)) checkedIn = new CheckedInputStream(bufIn, checksum) bufIn = checkedIn } @@ -1034,7 +1034,7 @@ final class ShuffleBlockFetcherIterator( val startTimeNs = System.nanoTime() assert(blockId.isInstanceOf[ShuffleBlockId], s"Expected ShuffleBlockId, but got $blockId") val shuffleBlock = blockId.asInstanceOf[ShuffleBlockId] - val buffer = new Array[Byte](ShuffleCorruptionDiagnosisHelper.CHECKSUM_CALCULATION_BUFFER) + val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER) // consume the remaining data to calculate the checksum try { while (checkedIn.read(buffer) != -1) {} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index c63e196ddc814..eda408afa7ce5 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -31,7 +31,7 @@ import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ import org.apache.spark.shuffle.ShufflePartitionPairsWriter import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper +import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} import org.apache.spark.util.{CompletionIterator, Utils => TryUtils} @@ -97,7 +97,7 @@ private[spark] class ExternalSorter[K, V, C]( ordering: Option[Ordering[K]] = None, serializer: Serializer = SparkEnv.get.serializer) extends Spillable[WritablePartitionedPairCollection[K, C]](context.taskMemoryManager()) - with Logging { + with Logging with ShuffleChecksumSupport { private val conf = SparkEnv.get.conf @@ -142,10 +142,9 @@ private[spark] class ExternalSorter[K, V, C]( private val forceSpillFiles = new ArrayBuffer[SpilledFile] @volatile private var readingIterator: SpillableIterator = null - private val partitionChecksums = - ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf) + private val partitionChecksums = createPartitionChecksums(numPartitions, conf) - def getChecksums: Array[Long] = ShuffleChecksumHelper.getChecksumValues(partitionChecksums) + def getChecksums: Array[Long] = getChecksumValues(partitionChecksums) // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 43b898f940883..4da83759a0e79 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -23,8 +23,8 @@ import java.util.*; import org.apache.spark.*; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.shuffle.ShuffleChecksumTestHelper; -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper; import org.mockito.stubbing.Answer; import scala.*; import scala.collection.Iterator; @@ -38,6 +38,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.internal.config.package$; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.LZ4CompressionCodec; import org.apache.spark.io.LZFCompressionCodec; @@ -302,7 +303,8 @@ public void writeChecksumFileWithoutSpill() throws Exception { ShuffleChecksumBlockId checksumBlockId = new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID()); File checksumFile = new File(tempDir, - ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf)); + ShuffleChecksumHelper.getChecksumFileName( + checksumBlockId.name(), conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()))); File dataFile = new File(tempDir, "data"); File indexFile = new File(tempDir, "index"); when(diskBlockManager.getFile(checksumBlockId)).thenReturn(checksumFile); @@ -330,7 +332,8 @@ public void writeChecksumFileWithSpill() throws Exception { ShuffleChecksumBlockId checksumBlockId = new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID()); File checksumFile = - new File(tempDir, ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf)); + new File(tempDir, ShuffleChecksumHelper.getChecksumFileName( + checksumBlockId.name(), conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()))); File dataFile = new File(tempDir, "data"); File indexFile = new File(tempDir, "index"); when(diskBlockManager.getFile(checksumBlockId)).thenReturn(checksumFile); diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala index 38c6b28fbb28f..30c8026ef10c0 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala @@ -20,7 +20,7 @@ package org.apache.spark.shuffle import java.io.{DataInputStream, File, FileInputStream} import java.util.zip.CheckedInputStream -import org.apache.spark.network.shuffle.checksum.ShuffleCorruptionDiagnosisHelper +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.network.util.LimitedInputStream trait ShuffleChecksumTestHelper { @@ -56,7 +56,7 @@ trait ShuffleChecksumTestHelper { val limit = (curOffset - prevOffset).toInt val bytes = new Array[Byte](limit) val checksumCal = - ShuffleCorruptionDiagnosisHelper.getChecksumByFileExtension(checksum.getName) + ShuffleChecksumHelper.getChecksumByFileExtension(checksum.getName) checkedIn = new CheckedInputStream( new LimitedInputStream(dataIn, curOffset - prevOffset), checksumCal) checkedIn.read(bytes, 0, limit) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 0697fbea4ca0c..c181ac173bc71 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -31,11 +31,12 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark._ import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.internal.config import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleChecksumTestHelper} import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -248,8 +249,8 @@ class BypassMergeSortShuffleWriterSuite val checksumBlockId = ShuffleChecksumBlockId(shuffleId, mapId, 0) val dataBlockId = ShuffleDataBlockId(shuffleId, mapId, 0) val indexBlockId = ShuffleIndexBlockId(shuffleId, mapId, 0) - val checksumFile = new File(tempDir, - ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf)) + val checksumFile = new File(tempDir, ShuffleChecksumHelper.getChecksumFileName( + checksumBlockId.name, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM))) val dataFile = new File(tempDir, dataBlockId.name) val indexFile = new File(tempDir, indexBlockId.name) reset(diskBlockManager) From 895aad59f0a543439575e5e0291978df978a4ebc Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 20:21:53 +0800 Subject: [PATCH 21/48] fix tests --- .../apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java | 4 ++-- .../shuffle/sort/BypassMergeSortShuffleWriterSuite.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 4da83759a0e79..be3c9a4199793 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -307,7 +307,7 @@ public void writeChecksumFileWithoutSpill() throws Exception { checksumBlockId.name(), conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()))); File dataFile = new File(tempDir, "data"); File indexFile = new File(tempDir, "index"); - when(diskBlockManager.getFile(checksumBlockId)).thenReturn(checksumFile); + when(diskBlockManager.getFile(checksumBlockId.name())).thenReturn(checksumFile); when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0))) .thenReturn(dataFile); when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0))) @@ -336,7 +336,7 @@ public void writeChecksumFileWithSpill() throws Exception { checksumBlockId.name(), conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()))); File dataFile = new File(tempDir, "data"); File indexFile = new File(tempDir, "index"); - when(diskBlockManager.getFile(checksumBlockId)).thenReturn(checksumFile); + when(diskBlockManager.getFile(checksumBlockId.name())).thenReturn(checksumFile); when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0))) .thenReturn(dataFile); when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0))) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index c181ac173bc71..a3b0830349029 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -254,7 +254,7 @@ class BypassMergeSortShuffleWriterSuite val dataFile = new File(tempDir, dataBlockId.name) val indexFile = new File(tempDir, indexBlockId.name) reset(diskBlockManager) - when(diskBlockManager.getFile(checksumBlockId)).thenAnswer(_ => checksumFile) + when(diskBlockManager.getFile(checksumBlockId.name)).thenAnswer(_ => checksumFile) when(diskBlockManager.getFile(dataBlockId)).thenAnswer(_ => dataFile) when(diskBlockManager.getFile(indexBlockId)).thenAnswer(_ => indexFile) when(diskBlockManager.createTempShuffleBlock()) From 91610b0f8380f6085344f73e1e30a19b0d4de61d Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 22:29:35 +0800 Subject: [PATCH 22/48] diagnose corruption when the block corrupted twice --- .../storage/ShuffleBlockFetcherIterator.scala | 49 ++++++++++--------- .../scala/org/apache/spark/ShuffleSuite.scala | 2 +- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 67e9fd2b2969b..893bfe1a9ec6a 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -836,9 +836,15 @@ final class ShuffleBlockFetcherIterator( } } catch { case e: IOException => + // When shuffle checksum is enabled, for a block that is corrupted twice, + // we'd calculate the checksum of the block by consuming the remaining data + // in the buf. So, we should release the buf later. + if (!(checksumEnabled && corruptedBlocks.contains(blockId))) { + buf.release() + } + if (blockId.isShuffleChunk) { // TODO (SPARK-36284): Add shuffle checksum support for push-based shuffle - buf.release() // Retrying a corrupt block may result again in a corrupt block. For shuffle // chunks, we opt to fallback on the original shuffle blocks that belong to that // corrupt shuffle chunk immediately instead of retrying to fetch the corrupt @@ -849,32 +855,27 @@ final class ShuffleBlockFetcherIterator( pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) // Set result to null to trigger another iteration of the while loop. result = null - } else { - if (buf.isInstanceOf[FileSegmentManagedBuffer] - || corruptedBlocks.contains(blockId)) { + } else if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else if (corruptedBlocks.contains(blockId)) { + // It's the second time this block is detected corrupted + if (checksumEnabled) { + // Diagnose the cause of data corruption if shuffle checksum is enabled + val cause = diagnoseCorruption(checkedIn, address, blockId) buf.release() - throwFetchFailedException(blockId, mapIndex, address, e) + val errorMsg = s"Block $blockId is corrupted due to $cause." + logError(errorMsg) + throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) } else { - logWarning(s"Got an corrupted block $blockId from $address", e) - // A disk issue indicates the data on disk has already corrupted, so it's - // meaningless to retry on this case. We'll give a retry in the case of - // network issue and other unknown issues (in order to keep the same - // behavior as previously) - val allowRetry = !checksumEnabled || - diagnoseCorruption(checkedIn, address, blockId) != Cause.DISK_ISSUE - buf.release() - if (allowRetry) { - logInfo(s"Will retry the block $blockId") - corruptedBlocks += blockId - fetchRequests += FetchRequest( - address, Array(FetchBlockInfo(blockId, size, mapIndex))) - result = null - } else { - logError(s"Block $blockId is corrupted due to disk issue, won't retry.") - throwFetchFailedException(blockId, mapIndex, address, e, - Some(s"Block $blockId is corrupted due to disk issue")) - } + throwFetchFailedException(blockId, mapIndex, address, e) } + } else { + // It's the first time this block is detected corrupted + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null } } finally { if (blockId.isShuffleChunk) { diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 9523b2a71e2a6..2b558f00df8df 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -449,7 +449,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDi } } - test("SPARK-18188: shuffle checksum detect disk corruption") { + test("SPARK-36206: shuffle checksum detect disk corruption") { conf .set(config.SHUFFLE_CHECKSUM_ENABLED, true) .set(TEST_NO_STAGE_RETRY, false) From 6e5d2c093b5c1fbf0e6c0956f495796eab7b6bde Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 23:44:13 +0800 Subject: [PATCH 23/48] send checksum algorithm together --- .../spark/network/corruption/Cause.java | 2 +- .../network/shuffle/BlockStoreClient.java | 6 +- .../network/shuffle/ExternalBlockHandler.java | 2 +- .../shuffle/ExternalShuffleBlockResolver.java | 36 +++------- .../checksum/ShuffleChecksumHelper.java | 7 +- .../shuffle/protocol/DiagnoseCorruption.java | 15 ++++- .../shuffle/ExternalBlockHandlerSuite.java | 66 +++++++++++++++---- .../spark/network/BlockDataManager.scala | 5 +- .../network/netty/NettyBlockRpcServer.scala | 3 +- .../shuffle/IndexShuffleBlockResolver.scala | 22 +++---- .../apache/spark/storage/BlockManager.scala | 15 ++++- .../storage/ShuffleBlockFetcherIterator.scala | 3 +- .../sort/IndexShuffleBlockResolverSuite.scala | 2 +- .../shuffle/sort/SortShuffleWriterSuite.scala | 4 +- 14 files changed, 118 insertions(+), 70 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java b/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java index d2aa3edfce19a..0e068438a13a6 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java +++ b/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java @@ -21,5 +21,5 @@ * The cause of shuffle data corruption. */ public enum Cause { - DISK_ISSUE, NETWORK_ISSUE, UNKNOWN_ISSUE, CHECKSUM_VERIFY_PASS + DISK_ISSUE, NETWORK_ISSUE, UNKNOWN_ISSUE, CHECKSUM_VERIFY_PASS, UNSUPPORTED_CHECKSUM_ALGORITHM } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index 82a5666404c9d..5e52b884c90b7 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -68,11 +68,13 @@ public Cause diagnoseCorruption( int shuffleId, long mapId, int reduceId, - long checksum) throws IOException, InterruptedException { + long checksum, + String algorithm) throws IOException, InterruptedException { TransportClient client = clientFactory.createClient(host, port); try { ByteBuffer response = client.sendRpcSync( - new DiagnoseCorruption(appId, execId, shuffleId, mapId, reduceId, checksum).toByteBuffer(), + new DiagnoseCorruption(appId, execId, shuffleId, mapId, reduceId, checksum, algorithm) + .toByteBuffer(), transportConf.connectionTimeoutMs() ); CorruptionCause cause = diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index f942ab3b8d9b2..830622ce51909 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -228,7 +228,7 @@ protected void handleMessage( DiagnoseCorruption msg = (DiagnoseCorruption) msgObj; checkAuth(client, msg.appId); Cause cause = blockManager.diagnoseShuffleBlockCorruption( - msg.appId, msg.execId, msg.shuffleId, msg.mapId, msg.reduceId, msg.checksum); + msg.appId, msg.execId, msg.shuffleId, msg.mapId, msg.reduceId, msg.checksum, msg.algorithm); // In any cases of the error, diagnoseShuffleBlockCorruption should return UNKNOWN_ISSUE, // so it should always reply as success. callback.onSuccess(new CorruptionCause(cause).toByteBuffer()); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 19a3543436241..7fc2866375c39 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -388,34 +388,16 @@ public Cause diagnoseShuffleBlockCorruption( int shuffleId, long mapId, int reduceId, - long checksumByReader) { + long checksumByReader, + String algorithm) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); - String fileName = "shuffle_" + shuffleId + "_" + mapId + "_0.checksum"; - try { - // This's consistent with `IndexShuffleBlockResolver.getChecksumFile`. - // We firstly use `fileName` to get the location of the checksum file. - // Then, we use `Files.newDirectoryStream` to list all the files under the same directory - // with `probeFile`. Since there's only one single checksum file for a certain map task, - // so it's supposed to return one matched file too. - File probeFile = ExecutorDiskUtils.getFile( - executor.localDirs, - executor.subDirsPerLocalDir, - fileName); - Path parentPath = probeFile.getParentFile().toPath(); - // we don't the exact checksum algorithm, so we have to list all the files here. - DirectoryStream stream = - Files.newDirectoryStream(parentPath, f -> f.getFileName().startsWith(fileName)); - Iterator pathIterator = stream.iterator(); - if (pathIterator.hasNext()) { - ManagedBuffer data = getBlockData(appId, execId, shuffleId, mapId, reduceId); - return ShuffleChecksumHelper.diagnoseCorruption( - pathIterator.next().toFile(), reduceId, data, checksumByReader); - } else { - return Cause.UNKNOWN_ISSUE; - } - } catch (IOException e) { - return Cause.UNKNOWN_ISSUE; - } + String fileName = "shuffle_" + shuffleId + "_" + mapId + "_0.checksum." + algorithm; + File checksumFile = ExecutorDiskUtils.getFile( + executor.localDirs, + executor.subDirsPerLocalDir, + fileName); + ManagedBuffer data = getBlockData(appId, execId, shuffleId, mapId, reduceId); + return ShuffleChecksumHelper.diagnoseCorruption(checksumFile, reduceId, data, checksumByReader); } /** Simply encodes an executor's full ID, which is appId + execId. */ diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java index a53d89e07e2cb..bffcde7870db2 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java @@ -129,8 +129,11 @@ public static Cause diagnoseCorruption( Cause cause; try { long diagnoseStart = System.currentTimeMillis(); - long checksumByWriter = readChecksumByReduceId(checksumFile, reduceId); + // Try to get the checksum instance before reading the checksum file so that + // `UnsupportedOperationException` can be thrown first before `FileNotFoundException` + // when the checksum algorithm isn't supported. Checksum checksumAlgo = getChecksumByFileExtension(checksumFile.getName()); + long checksumByWriter = readChecksumByReduceId(checksumFile, reduceId); long checksumByReCalculation = calculateChecksumForPartition(partitionData, checksumAlgo); long duration = System.currentTimeMillis() - diagnoseStart; logger.info("Shuffle corruption diagnosis took {} ms, checksum file {}", @@ -142,6 +145,8 @@ public static Cause diagnoseCorruption( } else { cause = Cause.CHECKSUM_VERIFY_PASS; } + } catch (UnsupportedOperationException e) { + cause = Cause.UNSUPPORTED_CHECKSUM_ALGORITHM; } catch (FileNotFoundException e) { // Even if checksum is enabled, a checksum file may not exist if error throws during writing. logger.warn("Checksum file " + checksumFile.getName() + " doesn't exit"); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java index f40ec69781e0b..620b5ad71cd75 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java @@ -30,6 +30,7 @@ public class DiagnoseCorruption extends BlockTransferMessage { public final long mapId; public final int reduceId; public final long checksum; + public final String algorithm; public DiagnoseCorruption( String appId, @@ -37,13 +38,15 @@ public DiagnoseCorruption( int shuffleId, long mapId, int reduceId, - long checksum) { + long checksum, + String algorithm) { this.appId = appId; this.execId = execId; this.shuffleId = shuffleId; this.mapId = mapId; this.reduceId = reduceId; this.checksum = checksum; + this.algorithm = algorithm; } @Override @@ -60,6 +63,7 @@ public String toString() { .append("mapId", mapId) .append("reduceId", reduceId) .append("checksum", checksum) + .append("algorithm", algorithm) .toString(); } @@ -74,6 +78,7 @@ public boolean equals(Object o) { if (shuffleId != that.shuffleId) return false; if (mapId != that.mapId) return false; if (reduceId != that.reduceId) return false; + if (!algorithm.equals(that.algorithm)) return false; if (!appId.equals(that.appId)) return false; if (!execId.equals(that.execId)) return false; return true; @@ -87,6 +92,7 @@ public int hashCode() { result = 31 * result + Long.hashCode(mapId); result = 31 * result + Integer.hashCode(reduceId); result = 31 * result + Long.hashCode(checksum); + result = 31 * result + algorithm.hashCode(); return result; } @@ -97,7 +103,8 @@ public int encodedLength() { + 4 /* encoded length of shuffleId */ + 8 /* encoded length of mapId */ + 4 /* encoded length of reduceId */ - + 8; /* encoded length of checksum */ + + 8 /* encoded length of checksum */ + + Encoders.Strings.encodedLength(algorithm); /* encoded length of algorithm */ } @Override @@ -108,6 +115,7 @@ public void encode(ByteBuf buf) { buf.writeLong(mapId); buf.writeInt(reduceId); buf.writeLong(checksum); + Encoders.Strings.encode(buf, algorithm); } public static DiagnoseCorruption decode(ByteBuf buf) { @@ -117,6 +125,7 @@ public static DiagnoseCorruption decode(ByteBuf buf) { long mapId = buf.readLong(); int reduceId = buf.readInt(); long checksum = buf.readLong(); - return new DiagnoseCorruption(appId, execId, shuffleId, mapId, reduceId, checksum); + String algorithm = Encoders.Strings.decode(buf); + return new DiagnoseCorruption(appId, execId, shuffleId, mapId, reduceId, checksum, algorithm); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 20c558eb1431a..8d3bafb8f1f0d 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -21,10 +21,13 @@ import java.nio.ByteBuffer; import java.util.Iterator; import java.util.Map; +import java.util.zip.CheckedInputStream; +import java.util.zip.Checksum; import com.codahale.metrics.Meter; import com.codahale.metrics.Metric; import com.codahale.metrics.Timer; +import com.google.common.io.ByteStreams; import com.google.common.io.Files; import org.apache.spark.network.corruption.Cause; import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; @@ -114,8 +117,7 @@ public void testCompatibilityWithOldVersion() { } private void checkDiagnosisResult( - long checksumByReader, - long checksumByWriter, + String algorithm, Cause expectedCaused) throws IOException { String appId = "app0"; String execId = "execId"; @@ -125,28 +127,54 @@ private void checkDiagnosisResult( // prepare the checksum file File tmpDir = Files.createTempDir(); - tmpDir.deleteOnExit(); File checksumFile = new File(tmpDir, - "shuffle_" + shuffleId +"_" + mapId + "_" + reduceId + ".checksum.ADLER32"); + "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".checksum." + algorithm); DataOutputStream out = new DataOutputStream(new FileOutputStream(checksumFile)); - if (checksumByWriter != 0) { - out.writeLong(checksumByWriter); + long checksumByReader = 0L; + if (expectedCaused != Cause.UNSUPPORTED_CHECKSUM_ALGORITHM) { + Checksum checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(algorithm); + CheckedInputStream checkedIn = new CheckedInputStream( + blockMarkers[0].createInputStream(), checksum); + byte[] buffer = new byte[10]; + ByteStreams.readFully(checkedIn, buffer, 0, (int) blockMarkers[0].size()); + long checksumByWriter = checkedIn.getChecksum().getValue(); + + switch (expectedCaused) { + case DISK_ISSUE: + out.writeLong(-checksumByWriter); + checksumByReader = checksumByWriter; + break; + + case NETWORK_ISSUE: + out.writeLong(checksumByWriter); + checksumByReader = -1 * checksumByWriter; + break; + + case UNKNOWN_ISSUE: + // write a int instead of a long to corrupt the checksum file + out.writeInt(0); + checksumByReader = checksumByWriter; + break; + + default: + out.writeLong(checksumByWriter); + checksumByReader = checksumByWriter; + } } out.close(); - // Checksum for the blockMarkers[0] using adler32 is 196609. when(blockResolver.getBlockData(appId, execId, shuffleId, mapId, reduceId)).thenReturn(blockMarkers[0]); Cause actualCause = ShuffleChecksumHelper.diagnoseCorruption(checksumFile, reduceId, blockResolver.getBlockData(appId, execId, shuffleId, mapId, reduceId), checksumByReader); when(blockResolver - .diagnoseShuffleBlockCorruption(appId, execId, shuffleId, mapId, reduceId, checksumByReader)) + .diagnoseShuffleBlockCorruption(appId, execId, shuffleId, mapId, reduceId, checksumByReader, algorithm)) .thenReturn(actualCause); when(client.getClientId()).thenReturn(appId); RpcResponseCallback callback = mock(RpcResponseCallback.class); DiagnoseCorruption diagnoseMsg = new DiagnoseCorruption( - appId, execId, shuffleId, mapId, reduceId, checksumByReader); + appId, execId, shuffleId, mapId, reduceId, checksumByReader, algorithm); handler.receive(client, diagnoseMsg.toByteBuffer(), callback); ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); @@ -156,27 +184,37 @@ private void checkDiagnosisResult( CorruptionCause cause = (CorruptionCause) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); assertEquals(expectedCaused, cause.cause); + tmpDir.delete(); } @Test public void testShuffleCorruptionDiagnosisDiskIssue() throws IOException { - checkDiagnosisResult(1, 1, Cause.DISK_ISSUE); + checkDiagnosisResult( "ADLER32", Cause.DISK_ISSUE); } @Test public void testShuffleCorruptionDiagnosisNetworkIssue() throws IOException { - checkDiagnosisResult(1, 196609, Cause.NETWORK_ISSUE); + checkDiagnosisResult("ADLER32", Cause.NETWORK_ISSUE); } @Test public void testShuffleCorruptionDiagnosisUnknownIssue() throws IOException { - // Use checksumByWriter=0 to create the invalid checksum file - checkDiagnosisResult(196609, 0, Cause.UNKNOWN_ISSUE); + checkDiagnosisResult("ADLER32", Cause.UNKNOWN_ISSUE); } @Test public void testShuffleCorruptionDiagnosisChecksumVerifyPass() throws IOException { - checkDiagnosisResult(196609, 196609, Cause.CHECKSUM_VERIFY_PASS); + checkDiagnosisResult("ADLER32", Cause.CHECKSUM_VERIFY_PASS); + } + + @Test + public void testShuffleCorruptionDiagnosisUnSupportedAlgorithm() throws IOException { + checkDiagnosisResult("XXX", Cause.UNSUPPORTED_CHECKSUM_ALGORITHM); + } + + @Test + public void testShuffleCorruptionDiagnosisCRC32() throws IOException { + checkDiagnosisResult("CRC32", Cause.CHECKSUM_VERIFY_PASS); } @Test diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index d35f6770f497e..594d88cf16c69 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -31,7 +31,10 @@ trait BlockDataManager { /** * Diagnose the possible cause of the shuffle data corruption by verify the shuffle checksums */ - def diagnoseShuffleBlockCorruption(blockId: BlockId, checksumByReader: Long): Cause + def diagnoseShuffleBlockCorruption( + blockId: BlockId, + checksumByReader: Long, + algorithm: String): Cause /** * Get the local directories that used by BlockManager to save the blocks to disk diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index c33b65deef457..81c878d17c695 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -137,7 +137,8 @@ class NettyBlockRpcServer( case diagnose: DiagnoseCorruption => val cause = blockManager.diagnoseShuffleBlockCorruption( ShuffleBlockId(diagnose.shuffleId, diagnose.mapId, diagnose.reduceId ), - diagnose.checksum) + diagnose.checksum, + diagnose.algorithm) responseContext.onSuccess(new CorruptionCause(cause).toByteBuffer) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index b5076d2461eb5..7454a74094541 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -157,7 +157,7 @@ private[spark] class IndexShuffleBlockResolver( logWarning(s"Error deleting index ${file.getPath()}") } - file = getChecksumFile(shuffleId, mapId) + file = getChecksumFile(shuffleId, mapId, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)) if (file.exists() && !file.delete()) { logWarning(s"Error deleting checksum ${file.getPath()}") } @@ -339,7 +339,8 @@ private[spark] class IndexShuffleBlockResolver( val (checksumFileOpt, checksumTmpOpt) = if (checksumEnabled) { assert(lengths.length == checksums.length, "The size of partition lengths and checksums should be equal") - val checksumFile = getChecksumFile(shuffleId, mapId) + val checksumFile = + getChecksumFile(shuffleId, mapId, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)) (Some(checksumFile), Some(Utils.tempFileWith(checksumFile))) } else { (None, None) @@ -540,20 +541,13 @@ private[spark] class IndexShuffleBlockResolver( def getChecksumFile( shuffleId: Int, mapId: Long, + algorithm: String, dirs: Option[Array[String]] = None): File = { val blockId = ShuffleChecksumBlockId(shuffleId, mapId, NOOP_REDUCE_ID) - val fileName = ShuffleChecksumHelper.getChecksumFileName( - blockId.name, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)) - val fileNameWithoutChecksum = fileName.substring(0, fileName.lastIndexOf('.')) - // We should use the file name without checksum first to create the file so that - // readers (e.g., shuffle external service) without knowing the checksum algorithm - // can also find the file. - val file = dirs - .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, fileNameWithoutChecksum)) - .getOrElse(blockManager.diskBlockManager.getFile(fileNameWithoutChecksum)) - - // Return the file with the checksum algorithm as extension - new File(file.getParentFile, fileName) + val fileName = ShuffleChecksumHelper.getChecksumFileName(blockId.name, algorithm) + dirs + .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, fileName)) + .getOrElse(blockManager.diskBlockManager.getFile(fileName)) } override def getBlockData( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 627304d56a618..b7ec89f8c7f1f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -284,12 +284,23 @@ private[spark] class BlockManager( override def getLocalDiskDirs: Array[String] = diskBlockManager.localDirsString - override def diagnoseShuffleBlockCorruption(blockId: BlockId, checksumByReader: Long): Cause = { + /** + * Diagnose the possible cause of the shuffle data corruption by verify the shuffle checksums + * + * @param blockId The blockId of the corrupted shuffle block + * @param checksumByReader The checksum value of the corrupted block + * @param algorithm The cheksum algorithm that is used when calculating the checksum value + */ + override def diagnoseShuffleBlockCorruption( + blockId: BlockId, + checksumByReader: Long, + algorithm: String): Cause = { assert(blockId.isInstanceOf[ShuffleBlockId], s"Corruption diagnosis only supports shuffle block yet, but got $blockId") val shuffleBlock = blockId.asInstanceOf[ShuffleBlockId] val resolver = shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver] - val checksumFile = resolver.getChecksumFile(shuffleBlock.shuffleId, shuffleBlock.mapId) + val checksumFile = + resolver.getChecksumFile(shuffleBlock.shuffleId, shuffleBlock.mapId, algorithm) val reduceId = shuffleBlock.reduceId ShuffleChecksumHelper.diagnoseCorruption( checksumFile, reduceId, resolver.getBlockData(shuffleBlock), checksumByReader) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 893bfe1a9ec6a..2de2528b43e0e 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -1045,8 +1045,9 @@ final class ShuffleBlockFetcherIterator( return Cause.UNKNOWN_ISSUE } val checksum = checkedIn.getChecksum.getValue + val algorithm = SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) val cause = shuffleClient.diagnoseCorruption(address.host, address.port, address.executorId, - shuffleBlock.shuffleId, shuffleBlock.mapId, shuffleBlock.reduceId, checksum) + shuffleBlock.shuffleId, shuffleBlock.mapId, shuffleBlock.reduceId, checksum, algorithm) val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) logInfo(s"Finished corruption diagnosis in ${duration} ms, cause: $cause") cause diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index abe2b5694bef5..21704b1c67325 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -262,7 +262,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa val indexInMemory = Array[Long](0, 1, 2, 3, 4, 5, 6, 7, 8, 9) val checksumsInMemory = Array[Long](0, 1, 2, 3, 4, 5, 6, 7, 8, 9) resolver.writeMetadataFileAndCommit(0, 0, indexInMemory, checksumsInMemory, dataTmp) - val checksumFile = resolver.getChecksumFile(0, 0) + val checksumFile = resolver.getChecksumFile(0, 0, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)) assert(checksumFile.exists()) val checksumFileName = checksumFile.toString val checksumAlgo = checksumFileName.substring(checksumFileName.lastIndexOf(".") + 1) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index e3457367d9baf..e8899e9a61118 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.PrivateMethodTester import org.scalatest.matchers.must.Matchers import org.apache.spark.{Aggregator, DebugFilesystem, Partitioner, SharedSparkContext, ShuffleDependency, SparkContext, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleChecksumTestHelper} @@ -165,7 +166,8 @@ class SortShuffleWriterSuite val expectSpillSize = if (doSpill) records.size else 0 assert(sorter.numSpills === expectSpillSize) writer.stop(success = true) - val checksumFile = shuffleBlockResolver.getChecksumFile(shuffleId, 0) + val checksumFile = shuffleBlockResolver + .getChecksumFile(shuffleId, 0, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)) assert(checksumFile.exists()) assert(checksumFile.length() === 8 * numPartition) val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 0) From 993ea3d07f3e15c75b0b0879b2bd7fd9e1b33dcb Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 26 Jul 2021 23:50:00 +0800 Subject: [PATCH 24/48] fix tests --- .../shuffle/sort/UnsafeShuffleWriterSuite.java | 16 ++++++++-------- .../sort/BypassMergeSortShuffleWriterSuite.scala | 7 ++++--- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index be3c9a4199793..453702e85c9dc 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -302,12 +302,12 @@ public void writeChecksumFileWithoutSpill() throws Exception { IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager); ShuffleChecksumBlockId checksumBlockId = new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID()); - File checksumFile = new File(tempDir, - ShuffleChecksumHelper.getChecksumFileName( - checksumBlockId.name(), conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()))); + String checksumFileName = ShuffleChecksumHelper.getChecksumFileName( + checksumBlockId.name(), conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM())); + File checksumFile = new File(tempDir, checksumFileName); File dataFile = new File(tempDir, "data"); File indexFile = new File(tempDir, "index"); - when(diskBlockManager.getFile(checksumBlockId.name())).thenReturn(checksumFile); + when(diskBlockManager.getFile(checksumFileName)).thenReturn(checksumFile); when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0))) .thenReturn(dataFile); when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0))) @@ -331,12 +331,12 @@ public void writeChecksumFileWithSpill() throws Exception { IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager); ShuffleChecksumBlockId checksumBlockId = new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID()); - File checksumFile = - new File(tempDir, ShuffleChecksumHelper.getChecksumFileName( - checksumBlockId.name(), conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()))); + String checksumFileName = ShuffleChecksumHelper.getChecksumFileName( + checksumBlockId.name(), conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM())); + File checksumFile = new File(tempDir, checksumFileName); File dataFile = new File(tempDir, "data"); File indexFile = new File(tempDir, "index"); - when(diskBlockManager.getFile(checksumBlockId.name())).thenReturn(checksumFile); + when(diskBlockManager.getFile(checksumFileName)).thenReturn(checksumFile); when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0))) .thenReturn(dataFile); when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0))) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index a3b0830349029..0dd1998ea7afa 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -249,12 +249,13 @@ class BypassMergeSortShuffleWriterSuite val checksumBlockId = ShuffleChecksumBlockId(shuffleId, mapId, 0) val dataBlockId = ShuffleDataBlockId(shuffleId, mapId, 0) val indexBlockId = ShuffleIndexBlockId(shuffleId, mapId, 0) - val checksumFile = new File(tempDir, ShuffleChecksumHelper.getChecksumFileName( - checksumBlockId.name, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM))) + val checksumFileName = ShuffleChecksumHelper.getChecksumFileName( + checksumBlockId.name, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)) + val checksumFile = new File(tempDir, checksumFileName) val dataFile = new File(tempDir, dataBlockId.name) val indexFile = new File(tempDir, indexBlockId.name) reset(diskBlockManager) - when(diskBlockManager.getFile(checksumBlockId.name)).thenAnswer(_ => checksumFile) + when(diskBlockManager.getFile(checksumFileName)).thenAnswer(_ => checksumFile) when(diskBlockManager.getFile(dataBlockId)).thenAnswer(_ => dataFile) when(diskBlockManager.getFile(indexBlockId)).thenAnswer(_ => indexFile) when(diskBlockManager.createTempShuffleBlock()) From e2e5fa0220d395760640f7bbfb6a0b2b326bf2c6 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 27 Jul 2021 15:39:11 +0800 Subject: [PATCH 25/48] fix rat of ShuffleChecksumSupport --- .../checksum/ShuffleChecksumSupport.java | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumSupport.java b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumSupport.java index b65dfa847d407..4f7c3f20b4c7e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumSupport.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.shuffle.checksum; import java.util.zip.Checksum; From c407962e04c8f6fcc8a1972eb2c985042f2e47dc Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 27 Jul 2021 15:46:37 +0800 Subject: [PATCH 26/48] add since for Cause --- .../main/java/org/apache/spark/network/corruption/Cause.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java b/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java index 0e068438a13a6..4426de033350e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java +++ b/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java @@ -19,6 +19,8 @@ /** * The cause of shuffle data corruption. + * + * @since 3.2.0 */ public enum Cause { DISK_ISSUE, NETWORK_ISSUE, UNKNOWN_ISSUE, CHECKSUM_VERIFY_PASS, UNSUPPORTED_CHECKSUM_ALGORITHM From 1b612a2cf0d87600030c1700cbd2baf5e9c7c96e Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 27 Jul 2021 15:52:46 +0800 Subject: [PATCH 27/48] add comment for fileName --- .../spark/network/shuffle/ExternalShuffleBlockResolver.java | 1 + 1 file changed, 1 insertion(+) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 7fc2866375c39..9a23283994fe2 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -391,6 +391,7 @@ public Cause diagnoseShuffleBlockCorruption( long checksumByReader, String algorithm) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); + // This should be in sync with IndexShuffleBlockResolver.getChecksumFile String fileName = "shuffle_" + shuffleId + "_" + mapId + "_0.checksum." + algorithm; File checksumFile = ExecutorDiskUtils.getFile( executor.localDirs, From 3a41a926378d054bd147e54cf854cb8891dae815 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 27 Jul 2021 15:59:06 +0800 Subject: [PATCH 28/48] fix ExternalBlockStoreClient --- .../apache/spark/network/shuffle/ExternalBlockStoreClient.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index 62353b590981e..826402c081cce 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -146,7 +146,7 @@ public void pushBlocks( assert inputListener instanceof BlockPushingListener : "Expecting a BlockPushingListener, but got " + inputListener.getClass(); TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockPusher(client, appId, conf.appAttemptId(), inputBlockId, + new OneForOneBlockPusher(client, appId, transportConf.appAttemptId(), inputBlockId, (BlockPushingListener) inputListener, buffersWithId).start(); } else { logger.info("This clientFactory was closed. Skipping further block push retries."); From cab36a257c34d8152d723e203d982f3a0860e8d1 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 27 Jul 2021 16:15:20 +0800 Subject: [PATCH 29/48] use the explicit checksum algorithm --- .../network/shuffle/ExternalShuffleBlockResolver.java | 3 ++- .../shuffle/checksum/ShuffleChecksumHelper.java | 10 +++------- .../network/shuffle/ExternalBlockHandlerSuite.java | 2 +- .../network/netty/NettyBlockTransferService.scala | 5 ----- .../scala/org/apache/spark/storage/BlockManager.scala | 2 +- .../spark/shuffle/sort/UnsafeShuffleWriterSuite.java | 10 ++++++---- .../spark/shuffle/ShuffleChecksumTestHelper.scala | 10 +++++++--- .../sort/BypassMergeSortShuffleWriterSuite.scala | 5 +++-- .../spark/shuffle/sort/SortShuffleWriterSuite.scala | 6 +++--- 9 files changed, 26 insertions(+), 27 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 9a23283994fe2..cde3d5500feea 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -398,7 +398,8 @@ public Cause diagnoseShuffleBlockCorruption( executor.subDirsPerLocalDir, fileName); ManagedBuffer data = getBlockData(appId, execId, shuffleId, mapId, reduceId); - return ShuffleChecksumHelper.diagnoseCorruption(checksumFile, reduceId, data, checksumByReader); + return ShuffleChecksumHelper.diagnoseCorruption( + algorithm, checksumFile, reduceId, data, checksumByReader); } /** Simply encodes an executor's full ID, which is appId + execId. */ diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java index bffcde7870db2..9ac31ab5330da 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java @@ -78,12 +78,6 @@ public static String getChecksumFileName(String blockName, String algorithm) { return String.format("%s.%s", blockName, algorithm); } - public static Checksum getChecksumByFileExtension(String fileName) { - int index = fileName.lastIndexOf("."); - String algorithm = fileName.substring(index + 1); - return getChecksumsByAlgorithm(1, algorithm)[0]; - } - private static long readChecksumByReduceId(File checksumFile, int reduceId) throws IOException { try (DataInputStream in = new DataInputStream(new FileInputStream(checksumFile))) { ByteStreams.skipFully(in, reduceId * 8); @@ -115,6 +109,7 @@ private static long calculateChecksumForPartition( * we suspect the corruption is caused by the NETWORK_ISSUE. Otherwise, the cause remains * CHECKSUM_VERIFY_PASS. In case of the any other failures, the cause remains UNKNOWN_ISSUE. * + * @param algorithm The checksum algorithm that is used for calculating checksum value of partitionData * @param checksumFile The checksum file that written by the shuffle writer * @param reduceId The reduceId of the shuffle block * @param partitionData The partition data of the shuffle block @@ -122,6 +117,7 @@ private static long calculateChecksumForPartition( * @return The cause of data corruption */ public static Cause diagnoseCorruption( + String algorithm, File checksumFile, int reduceId, ManagedBuffer partitionData, @@ -132,7 +128,7 @@ public static Cause diagnoseCorruption( // Try to get the checksum instance before reading the checksum file so that // `UnsupportedOperationException` can be thrown first before `FileNotFoundException` // when the checksum algorithm isn't supported. - Checksum checksumAlgo = getChecksumByFileExtension(checksumFile.getName()); + Checksum checksumAlgo = getChecksumByAlgorithm(algorithm); long checksumByWriter = readChecksumByReduceId(checksumFile, reduceId); long checksumByReCalculation = calculateChecksumForPartition(partitionData, checksumAlgo); long duration = System.currentTimeMillis() - diagnoseStart; diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 8d3bafb8f1f0d..14073d809b5df 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -164,7 +164,7 @@ private void checkDiagnosisResult( out.close(); when(blockResolver.getBlockData(appId, execId, shuffleId, mapId, reduceId)).thenReturn(blockMarkers[0]); - Cause actualCause = ShuffleChecksumHelper.diagnoseCorruption(checksumFile, reduceId, + Cause actualCause = ShuffleChecksumHelper.diagnoseCorruption(algorithm, checksumFile, reduceId, blockResolver.getBlockData(appId, execId, shuffleId, mapId, reduceId), checksumByReader); when(blockResolver .diagnoseShuffleBlockCorruption(appId, execId, shuffleId, mapId, reduceId, checksumByReader, algorithm)) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index f8d7fcef38355..6da0cb439db1a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -36,13 +36,8 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ -<<<<<<< HEAD import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockTransferListener, DownloadFileManager, OneForOneBlockFetcher, RetryingBlockTransferor} -import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, CorruptionCause, DiagnoseCorruption, UploadBlock, UploadBlockStream} -======= -import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, OneForOneBlockFetcher, RetryingBlockFetcher} import org.apache.spark.network.shuffle.protocol.{UploadBlock, UploadBlockStream} ->>>>>>> update import org.apache.spark.network.util.JavaUtils import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.serializer.JavaSerializer diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index b7ec89f8c7f1f..65047a7b30ede 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -303,7 +303,7 @@ private[spark] class BlockManager( resolver.getChecksumFile(shuffleBlock.shuffleId, shuffleBlock.mapId, algorithm) val reduceId = shuffleBlock.reduceId ShuffleChecksumHelper.diagnoseCorruption( - checksumFile, reduceId, resolver.getBlockData(shuffleBlock), checksumByReader) + algorithm, checksumFile, reduceId, resolver.getBlockData(shuffleBlock), checksumByReader) } /** diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 453702e85c9dc..f845b4be8a0bd 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -302,8 +302,9 @@ public void writeChecksumFileWithoutSpill() throws Exception { IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager); ShuffleChecksumBlockId checksumBlockId = new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID()); + String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); String checksumFileName = ShuffleChecksumHelper.getChecksumFileName( - checksumBlockId.name(), conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM())); + checksumBlockId.name(), checksumAlgorithm); File checksumFile = new File(tempDir, checksumFileName); File dataFile = new File(tempDir, "data"); File indexFile = new File(tempDir, "index"); @@ -323,7 +324,7 @@ public void writeChecksumFileWithoutSpill() throws Exception { writer1.stop(true); assertTrue(checksumFile.exists()); assertEquals(checksumFile.length(), 8 * NUM_PARTITIONS); - compareChecksums(NUM_PARTITIONS, checksumFile, dataFile, indexFile); + compareChecksums(NUM_PARTITIONS, checksumAlgorithm, checksumFile, dataFile, indexFile); } @Test @@ -331,8 +332,9 @@ public void writeChecksumFileWithSpill() throws Exception { IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager); ShuffleChecksumBlockId checksumBlockId = new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID()); + String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); String checksumFileName = ShuffleChecksumHelper.getChecksumFileName( - checksumBlockId.name(), conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM())); + checksumBlockId.name(), checksumAlgorithm); File checksumFile = new File(tempDir, checksumFileName); File dataFile = new File(tempDir, "data"); File indexFile = new File(tempDir, "index"); @@ -358,7 +360,7 @@ public void writeChecksumFileWithSpill() throws Exception { writer1.closeAndWriteOutput(); assertTrue(checksumFile.exists()); assertEquals(checksumFile.length(), 8 * NUM_PARTITIONS); - compareChecksums(NUM_PARTITIONS, checksumFile, dataFile, indexFile); + compareChecksums(NUM_PARTITIONS, checksumAlgorithm, checksumFile, dataFile, indexFile); } private void testMergingSpills( diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala index 30c8026ef10c0..3db2f77fe1534 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala @@ -28,7 +28,12 @@ trait ShuffleChecksumTestHelper { /** * Ensure that the checksum values are consistent between write and read side. */ - def compareChecksums(numPartition: Int, checksum: File, data: File, index: File): Unit = { + def compareChecksums( + numPartition: Int, + algorithm: String, + checksum: File, + data: File, + index: File): Unit = { assert(checksum.exists(), "Checksum file doesn't exist") assert(data.exists(), "Data file doesn't exist") assert(index.exists(), "Index file doesn't exist") @@ -55,8 +60,7 @@ trait ShuffleChecksumTestHelper { val curOffset = indexIn.readLong val limit = (curOffset - prevOffset).toInt val bytes = new Array[Byte](limit) - val checksumCal = - ShuffleChecksumHelper.getChecksumByFileExtension(checksum.getName) + val checksumCal = ShuffleChecksumHelper.getChecksumByAlgorithm(algorithm) checkedIn = new CheckedInputStream( new LimitedInputStream(dataIn, curOffset - prevOffset), checksumCal) checkedIn.read(bytes, 0, limit) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 0dd1998ea7afa..38ed702d0e4c7 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -249,8 +249,9 @@ class BypassMergeSortShuffleWriterSuite val checksumBlockId = ShuffleChecksumBlockId(shuffleId, mapId, 0) val dataBlockId = ShuffleDataBlockId(shuffleId, mapId, 0) val indexBlockId = ShuffleIndexBlockId(shuffleId, mapId, 0) + val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) val checksumFileName = ShuffleChecksumHelper.getChecksumFileName( - checksumBlockId.name, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)) + checksumBlockId.name, checksumAlgorithm) val checksumFile = new File(tempDir, checksumFileName) val dataFile = new File(tempDir, dataBlockId.name) val indexFile = new File(tempDir, indexBlockId.name) @@ -279,6 +280,6 @@ class BypassMergeSortShuffleWriterSuite writer.stop( /* success = */ true) assert(checksumFile.exists()) assert(checksumFile.length() === 8 * numPartition) - compareChecksums(numPartition, checksumFile, dataFile, indexFile) + compareChecksums(numPartition, checksumAlgorithm, checksumFile, dataFile, indexFile) } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index e8899e9a61118..6c13c7c8c3c61 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -166,13 +166,13 @@ class SortShuffleWriterSuite val expectSpillSize = if (doSpill) records.size else 0 assert(sorter.numSpills === expectSpillSize) writer.stop(success = true) - val checksumFile = shuffleBlockResolver - .getChecksumFile(shuffleId, 0, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)) + val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) + val checksumFile = shuffleBlockResolver.getChecksumFile(shuffleId, 0, checksumAlgorithm) assert(checksumFile.exists()) assert(checksumFile.length() === 8 * numPartition) val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 0) val indexFile = shuffleBlockResolver.getIndexFile(shuffleId, 0) - compareChecksums(numPartition, checksumFile, dataFile, indexFile) + compareChecksums(numPartition, checksumAlgorithm, checksumFile, dataFile, indexFile) localSC.stop() } } From 1292390af68f21c27c15b83cd6f99ac8a3cc693d Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 27 Jul 2021 16:20:23 +0800 Subject: [PATCH 30/48] update comment of diagnoseCorruption --- .../spark/network/shuffle/ExternalBlockHandlerSuite.java | 4 ++-- .../apache/spark/storage/ShuffleBlockFetcherIterator.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 14073d809b5df..11c040692f23f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -141,13 +141,13 @@ private void checkDiagnosisResult( switch (expectedCaused) { case DISK_ISSUE: - out.writeLong(-checksumByWriter); + out.writeLong(- checksumByWriter); checksumByReader = checksumByWriter; break; case NETWORK_ISSUE: out.writeLong(checksumByWriter); - checksumByReader = -1 * checksumByWriter; + checksumByReader = - checksumByWriter; break; case UNKNOWN_ISSUE: diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 2de2528b43e0e..134a611abb857 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -1012,7 +1012,7 @@ final class ShuffleBlockFetcherIterator( /** * Get the suspect corruption cause for the corrupted block. It should be only invoked - * when checksum is enabled. + * when checksum is enabled and corruption was detected at least once. * * This will firstly consume the rest of stream of the corrupted block to calculate the * checksum of the block. Then, it will raise a synchronized RPC call along with the From 8610333ca30b70a8d497a806ab39892f783c225c Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 27 Jul 2021 20:56:16 +0800 Subject: [PATCH 31/48] improve error msg for different causes --- .../storage/ShuffleBlockFetcherIterator.scala | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 134a611abb857..ec322f286a6a8 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -863,7 +863,21 @@ final class ShuffleBlockFetcherIterator( // Diagnose the cause of data corruption if shuffle checksum is enabled val cause = diagnoseCorruption(checkedIn, address, blockId) buf.release() - val errorMsg = s"Block $blockId is corrupted due to $cause." + val errorMsg = cause match { + case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => + s"Block $blockId is corrupted but corruption diagnosis failed due to " + + s"unsupported checksum algorithm: " + + s"${SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)}" + + case Cause.CHECKSUM_VERIFY_PASS => + s"Block $blockId is corrupted but checksum verification passed" + + case Cause.UNKNOWN_ISSUE => + s"Block $blockId is corrupted but the cause is unknown" + + case otherCause => + s"Block $blockId is corrupted due to $otherCause" + } logError(errorMsg) throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) } else { From 47764e336f9b50c3646465cb562372c2ba2457ef Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 27 Jul 2021 21:07:08 +0800 Subject: [PATCH 32/48] inline diagnosisResponse into diagnoseCorruption --- .../storage/ShuffleBlockFetcherIterator.scala | 57 ++++++++++--------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index ec322f286a6a8..46fe1cabe4193 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -861,25 +861,11 @@ final class ShuffleBlockFetcherIterator( // It's the second time this block is detected corrupted if (checksumEnabled) { // Diagnose the cause of data corruption if shuffle checksum is enabled - val cause = diagnoseCorruption(checkedIn, address, blockId) + val diagnosisResponse = diagnoseCorruption(checkedIn, address, blockId) buf.release() - val errorMsg = cause match { - case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => - s"Block $blockId is corrupted but corruption diagnosis failed due to " + - s"unsupported checksum algorithm: " + - s"${SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)}" - - case Cause.CHECKSUM_VERIFY_PASS => - s"Block $blockId is corrupted but checksum verification passed" - - case Cause.UNKNOWN_ISSUE => - s"Block $blockId is corrupted but the cause is unknown" - - case otherCause => - s"Block $blockId is corrupted due to $otherCause" - } - logError(errorMsg) - throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) + logError(diagnosisResponse) + throwFetchFailedException( + blockId, mapIndex, address, e, Some(diagnosisResponse)) } else { throwFetchFailedException(blockId, mapIndex, address, e) } @@ -1039,32 +1025,48 @@ final class ShuffleBlockFetcherIterator( * @param checkedIn the [[CheckedInputStream]] which is used to calculate the checksum. * @param address the address where the corrupted block is fetched from. * @param blockId the blockId of the corrupted block. - * @return the cause of corruption, which should be one of the [[Cause]]. + * @return The corruption diagnosis response for different causes. */ private[storage] def diagnoseCorruption( checkedIn: CheckedInputStream, address: BlockManagerId, - blockId: BlockId): Cause = { + blockId: BlockId): String = { logInfo("Start corruption diagnosis.") val startTimeNs = System.nanoTime() assert(blockId.isInstanceOf[ShuffleBlockId], s"Expected ShuffleBlockId, but got $blockId") val shuffleBlock = blockId.asInstanceOf[ShuffleBlockId] val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER) // consume the remaining data to calculate the checksum + var cause: Cause = null try { while (checkedIn.read(buffer) != -1) {} } catch { case e: IOException => logWarning("IOException throws while consuming the rest stream of the corrupted block", e) - return Cause.UNKNOWN_ISSUE + cause = Cause.UNKNOWN_ISSUE } val checksum = checkedIn.getChecksum.getValue val algorithm = SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) - val cause = shuffleClient.diagnoseCorruption(address.host, address.port, address.executorId, + cause = shuffleClient.diagnoseCorruption(address.host, address.port, address.executorId, shuffleBlock.shuffleId, shuffleBlock.mapId, shuffleBlock.reduceId, checksum, algorithm) val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) - logInfo(s"Finished corruption diagnosis in ${duration} ms, cause: $cause") - cause + val diagnosisResponse = cause match { + case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => + s"Block $blockId is corrupted but corruption diagnosis failed due to " + + s"unsupported checksum algorithm: " + + s"${SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)}" + + case Cause.CHECKSUM_VERIFY_PASS => + s"Block $blockId is corrupted but checksum verification passed" + + case Cause.UNKNOWN_ISSUE => + s"Block $blockId is corrupted but the cause is unknown" + + case otherCause => + s"Block $blockId is corrupted due to $otherCause" + } + logInfo(s"Finished corruption diagnosis in $duration ms. $diagnosisResponse") + diagnosisResponse } def toCompletionIterator: Iterator[(BlockId, InputStream)] = { @@ -1297,14 +1299,13 @@ private class BufferReleasingInputStream( block } catch { case e: IOException if detectCorruption => - val message = checkedInOpt.map { checkedIn => - val cause = iterator.diagnoseCorruption(checkedIn, address, blockId) - s"Block $blockId is corrupted due to $cause" + val diagnosisResponse = checkedInOpt.map { checkedIn => + iterator.diagnoseCorruption(checkedIn, address, blockId) }.orNull IOUtils.closeQuietly(this) // We'd never retry the block whatever the cause is since the block has been // partially consumed by downstream RDDs. - iterator.throwFetchFailedException(blockId, mapIndex, address, e, Some(message)) + iterator.throwFetchFailedException(blockId, mapIndex, address, e, Some(diagnosisResponse)) } } } From c041f0118e1940d3a7e214bc1cb923463f5632f7 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 27 Jul 2021 21:08:57 +0800 Subject: [PATCH 33/48] update warning msg --- .../spark/network/shuffle/checksum/ShuffleChecksumHelper.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java index 9ac31ab5330da..39878db29922e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java @@ -148,7 +148,7 @@ public static Cause diagnoseCorruption( logger.warn("Checksum file " + checksumFile.getName() + " doesn't exit"); cause = Cause.UNKNOWN_ISSUE; } catch (Exception e) { - logger.warn("Exception throws while diagnosing shuffle block corruption.", e); + logger.warn("Unable to diagnose shuffle block corruption", e); cause = Cause.UNKNOWN_ISSUE; } return cause; From 4cd8350cd5bd81fe5645eed0743ed19815ad35a0 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 27 Jul 2021 23:34:02 +0800 Subject: [PATCH 34/48] fix ShuffleBlockFetcherIteratorSuite --- .../shuffle/BlockStoreShuffleReader.scala | 2 ++ .../storage/ShuffleBlockFetcherIterator.scala | 24 ++++++++++--------- .../ShuffleBlockFetcherIteratorSuite.scala | 9 +++++-- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 818aa2ef75a9e..df06b07852905 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -83,6 +83,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM), readMetrics, fetchContinuousBlocksInBatch).toCompletionIterator diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 46fe1cabe4193..9ce6a07f76d22 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -32,9 +32,9 @@ import io.netty.util.internal.OutOfDirectMemoryError import org.apache.commons.io.IOUtils import org.roaringbitmap.RoaringBitmap -import org.apache.spark.{MapOutputTracker, SparkEnv, SparkException, TaskContext} +import org.apache.spark.{MapOutputTracker, SparkException, TaskContext} import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID -import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.corruption.Cause import org.apache.spark.network.shuffle._ @@ -72,7 +72,11 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param maxAttemptsOnNettyOOM The max number of a block could retry due to Netty OOM before * throwing the fetch failure. - * @param detectCorrupt whether to detect any corruption in fetched blocks. + * @param detectCorrupt whether to detect any corruption in fetched blocks. + * @param checksumEnabled whether the shuffle checksum is enabled. When enabled, Spark will try to + * diagnose the cause of the block corruption. + * @param checksumAlgorithm the checksum algorithm that is used when calculating the checksum value + * for the block data. * @param shuffleMetrics used to report shuffle metrics. * @param doBatchFetch fetch continuous shuffle blocks from same executor in batch if the server * side supports. @@ -92,6 +96,8 @@ final class ShuffleBlockFetcherIterator( maxAttemptsOnNettyOOM: Int, detectCorrupt: Boolean, detectCorruptUseExtraMemory: Boolean, + checksumEnabled: Boolean, + checksumAlgorithm: String, shuffleMetrics: ShuffleReadMetricsReporter, doBatchFetch: Boolean) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { @@ -164,8 +170,6 @@ final class ShuffleBlockFetcherIterator( */ private[this] val corruptedBlocks = mutable.HashSet[BlockId]() - private[this] val checksumEnabled = SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED) - /** * Whether the iterator is still active. If isZombie is true, the callback interface will no * longer place fetched blocks into [[results]]. @@ -796,8 +800,7 @@ final class ShuffleBlockFetcherIterator( val in = try { var bufIn = buf.createInputStream() if (checksumEnabled) { - val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm( - SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)) + val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) checkedIn = new CheckedInputStream(bufIn, checksum) bufIn = checkedIn } @@ -1046,15 +1049,14 @@ final class ShuffleBlockFetcherIterator( cause = Cause.UNKNOWN_ISSUE } val checksum = checkedIn.getChecksum.getValue - val algorithm = SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) cause = shuffleClient.diagnoseCorruption(address.host, address.port, address.executorId, - shuffleBlock.shuffleId, shuffleBlock.mapId, shuffleBlock.reduceId, checksum, algorithm) + shuffleBlock.shuffleId, shuffleBlock.mapId, shuffleBlock.reduceId, checksum, + checksumAlgorithm) val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) val diagnosisResponse = cause match { case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => s"Block $blockId is corrupted but corruption diagnosis failed due to " + - s"unsupported checksum algorithm: " + - s"${SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)}" + s"unsupported checksum algorithm: $checksumAlgorithm" case Cause.CHECKSUM_VERIFY_PASS => s"Block $blockId is corrupted but checksum verification passed" diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index c22e1d0ca2244..e0e3c8407937d 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -21,6 +21,7 @@ import java.io._ import java.nio.ByteBuffer import java.util.UUID import java.util.concurrent.{CompletableFuture, Semaphore} +import java.util.zip.CheckedInputStream import scala.collection.mutable import scala.concurrent.ExecutionContext.Implicits.global @@ -142,9 +143,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } // Create a mock managed buffer for testing - def createMockManagedBuffer(size: Int = 1): ManagedBuffer = { + def createMockManagedBuffer(size: Int = 1, checksumEnabled: Boolean = true): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) - val in = mock(classOf[InputStream]) + val in = if (checksumEnabled) mock(classOf[CheckedInputStream]) else mock(classOf[InputStream]) when(in.read(any())).thenReturn(1) when(in.read(any(), any(), any())).thenReturn(1) when(mockManagedBuffer.createInputStream()).thenReturn(in) @@ -180,6 +181,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxAttemptsOnNettyOOM: Int = 10, detectCorrupt: Boolean = true, detectCorruptUseExtraMemory: Boolean = true, + checksumEnabled: Boolean = false, + checksumAlgorithm: String = "ADLER32", shuffleMetrics: Option[ShuffleReadMetricsReporter] = None, doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = { val tContext = taskContext.getOrElse(TaskContext.empty()) @@ -197,6 +200,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxAttemptsOnNettyOOM, detectCorrupt, detectCorruptUseExtraMemory, + checksumEnabled, + checksumAlgorithm, shuffleMetrics.getOrElse(tContext.taskMetrics().createTempShuffleReadMetrics()), doBatchFetch) } From 2da1b910ac5e8699fe46bca13a20976e41d2bfa1 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 27 Jul 2021 23:39:48 +0800 Subject: [PATCH 35/48] fix java lint --- .../spark/network/shuffle/ExternalShuffleBlockResolver.java | 3 --- .../network/shuffle/checksum/ShuffleChecksumHelper.java | 3 ++- .../spark/network/shuffle/ExternalBlockHandlerSuite.java | 6 ++++-- .../spark/shuffle/sort/BypassMergeSortShuffleWriter.java | 4 +++- .../apache/spark/shuffle/sort/ShuffleExternalSorter.java | 1 - .../apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java | 1 - 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index cde3d5500feea..1ac8d10847d6a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -18,10 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.*; -import java.nio.file.DirectoryStream; -import java.nio.file.Files; import java.nio.charset.StandardCharsets; -import java.nio.file.Path; import java.util.*; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java index 39878db29922e..c4eaab7185257 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java @@ -109,7 +109,8 @@ private static long calculateChecksumForPartition( * we suspect the corruption is caused by the NETWORK_ISSUE. Otherwise, the cause remains * CHECKSUM_VERIFY_PASS. In case of the any other failures, the cause remains UNKNOWN_ISSUE. * - * @param algorithm The checksum algorithm that is used for calculating checksum value of partitionData + * @param algorithm The checksum algorithm that is used for calculating checksum value + * of partitionData * @param checksumFile The checksum file that written by the shuffle writer * @param reduceId The reduceId of the shuffle block * @param partitionData The partition data of the shuffle block diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 11c040692f23f..5763d67954eb1 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -163,11 +163,13 @@ private void checkDiagnosisResult( } out.close(); - when(blockResolver.getBlockData(appId, execId, shuffleId, mapId, reduceId)).thenReturn(blockMarkers[0]); + when(blockResolver.getBlockData(appId, execId, shuffleId, mapId, reduceId)) + .thenReturn(blockMarkers[0]); Cause actualCause = ShuffleChecksumHelper.diagnoseCorruption(algorithm, checksumFile, reduceId, blockResolver.getBlockData(appId, execId, shuffleId, mapId, reduceId), checksumByReader); when(blockResolver - .diagnoseShuffleBlockCorruption(appId, execId, shuffleId, mapId, reduceId, checksumByReader, algorithm)) + .diagnoseShuffleBlockCorruption( + appId, execId, shuffleId, mapId, reduceId, checksumByReader, algorithm)) .thenReturn(actualCause); when(client.getClientId()).thenReturn(appId); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 53323b6eb817c..9a5ac6f287beb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -77,7 +77,9 @@ *

* There have been proposals to completely remove this code path; see SPARK-6026 for details. */ -final class BypassMergeSortShuffleWriter extends ShuffleWriter implements ShuffleChecksumSupport { +final class BypassMergeSortShuffleWriter + extends ShuffleWriter + implements ShuffleChecksumSupport { private static final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index ea08f77c3141c..a82f691d085d4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -38,7 +38,6 @@ import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TooLargePageException; -import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index f845b4be8a0bd..87f9ab32eb585 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -43,7 +43,6 @@ import org.apache.spark.io.LZ4CompressionCodec; import org.apache.spark.io.LZFCompressionCodec; import org.apache.spark.io.SnappyCompressionCodec; -import org.apache.spark.internal.config.package$; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.network.util.LimitedInputStream; From 95ef9dbf1c3cd3e7d45dc30fdd340dbc39ff933a Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 28 Jul 2021 00:20:57 +0800 Subject: [PATCH 36/48] fix tests --- .../ShuffleBlockFetcherIteratorSuite.scala | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index e0e3c8407937d..2e707f8dfc237 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -143,9 +143,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } // Create a mock managed buffer for testing - def createMockManagedBuffer(size: Int = 1, checksumEnabled: Boolean = true): ManagedBuffer = { + def createMockManagedBuffer(size: Int = 1): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) - val in = if (checksumEnabled) mock(classOf[CheckedInputStream]) else mock(classOf[InputStream]) + val in = mock(classOf[InputStream]) when(in.read(any())).thenReturn(1) when(in.read(any(), any(), any())).thenReturn(1) when(mockManagedBuffer.createInputStream()).thenReturn(in) @@ -158,14 +158,19 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream] verify(buffer, times(0)).release() val delegateAccess = PrivateMethod[InputStream](Symbol("delegate")) - - verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close() + var in = wrappedInputStream.invokePrivate(delegateAccess()) + if (in.isInstanceOf[CheckedInputStream]) { + val underlyingInputFiled = classOf[CheckedInputStream].getSuperclass.getDeclaredField("in") + underlyingInputFiled.setAccessible(true) + in = underlyingInputFiled.get(in.asInstanceOf[CheckedInputStream]).asInstanceOf[InputStream] + } + verify(in, times(0)).close() wrappedInputStream.close() verify(buffer, times(1)).release() - verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + verify(in, times(1)).close() wrappedInputStream.close() // close should be idempotent verify(buffer, times(1)).release() - verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + verify(in, times(1)).close() } // scalastyle:off argcount @@ -181,7 +186,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxAttemptsOnNettyOOM: Int = 10, detectCorrupt: Boolean = true, detectCorruptUseExtraMemory: Boolean = true, - checksumEnabled: Boolean = false, + checksumEnabled: Boolean = true, checksumAlgorithm: String = "ADLER32", shuffleMetrics: Option[ShuffleReadMetricsReporter] = None, doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = { From 925491c410b91912d6cec3ef74f44887204982fe Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 28 Jul 2021 14:09:02 +0800 Subject: [PATCH 37/48] swallow exception from diagnoseCorruption --- .../spark/network/shuffle/BlockStoreClient.java | 4 ++-- .../spark/storage/ShuffleBlockFetcherIterator.scala | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index 5e52b884c90b7..f40b8786fd522 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -69,9 +69,9 @@ public Cause diagnoseCorruption( long mapId, int reduceId, long checksum, - String algorithm) throws IOException, InterruptedException { - TransportClient client = clientFactory.createClient(host, port); + String algorithm) { try { + TransportClient client = clientFactory.createClient(host, port); ByteBuffer response = client.sendRpcSync( new DiagnoseCorruption(appId, execId, shuffleId, mapId, reduceId, checksum, algorithm) .toByteBuffer(), diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 9ce6a07f76d22..fa71a2d93de7e 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -1043,15 +1043,15 @@ final class ShuffleBlockFetcherIterator( var cause: Cause = null try { while (checkedIn.read(buffer) != -1) {} + val checksum = checkedIn.getChecksum.getValue + cause = shuffleClient.diagnoseCorruption(address.host, address.port, address.executorId, + shuffleBlock.shuffleId, shuffleBlock.mapId, shuffleBlock.reduceId, checksum, + checksumAlgorithm) } catch { - case e: IOException => - logWarning("IOException throws while consuming the rest stream of the corrupted block", e) + case e: Exception => + logWarning("Unable to diagnose the corruption cause of the corrupted block", e) cause = Cause.UNKNOWN_ISSUE } - val checksum = checkedIn.getChecksum.getValue - cause = shuffleClient.diagnoseCorruption(address.host, address.port, address.executorId, - shuffleBlock.shuffleId, shuffleBlock.mapId, shuffleBlock.reduceId, checksum, - checksumAlgorithm) val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) val diagnosisResponse = cause match { case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => From 73b7b7003cbe6a8e2aa70c1cedaaec30a016ad12 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 28 Jul 2021 14:21:22 +0800 Subject: [PATCH 38/48] ensure test stability --- .../spark/network/shuffle/ExternalBlockHandlerSuite.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 5763d67954eb1..55607d4ee396e 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -140,14 +140,16 @@ private void checkDiagnosisResult( long checksumByWriter = checkedIn.getChecksum().getValue(); switch (expectedCaused) { + // when checksumByWriter != checksumRecalculated case DISK_ISSUE: - out.writeLong(- checksumByWriter); + out.writeLong(checksumByWriter - 1); checksumByReader = checksumByWriter; break; + // when checksumByWriter == checksumRecalculated and checksumByReader != checksumByWriter case NETWORK_ISSUE: out.writeLong(checksumByWriter); - checksumByReader = - checksumByWriter; + checksumByReader = checksumByWriter - 1; break; case UNKNOWN_ISSUE: From be116fba7c3a57caeed21aecec17688e31d1b9d4 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 28 Jul 2021 14:25:47 +0800 Subject: [PATCH 39/48] update version to 3.2.0 --- .../main/scala/org/apache/spark/internal/config/package.scala | 4 ++-- core/src/main/scala/org/apache/spark/storage/BlockId.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 39c526cb0e8b3..60ba3aac264a5 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1372,7 +1372,7 @@ package object config { ConfigBuilder("spark.shuffle.checksum.enabled") .doc("Whether to calculate the checksum of shuffle output. If enabled, Spark will try " + "its best to tell if shuffle data corruption is caused by network or disk or others.") - .version("3.3.0") + .version("3.2.0") .booleanConf .createWithDefault(true) @@ -1380,7 +1380,7 @@ package object config { ConfigBuilder("spark.shuffle.checksum.algorithm") .doc("The algorithm used to calculate the checksum. Currently, it only supports" + " built-in algorithms of JDK.") - .version("3.3.0") + .version("3.2.0") .stringConf .transform(_.toUpperCase(Locale.ROOT)) .checkValue(Set("ADLER32", "CRC32").contains, "Shuffle checksum algorithm " + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index ce53f08bae8ee..e450129be98f0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -94,7 +94,7 @@ case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) exten override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" } -@Since("3.3.0") +@Since("3.2.0") @DeveloperApi case class ShuffleChecksumBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".checksum" From 64071ee4bec9719ecc50f54a6908c97b34e436ef Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 28 Jul 2021 14:29:45 +0800 Subject: [PATCH 40/48] fix import style --- .../network/shuffle/checksum/ShuffleChecksumHelper.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java index c4eaab7185257..2061457ae531b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java @@ -24,11 +24,12 @@ import java.util.zip.Checksum; import com.google.common.io.ByteStreams; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.spark.annotation.Private; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.corruption.Cause; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * A set of utility functions for the shuffle checksum. From a994c0d94c49e5fd639619a99545b3815a870d18 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 28 Jul 2021 15:55:13 +0800 Subject: [PATCH 41/48] add tests --- .../ShuffleBlockFetcherIteratorSuite.scala | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 2e707f8dfc237..8ed009882bdd1 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -27,6 +27,7 @@ import scala.collection.mutable import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future +import com.google.common.io.ByteStreams import io.netty.util.internal.OutOfDirectMemoryError import org.apache.log4j.Level import org.mockito.ArgumentMatchers.{any, eq => meq} @@ -223,6 +224,69 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockIds.map(blockId => (blockId, blockSize, blockMapIndex)).toSeq } + test("SPARK-36206: diagnose the block when it's corrupted twice") { + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer() + ) + answerFetchBlocks { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + listener.onBlockFetchSuccess(ShuffleBlockId(0, 0, 0).toString, mockCorruptBuffer()) + } + + val logAppender = new LogAppender("diagnose corruption") + withLogAppender(logAppender) { + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)), + streamWrapperLimitSize = Some(100) + ) + intercept[FetchFailedException](iterator.next()) + // The block will be fetched twice due to retry + verify(transfer, times(2)) + .fetchBlocks(any(), any(), any(), any(), any(), any()) + // only diagnose once + assert(logAppender.loggingEvents.count( + _.getRenderedMessage.contains("Start corruption diagnosis")) === 1) + } + } + + test("SPARK-36206: diagnose the block when it's corrupted " + + "inside BufferReleasingInputStream") { + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer() + ) + answerFetchBlocks { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, + mockCorruptBuffer(100, 50)) + } + + val logAppender = new LogAppender("diagnose corruption") + withLogAppender(logAppender) { + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)), + streamWrapperLimitSize = Some(100), + maxBytesInFlight = 100 + ) + intercept[FetchFailedException] { + val inputStream = iterator.next()._2 + // Consume the data to trigger the corruption + ByteStreams.readFully(inputStream, new Array[Byte](100)) + } + // The block will be fetched only once because corruption can't be detected in + // maxBytesInFlight/3 of the data size + verify(transfer, times(1)) + .fetchBlocks(any(), any(), any(), any(), any(), any()) + // only diagnose once + assert(logAppender.loggingEvents.exists( + _.getRenderedMessage.contains("Start corruption diagnosis"))) + } + } + test("successful 3 local + 4 host local + 2 remote reads") { val blockManager = createMockBlockManager() val localBmId = blockManager.blockManagerId From 261f5ec6348d13bc80513317a3ed29831a92c290 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Thu, 29 Jul 2021 10:26:58 +0800 Subject: [PATCH 42/48] refactor bufIn --- .../apache/spark/storage/ShuffleBlockFetcherIterator.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index fa71a2d93de7e..a54bf1ec9f103 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -802,9 +802,10 @@ final class ShuffleBlockFetcherIterator( if (checksumEnabled) { val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) checkedIn = new CheckedInputStream(bufIn, checksum) - bufIn = checkedIn + checkedIn + } else { + bufIn } - bufIn } catch { // The exception could only be throwed by local shuffle block case e: IOException => From 8d9db93254544aae206d810137e906c65ec1e558 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Thu, 29 Jul 2021 10:36:42 +0800 Subject: [PATCH 43/48] use nano time --- .../network/shuffle/checksum/ShuffleChecksumHelper.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java index 2061457ae531b..d12659c2fc022 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle.checksum; import java.io.*; +import java.util.concurrent.TimeUnit; import java.util.zip.Adler32; import java.util.zip.CRC32; import java.util.zip.CheckedInputStream; @@ -126,14 +127,14 @@ public static Cause diagnoseCorruption( long checksumByReader) { Cause cause; try { - long diagnoseStart = System.currentTimeMillis(); + long diagnoseStartNs = System.nanoTime(); // Try to get the checksum instance before reading the checksum file so that // `UnsupportedOperationException` can be thrown first before `FileNotFoundException` // when the checksum algorithm isn't supported. Checksum checksumAlgo = getChecksumByAlgorithm(algorithm); long checksumByWriter = readChecksumByReduceId(checksumFile, reduceId); long checksumByReCalculation = calculateChecksumForPartition(partitionData, checksumAlgo); - long duration = System.currentTimeMillis() - diagnoseStart; + long duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - diagnoseStartNs); logger.info("Shuffle corruption diagnosis took {} ms, checksum file {}", duration, checksumFile.getAbsolutePath()); if (checksumByWriter != checksumByReCalculation) { From 6f72ab16a4c4cd257886b19f07e36fd054479532 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Thu, 29 Jul 2021 10:38:02 +0800 Subject: [PATCH 44/48] Use Option --- .../apache/spark/storage/ShuffleBlockFetcherIterator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a54bf1ec9f103..10d742691a331 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -1304,11 +1304,11 @@ private class BufferReleasingInputStream( case e: IOException if detectCorruption => val diagnosisResponse = checkedInOpt.map { checkedIn => iterator.diagnoseCorruption(checkedIn, address, blockId) - }.orNull + } IOUtils.closeQuietly(this) // We'd never retry the block whatever the cause is since the block has been // partially consumed by downstream RDDs. - iterator.throwFetchFailedException(blockId, mapIndex, address, e, Some(diagnosisResponse)) + iterator.throwFetchFailedException(blockId, mapIndex, address, e, diagnosisResponse) } } } From e5e58d547d03ef4f53115b0070eccb0cc4f85c06 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Thu, 29 Jul 2021 10:53:51 +0800 Subject: [PATCH 45/48] move Cause --- .../org/apache/spark/network/shuffle/BlockStoreClient.java | 2 +- .../apache/spark/network/shuffle/ExternalBlockHandler.java | 6 +++--- .../spark/network/shuffle/ExternalShuffleBlockResolver.java | 2 +- .../org/apache/spark/network/shuffle/checksum}/Cause.java | 4 +--- .../network/shuffle/checksum/ShuffleChecksumHelper.java | 1 - .../spark/network/shuffle/protocol/CorruptionCause.java | 3 ++- .../spark/network/shuffle/ExternalBlockHandlerSuite.java | 4 ++-- .../scala/org/apache/spark/network/BlockDataManager.scala | 2 +- .../main/scala/org/apache/spark/storage/BlockManager.scala | 3 +-- .../apache/spark/storage/ShuffleBlockFetcherIterator.scala | 3 +-- 10 files changed, 13 insertions(+), 17 deletions(-) rename common/{network-common/src/main/java/org/apache/spark/network/corruption => network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum}/Cause.java (93%) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index f40b8786fd522..6dc5fd5a70f1a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -33,7 +33,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.corruption.Cause; +import org.apache.spark.network.shuffle.checksum.Cause; import org.apache.spark.network.shuffle.protocol.*; import org.apache.spark.network.util.TransportConf; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index 830622ce51909..71741f2cba053 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -17,7 +17,6 @@ package org.apache.spark.network.shuffle; -import com.google.common.base.Preconditions; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; @@ -35,9 +34,9 @@ import com.codahale.metrics.RatioGauge; import com.codahale.metrics.Timer; import com.codahale.metrics.Counter; -import com.google.common.collect.Sets; import com.google.common.annotations.VisibleForTesting; -import org.apache.spark.network.corruption.Cause; +import com.google.common.base.Preconditions; +import com.google.common.collect.Sets; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,6 +49,7 @@ import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.shuffle.checksum.Cause; import org.apache.spark.network.shuffle.protocol.*; import org.apache.spark.network.util.TimerWithCustomTimeUnit; import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 1ac8d10847d6a..d8be076ee397b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -45,7 +45,7 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.corruption.Cause; +import org.apache.spark.network.shuffle.checksum.Cause; import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.LevelDBProvider; diff --git a/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/Cause.java similarity index 93% rename from common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/Cause.java index 4426de033350e..d316737a16148 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/Cause.java @@ -15,12 +15,10 @@ * limitations under the License. */ -package org.apache.spark.network.corruption; +package org.apache.spark.network.shuffle.checksum; /** * The cause of shuffle data corruption. - * - * @since 3.2.0 */ public enum Cause { DISK_ISSUE, NETWORK_ISSUE, UNKNOWN_ISSUE, CHECKSUM_VERIFY_PASS, UNSUPPORTED_CHECKSUM_ALGORITHM diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java index d12659c2fc022..f332f740b3f5f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java @@ -30,7 +30,6 @@ import org.apache.spark.annotation.Private; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.corruption.Cause; /** * A set of utility functions for the shuffle checksum. diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java index d9b04030946f9..5690eee53bd13 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java @@ -20,7 +20,8 @@ import io.netty.buffer.ByteBuf; import org.apache.commons.lang3.builder.ToStringBuilder; import org.apache.commons.lang3.builder.ToStringStyle; -import org.apache.spark.network.corruption.Cause; + +import org.apache.spark.network.shuffle.checksum.Cause; /** Response to the {@link DiagnoseCorruption} */ public class CorruptionCause extends BlockTransferMessage { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 55607d4ee396e..d45cbd5adcd98 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -29,8 +29,6 @@ import com.codahale.metrics.Timer; import com.google.common.io.ByteStreams; import com.google.common.io.Files; -import org.apache.spark.network.corruption.Cause; -import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; @@ -48,6 +46,8 @@ import org.apache.spark.network.protocol.MergedBlockMetaRequest; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.shuffle.checksum.Cause; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.CorruptionCause; import org.apache.spark.network.shuffle.protocol.DiagnoseCorruption; diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 594d88cf16c69..0b2ee15750009 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.TaskContext import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.StreamCallbackWithID -import org.apache.spark.network.corruption.Cause +import org.apache.spark.network.shuffle.checksum.Cause import org.apache.spark.storage.{BlockId, StorageLevel} private[spark] diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 65047a7b30ede..619c5b76eb934 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -47,10 +47,9 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.client.StreamCallbackWithID -import org.apache.spark.network.corruption.Cause import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle._ -import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper +import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 10d742691a331..3eb8acd4f5560 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -36,9 +36,8 @@ import org.apache.spark.{MapOutputTracker, SparkException, TaskContext} import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.corruption.Cause import org.apache.spark.network.shuffle._ -import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper +import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} import org.apache.spark.network.util.{NettyUtils, TransportConf} import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} From 7e4c91db9d203d0b69e219ca9eee5fc9a91f0300 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Thu, 29 Jul 2021 11:28:11 +0800 Subject: [PATCH 46/48] fix tests --- core/src/test/scala/org/apache/spark/ShuffleSuite.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 2b558f00df8df..61cede99ddb4d 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -450,12 +450,11 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDi } test("SPARK-36206: shuffle checksum detect disk corruption") { - conf + val newConf = conf.clone .set(config.SHUFFLE_CHECKSUM_ENABLED, true) .set(TEST_NO_STAGE_RETRY, false) .set("spark.stage.maxConsecutiveAttempts", "1") - .set(config.SHUFFLE_SERVICE_ENABLED, false) - sc = new SparkContext("local-cluster[2, 1, 2048]", "test", conf) + sc = new SparkContext("local-cluster[2, 1, 2048]", "test", newConf) val rdd = sc.parallelize(1 to 10, 2).map((_, 1)).reduceByKey(_ + _) // materialize the shuffle map outputs rdd.count() From 67262c9db3600164af82106c045ab60e73db8d78 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Fri, 30 Jul 2021 11:06:49 +0800 Subject: [PATCH 47/48] fix --- .../test/scala/org/apache/spark/ShuffleSuite.scala | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 61cede99ddb4d..c1a964c336109 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark -import java.io.{File, FileOutputStream} -import java.nio.ByteBuffer +import java.io.{File, RandomAccessFile} import java.util.{Locale, Properties} import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService } @@ -468,11 +467,11 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDi } if (dataFile.exists()) { - val f = new FileOutputStream(dataFile, true) - val ch = f.getChannel + val f = new RandomAccessFile(dataFile, "rw") // corrupt the shuffle data files by writing some arbitrary bytes - ch.write(ByteBuffer.wrap(Array[Byte](12)), 0) - ch.close() + f.seek(0) + f.write(Array[Byte](12)) + f.close() } BarrierTaskContext.get().barrier() iter From ca1b058b6ccac9178859c56e2f7dd05f4ff68900 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 2 Aug 2021 14:06:52 +0800 Subject: [PATCH 48/48] address comment --- .../spark/network/shuffle/ExternalShuffleBlockResolver.java | 2 +- .../main/scala/org/apache/spark/network/BlockDataManager.scala | 2 +- core/src/main/scala/org/apache/spark/storage/BlockManager.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index d8be076ee397b..73d4e6ceb1951 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -377,7 +377,7 @@ public Map getLocalDirs(String appId, Set execIds) { } /** - * Diagnose the possible cause of the shuffle data corruption by verify the shuffle checksums + * Diagnose the possible cause of the shuffle data corruption by verifying the shuffle checksums */ public Cause diagnoseShuffleBlockCorruption( String appId, diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 0b2ee15750009..89177346a789a 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -29,7 +29,7 @@ private[spark] trait BlockDataManager { /** - * Diagnose the possible cause of the shuffle data corruption by verify the shuffle checksums + * Diagnose the possible cause of the shuffle data corruption by verifying the shuffle checksums */ def diagnoseShuffleBlockCorruption( blockId: BlockId, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 619c5b76eb934..4c646b27c270f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -284,7 +284,7 @@ private[spark] class BlockManager( override def getLocalDiskDirs: Array[String] = diskBlockManager.localDirsString /** - * Diagnose the possible cause of the shuffle data corruption by verify the shuffle checksums + * Diagnose the possible cause of the shuffle data corruption by verifying the shuffle checksums * * @param blockId The blockId of the corrupted shuffle block * @param checksumByReader The checksum value of the corrupted block