diff --git a/assembly/pom.xml b/assembly/pom.xml index 74c2f44121fca..08e453077e87b 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -268,5 +268,17 @@ + + + + external-shuffle-storage + + + org.apache.spark + external-shuffle-storage_${scala.binary.version} + ${project.version} + + + diff --git a/external-shuffle-storage/README.md b/external-shuffle-storage/README.md new file mode 100644 index 0000000000000..92b578b66e3f4 --- /dev/null +++ b/external-shuffle-storage/README.md @@ -0,0 +1,55 @@ +# External Shuffle Storage + +This module provides support to store shuffle files on external shuffle storage like S3. It helps Dynamic +Allocation on Kubernetes. Spark driver could release idle executors without worrying about losing +shuffle data because the shuffle data is store on external shuffle storage which are different +from executors. + +This module implements a new Shuffle Manager named as StarShuffleManager, and copies a lot of codes +from Spark SortShuffleManager. This is for a quick prototype. We want to use this as an example to discuss +with Spark community and get feedback. We will work with the community to remove code duplication later +and make StarShuffleManager more integrated with Spark code. + +## How to Build Spark Distribution with StarShuffleManager jar File + +Follow [Building Spark](https://spark.apache.org/docs/latest/building-spark.html) instructions, +with extra `-Pexternal-shuffle-storage` to generate the new shuffle implementation jar file. + +Following is one command example to use `dev/make-distribution.sh` under Spark repo root directory: + +``` +./dev/make-distribution.sh --name spark-with-external-shuffle-storage --pip --tgz -Phive -Phive-thriftserver -Pkubernetes -Phadoop-3.2 -Phadoop-cloud -Dhadoop.version=3.2.0 -Pexternal-shuffle-storage +``` + +If you want to build a Spark docker image, you could unzip the Spark distribution tgz file, and run command like following: + +``` +./bin/docker-image-tool.sh -t spark-with-external-shuffle-storage build +``` + +This command creates `external-shuffle-storage_xxx.jar` file for StarShuffleManager +under `jars` directory in the generated Spark distribution. Now you could use this Spark +distribution to run your Spark application with external shuffle storage. + +## How to Run Spark Application With External Shuffle Storage in Kubernetes + +### Run Spark Application With S3 as External Shuffle Storage and Dynamic Allocation + +Add configure to your Spark application like following (you need to adjust the values based on your environment): + +``` +spark.shuffle.manager=org.apache.spark.shuffle.StarShuffleManager +spark.shuffle.star.rootDir=s3://my_bucket_name/my_shuffle_folder +spark.dynamicAllocation.enabled=true +spark.dynamicAllocation.shuffleTracking.enabled=true +spark.dynamicAllocation.shuffleTracking.timeout=1 +``` + +### How to specify AWS region for the S3 files + +Add Spark config like following: + +``` +spark.hadoop.fs.s3a.endpoint.region=us-west-2 +``` + diff --git a/external-shuffle-storage/pom.xml b/external-shuffle-storage/pom.xml new file mode 100644 index 0000000000000..1699240b78fe6 --- /dev/null +++ b/external-shuffle-storage/pom.xml @@ -0,0 +1,143 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.12 + 3.3.0-SNAPSHOT + ../pom.xml + + + external-shuffle-storage_2.12 + jar + External Shuffle Storage + http://spark.apache.org/ + + + external-shuffle-storage + none + package + provided + provided + provided + + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.commons + commons-math3 + provided + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.scala-lang + scala-library + provided + + + io.netty + netty-all + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + org.apache.httpcomponents + httpclient + + + commons-io + commons-io + + + org.apache.commons + commons-lang3 + + + com.amazonaws + aws-java-sdk-s3 + 1.11.975 + provided + + + org.testng + testng + 6.14.3 + test + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + + + org.apache.maven.plugins + maven-jar-plugin + + ${jars.target.dir} + + + + + diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/ByteBufUtils.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/ByteBufUtils.java new file mode 100644 index 0000000000000..7007ce8992560 --- /dev/null +++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/ByteBufUtils.java @@ -0,0 +1,46 @@ +/* + * This file is copied from Uber Remote Shuffle Service + * (https://github.com/uber/RemoteShuffleService) and modified. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.starshuffle; + +import io.netty.buffer.ByteBuf; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; + +public class ByteBufUtils { + public static final void writeLengthAndString(ByteBuf buf, String str) { + if (str == null) { + buf.writeInt(-1); + return; + } + + byte[] bytes = str.getBytes(StandardCharsets.UTF_8); + buf.writeInt(bytes.length); + buf.writeBytes(bytes); + } + + public static final String readLengthAndString(ByteBuf buf) { + int length = buf.readInt(); + if (length == -1) { + return null; + } + + byte[] bytes = new byte[length]; + buf.readBytes(bytes); + return new String(bytes, StandardCharsets.UTF_8); + } +} diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarBlockStoreClient.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarBlockStoreClient.java new file mode 100644 index 0000000000000..56260336e1b44 --- /dev/null +++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarBlockStoreClient.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.starshuffle; + +import org.apache.spark.SparkEnv; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.netty.SparkTransportConf; +import org.apache.spark.network.shuffle.*; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.ShuffleBlockId; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Option; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.concurrent.CompletableFuture; + +/** + * This class fetches shuffle blocks from external storage like S3 + */ +public class StarBlockStoreClient extends BlockStoreClient { + + private static final Logger logger = LoggerFactory.getLogger(StarBlockStoreClient.class); + + // Fetch shuffle blocks from external shuffle storage. + // The shuffle location is encoded in the host argument. In the future, we should enhance + // Spark internal code to support abstraction of shuffle storage location. + @Override + public void fetchBlocks(String host, int port, String execId, String[] blockIds, BlockFetchingListener listener, DownloadFileManager downloadFileManager) { + for (int i = 0; i < blockIds.length; i++) { + String blockId = blockIds[i]; + CompletableFuture.runAsync(() -> fetchBlock(host, execId, blockId, listener, downloadFileManager)); + } + } + + private void fetchBlock(String host, String execId, String blockIdStr, BlockFetchingListener listener, DownloadFileManager downloadFileManager) { + BlockId blockId = BlockId.apply(blockIdStr); + if (blockId instanceof ShuffleBlockId) { + ShuffleBlockId shuffleBlockId = (ShuffleBlockId)blockId; + StarMapResultFileInfo mapResultFileInfo = StarMapResultFileInfo.deserializeFromString(host); + long offset = 0; + for (int i = 0; i < shuffleBlockId.reduceId(); i++) { + offset += mapResultFileInfo.getPartitionLengths()[i]; + } + long size = mapResultFileInfo.getPartitionLengths()[shuffleBlockId.reduceId()]; + StarShuffleFileManager streamProvider = StarUtils.createShuffleFileManager(SparkEnv.get().conf(), + mapResultFileInfo.getLocation()); + if (downloadFileManager != null) { + try (InputStream inputStream = streamProvider.read(mapResultFileInfo.getLocation(), offset, size)) { + TransportConf transportConf = SparkTransportConf.fromSparkConf( + SparkEnv.get().conf(), "starShuffle", 1, Option.empty()); + DownloadFile downloadFile = downloadFileManager.createTempFile(transportConf); + downloadFileManager.registerTempFileToClean(downloadFile); + DownloadFileWritableChannel downloadFileWritableChannel = downloadFile.openForWriting(); + + int bufferSize = 64 * 1024; + byte[] bytes = new byte[bufferSize]; + int readBytes = 0; + while (readBytes < size) { + int toReadBytes = Math.min((int)size - readBytes, bufferSize); + int n = inputStream.read(bytes, 0, toReadBytes); + if (n == -1) { + throw new RuntimeException(String.format( + "Failed to read file %s for shuffle block %s, hit end with remaining %s bytes", + mapResultFileInfo.getLocation(), + blockId, + size - readBytes)); + } + readBytes += n; + downloadFileWritableChannel.write(ByteBuffer.wrap(bytes, 0, n)); + } + ManagedBuffer managedBuffer = downloadFileWritableChannel.closeAndRead(); + listener.onBlockFetchSuccess(blockIdStr, managedBuffer); + } catch (IOException e) { + throw new RuntimeException(String.format( + "Failed to read file %s for shuffle block %s", + mapResultFileInfo.getLocation(), + blockId), + e); + } + } else { + try (InputStream inputStream = streamProvider.read(mapResultFileInfo.getLocation(), offset, size)) { + ByteBuffer byteBuffer = ByteBuffer.allocate((int)size); + int b = inputStream.read(); + while (b != -1) { + byteBuffer.put((byte)b); + if (byteBuffer.position() == size) { + break; + } + b = inputStream.read(); + } + byteBuffer.flip(); + NioManagedBuffer managedBuffer = new NioManagedBuffer(byteBuffer); + listener.onBlockFetchSuccess(blockIdStr, managedBuffer); + } catch (IOException e) { + throw new RuntimeException(String.format( + "Failed to read file %s for shuffle block %s", + mapResultFileInfo.getLocation(), + blockId), + e); + } + } + logger.info("Fetch blocks: {}, {}", host, execId); + } else { + throw new RuntimeException(String.format( + "%s does not support %s: %s", + this.getClass().getSimpleName(), + blockId.getClass().getSimpleName(), + blockId)); + } + } + + @Override + public void close() throws IOException { + logger.info("Close"); + } +} diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarBypassMergeSortShuffleWriter.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarBypassMergeSortShuffleWriter.java new file mode 100644 index 0000000000000..16e73e8377107 --- /dev/null +++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarBypassMergeSortShuffleWriter.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.starshuffle; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.spark.Partitioner; +import org.apache.spark.ShuffleDependency; +import org.apache.spark.SparkConf; +import org.apache.spark.internal.config.package$; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.shuffle.StarBypassMergeSortShuffleHandle; +import org.apache.spark.shuffle.StarOpts; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.storage.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.None$; +import scala.Option; +import scala.Product2; +import scala.Tuple2; +import scala.collection.Iterator; + +import javax.annotation.Nullable; +import java.io.File; +import java.io.IOException; + +/** + * This class is copied from BypassMergeSortShuffleWriter for a quick prototype. + * Will rework it later to avoid code copy. + */ +public class StarBypassMergeSortShuffleWriter extends ShuffleWriter { + + private static final Logger logger = LoggerFactory.getLogger(StarBypassMergeSortShuffleWriter.class); + + private final int fileBufferSize; + private final boolean transferToEnabled; + private final int numPartitions; + private final BlockManager blockManager; + private final Partitioner partitioner; + private final ShuffleWriteMetricsReporter writeMetrics; + private final int shuffleId; + private final long mapId; + private final Serializer serializer; + private final ShuffleExecutorComponents shuffleExecutorComponents; + + /** Array of file writers, one for each partition */ + private DiskBlockObjectWriter[] partitionWriters; + private FileSegment[] partitionWriterSegments; + @Nullable private MapStatus mapStatus; + private long[] partitionLengths; + + private String rootDir; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true + * and then call stop() with success = false if they get an exception, we want to make sure + * we don't try deleting files, etc twice. + */ + private boolean stopping = false; + + public StarBypassMergeSortShuffleWriter( + BlockManager blockManager, + StarBypassMergeSortShuffleHandle handle, + long mapId, + SparkConf conf, + ShuffleWriteMetricsReporter writeMetrics, + ShuffleExecutorComponents shuffleExecutorComponents) { + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; + this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); + this.blockManager = blockManager; + final ShuffleDependency dep = handle.dependency(); + this.mapId = mapId; + this.shuffleId = dep.shuffleId(); + this.partitioner = dep.partitioner(); + this.numPartitions = partitioner.numPartitions(); + this.writeMetrics = writeMetrics; + this.serializer = dep.serializer(); + this.shuffleExecutorComponents = shuffleExecutorComponents; + this.rootDir = conf.get(StarOpts.rootDir()); + } + + @Override + public void write(Iterator> records) throws IOException { + assert (partitionWriters == null); + StartFileSegmentWriter startFileSegmentWriter = new StartFileSegmentWriter(rootDir); + ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents + .createMapOutputWriter(shuffleId, mapId, numPartitions); + try { + if (!records.hasNext()) { + partitionLengths = StarUtils.getLengths(partitionWriterSegments); + mapStatus = startFileSegmentWriter.write(mapId, partitionWriterSegments); + return; + } + final SerializerInstance serInstance = serializer.newInstance(); + final long openStartTime = System.nanoTime(); + partitionWriters = new DiskBlockObjectWriter[numPartitions]; + partitionWriterSegments = new FileSegment[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + final Tuple2 tempShuffleBlockIdPlusFile = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = tempShuffleBlockIdPlusFile._2(); + final BlockId blockId = tempShuffleBlockIdPlusFile._1(); + partitionWriters[i] = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incWriteTime(System.nanoTime() - openStartTime); + + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + partitionWriters[partitioner.getPartition(key)].write(key, record._2()); + } + + for (int i = 0; i < numPartitions; i++) { + try (DiskBlockObjectWriter writer = partitionWriters[i]) { + partitionWriterSegments[i] = writer.commitAndGet(); + } + } + + partitionLengths = StarUtils.getLengths(partitionWriterSegments); + mapStatus = startFileSegmentWriter.write(mapId, partitionWriterSegments); + } catch (Exception e) { + try { + mapOutputWriter.abort(e); + } catch (Exception e2) { + logger.error("Failed to abort the writer after failing to write map output.", e2); + e.addSuppressed(e2); + } + throw e; + } + } + + @VisibleForTesting + public long[] getPartitionLengths() { + return partitionLengths; + } + + @Override + public Option stop(boolean success) { + if (stopping) { + return None$.empty(); + } else { + stopping = true; + if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + // The map task failed, so delete our output data. + if (partitionWriters != null) { + try { + for (DiskBlockObjectWriter writer : partitionWriters) { + // This method explicitly does _not_ throw exceptions: + File file = writer.revertPartialWritesAndClose(); + if (!file.delete()) { + logger.error("Error while deleting file {}", file.getAbsolutePath()); + } + } + } finally { + partitionWriters = null; + } + } + return None$.empty(); + } + } + } +} diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarLocalFileShuffleFileManager.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarLocalFileShuffleFileManager.java new file mode 100644 index 0000000000000..1a61c7aa89e66 --- /dev/null +++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarLocalFileShuffleFileManager.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.starshuffle; + +import org.apache.commons.io.IOUtils; +import org.apache.spark.network.util.LimitedInputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.*; +import java.nio.file.Paths; +import java.util.UUID; + +/** + * This class read/write shuffle file on external storage like local file or network file system. + */ +public class StarLocalFileShuffleFileManager implements StarShuffleFileManager { + private static final Logger logger = LoggerFactory.getLogger( + StarLocalFileShuffleFileManager.class); + + @Override + public String createFile(String root) { + try { + if (root == null || root.isEmpty()) { + File file = File.createTempFile("shuffle", ".data"); + return file.getAbsolutePath(); + } else { + String fileName = String.format("shuffle-%s.data", UUID.randomUUID()); + return Paths.get(root, fileName).toString(); + } + } catch (IOException e) { + throw new RuntimeException("Failed to create shuffle file", e); + } + } + + @Override + public void write(InputStream data, long size, String file) { + logger.info("Writing to shuffle file: {}", file); + try (FileOutputStream outputStream = new FileOutputStream(file)) { + long copiedBytes = IOUtils.copyLarge(data, outputStream); + if (copiedBytes != size) { + throw new RuntimeException(String.format( + "Got corrupted shuffle data when writing to " + + "file %s, expected size: %s, actual written size: %s", + file, copiedBytes)); + } + } catch (IOException e) { + throw new RuntimeException(String.format( + "Failed to write shuffle file %s", + file)); + } + } + + @Override + public InputStream read(String file, long offset, long size) { + logger.info("Opening shuffle file: {}, offset: {}, size: {}", file, offset, size); + try { + FileInputStream inputStream = new FileInputStream(file); + inputStream.skip(offset); + return new LimitedInputStream(inputStream, size); + } catch (IOException e) { + throw new RuntimeException(String.format( + "Failed to open shuffle file %s, offset: %s, size: %s", + file, offset, size)); + } + } +} diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarMapResultFileInfo.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarMapResultFileInfo.java new file mode 100644 index 0000000000000..c132900c0831d --- /dev/null +++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarMapResultFileInfo.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.starshuffle; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import java.util.Arrays; +import java.util.Base64; + +/** + * This class stores map result file (shuffle file) information. + */ +public class StarMapResultFileInfo { + private final String location; + private final long[] partitionLengths; + + public void serialize(ByteBuf buf) { + + ByteBufUtils.writeLengthAndString(buf, location); + buf.writeInt(partitionLengths.length); + for (int i = 0; i < partitionLengths.length; i++) { + buf.writeLong(partitionLengths[i]); + } + } + + public static StarMapResultFileInfo deserialize(ByteBuf buf) { + String location = ByteBufUtils.readLengthAndString(buf); + int size = buf.readInt(); + long[] partitionLengths = new long[size]; + for (int i = 0; i < size; i++) { + partitionLengths[i] = buf.readLong(); + } + return new StarMapResultFileInfo(location, partitionLengths); + } + + /*** + * This serialize method is faster than json serialization. + * @return + */ + public String serializeToString() { + ByteBuf buf = Unpooled.buffer(); + try { + serialize(buf); + byte[] bytes = new byte[buf.readableBytes()]; + buf.readBytes(bytes); + return Base64.getEncoder().encodeToString(bytes); + } finally { + buf.release(); + } + } + + public static StarMapResultFileInfo deserializeFromString(String str) { + byte[] bytes = Base64.getDecoder().decode(str); + ByteBuf buf = Unpooled.wrappedBuffer(bytes); + try { + return deserialize(buf); + } finally { + buf.release(); + } + } + + public StarMapResultFileInfo(String location, long[] partitionLengths) { + this.location = location; + this.partitionLengths = partitionLengths; + } + + public long[] getPartitionLengths() { + return partitionLengths; + } + + public String getLocation() { + return location; + } + + @Override + public String toString() { + return "MapResultFile{" + + "location='" + location + '\'' + + ", partitionLengths=" + Arrays.toString(partitionLengths) + + '}'; + } +} diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarS3ShuffleFileManager.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarS3ShuffleFileManager.java new file mode 100644 index 0000000000000..6c7d9e83e7a1f --- /dev/null +++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarS3ShuffleFileManager.java @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.starshuffle; + +import com.amazonaws.ClientConfiguration; +import com.amazonaws.client.builder.ExecutorFactory; +import com.amazonaws.event.ProgressEvent; +import com.amazonaws.event.ProgressListener; +import com.amazonaws.regions.Regions; +import com.amazonaws.services.s3.AmazonS3; +import com.amazonaws.services.s3.AmazonS3ClientBuilder; +import com.amazonaws.services.s3.model.GetObjectRequest; +import com.amazonaws.services.s3.model.ObjectMetadata; +import com.amazonaws.services.s3.model.PutObjectRequest; +import com.amazonaws.services.s3.transfer.TransferManager; +import com.amazonaws.services.s3.transfer.TransferManagerBuilder; +import org.apache.hadoop.conf.Configuration; +import org.apache.spark.SparkConf; +import org.apache.spark.deploy.SparkHadoopUtil; +import org.apache.spark.network.util.LimitedInputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.*; +import java.net.URI; +import java.util.UUID; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicLong; + +/** + * This class read/write shuffle file on external storage like S3. + */ +public class StarS3ShuffleFileManager implements StarShuffleFileManager { + private static final Logger logger = LoggerFactory.getLogger(StarS3ShuffleFileManager.class); + + // TODO make following values configurable + public final static int S3_PUT_TIMEOUT_MILLISEC = 180 * 1000; + + // Following constants are copied from: + // https://github.com/apache/hadoop/blob/6c6d1b64d4a7cd5288fcded78043acaf23228f96/hadoop-tools/hadoop-aws/src/main/java/org/apache/hadoop/fs/s3a/Constants.java + public static final long DEFAULT_MULTIPART_SIZE = 67108864; // 64M + public static final long DEFAULT_MIN_MULTIPART_THRESHOLD = 134217728; // 128M + public static final String MAX_THREADS = "fs.s3a.threads.max"; + public static final int DEFAULT_MAX_THREADS = 10; + public static final String KEEPALIVE_TIME = "fs.s3a.threads.keepalivetime"; + public static final int DEFAULT_KEEPALIVE_TIME = 60; + + public static final String AWS_REGION = "fs.s3a.endpoint.region"; + public static final String DEFAULT_AWS_REGION = Regions.US_WEST_2.getName(); + + private static TransferManager transferManager; + private static final Object transferManagerLock = new Object(); + + private final String awsRegion; + private final int maxThreads; + private final long keepAliveTime; + + public StarS3ShuffleFileManager(SparkConf conf) { + Configuration hadoopConf = SparkHadoopUtil.get().newConfiguration(conf); + + awsRegion = hadoopConf.get(AWS_REGION, DEFAULT_AWS_REGION); + + int threads = conf.getInt(MAX_THREADS, DEFAULT_MAX_THREADS); + if (threads < 2) { + logger.warn(MAX_THREADS + " must be at least 2: forcing to 2."); + threads = 2; + } + maxThreads = threads; + + keepAliveTime = conf.getLong(KEEPALIVE_TIME, DEFAULT_KEEPALIVE_TIME); + } + + @Override + public String createFile(String root) { + if (!root.endsWith("/")) { + root = root + "/"; + } + String fileName = String.format("shuffle-%s.data", UUID.randomUUID()); + return root + fileName; + } + + @Override + public void write(InputStream data, long size, String file) { + logger.info("Writing to shuffle file: {}", file); + writeS3(data, size, file); + } + + @Override + public InputStream read(String file, long offset, long size) { + logger.info("Opening shuffle file: {}, offset: {}, size: {}", file, offset, size); + return readS3(file, offset, size); + } + + private void writeS3(InputStream inputStream, long size, String s3Url) { + logger.info("Uploading shuffle file to s3: {}, size: {}", s3Url, size); + + S3BucketAndKey bucketAndKey = S3BucketAndKey.getFromUrl(s3Url); + String bucket = bucketAndKey.getBucket(); + String key = bucketAndKey.getKey(); + + TransferManager transferManager = getTransferManager(); + + ObjectMetadata metadata = new ObjectMetadata(); + metadata.setContentType("application/octet-stream"); + metadata.setContentLength(size); + + PutObjectRequest request = new PutObjectRequest(bucket, + key, + inputStream, + metadata); + + AtomicLong totalTransferredBytes = new AtomicLong(0); + + request.setGeneralProgressListener(new ProgressListener() { + private long lastLogTime = 0; + + @Override + public void progressChanged(ProgressEvent progressEvent) { + long count = progressEvent.getBytesTransferred(); + long total = totalTransferredBytes.addAndGet(count); + long currentTime = System.currentTimeMillis(); + long logInterval = 10000; + if (currentTime - lastLogTime >= logInterval) { + logger.info("S3 upload progress: {}, recent transferred {} bytes, total transferred {}", key, count, total); + lastLogTime = currentTime; + } + } + }); + + // https://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/best-practices.html + request.getRequestClientOptions().setReadLimit((int) DEFAULT_MULTIPART_SIZE + 1); + request.setSdkRequestTimeout(S3_PUT_TIMEOUT_MILLISEC); + request.setSdkClientExecutionTimeout(S3_PUT_TIMEOUT_MILLISEC); + try { + long startTime = System.currentTimeMillis(); + transferManager.upload(request).waitForCompletion(); + long duration = System.currentTimeMillis() - startTime; + double mbs = 0; + if (duration != 0) { + mbs = ((double) size) / (1000 * 1000) / ((double) duration / 1000); + } + logger.info("S3 upload finished: {}, file size: {} bytes, total transferred: {}, throughput: {} mbs", + s3Url, size, totalTransferredBytes.get(), mbs); + } catch (InterruptedException e) { + throw new RuntimeException("Failed to upload to s3: " + key, e); + } + } + + private InputStream readS3(String s3Url, long offset, long size) { + logger.info("Downloading shuffle file from s3: {}, size: {}", s3Url, size); + + S3BucketAndKey bucketAndKey = S3BucketAndKey.getFromUrl(s3Url); + + File downloadTempFile; + try { + downloadTempFile = File.createTempFile("shuffle-download", ".data"); + } catch (IOException e) { + throw new RuntimeException("Failed to create temp file for downloading shuffle file"); + } + + TransferManager transferManager = getTransferManager(); + + GetObjectRequest getObjectRequest = new GetObjectRequest(bucketAndKey.getBucket(), bucketAndKey.getKey()) + .withRange(offset, offset + size); + + AtomicLong totalTransferredBytes = new AtomicLong(0); + + getObjectRequest.setGeneralProgressListener(new ProgressListener() { + private long lastLogTime = 0; + + @Override + public void progressChanged(ProgressEvent progressEvent) { + long count = progressEvent.getBytesTransferred(); + long total = totalTransferredBytes.addAndGet(count); + long currentTime = System.currentTimeMillis(); + long logInterval = 10000; + if (currentTime - lastLogTime >= logInterval) { + logger.info("S3 download progress: {}, recent transferred {} bytes, total transferred {}", s3Url, count, total); + lastLogTime = currentTime; + } + } + }); + + try { + long startTime = System.currentTimeMillis(); + transferManager.download(getObjectRequest, downloadTempFile).waitForCompletion(); + long duration = System.currentTimeMillis() - startTime; + double mbs = 0; + if (duration != 0) { + mbs = ((double) size) / (1000 * 1000) / ((double) duration / 1000); + } + logger.info("S3 download finished: {}, file size: {} bytes, total transferred: {}, throughput: {} mbs", + s3Url, size, totalTransferredBytes.get(), mbs); + } catch (InterruptedException e) { + throw new RuntimeException(String.format( + "Failed to download shuffle file %s", s3Url)); + } finally { + // TODO + transferManager.shutdownNow(); + } + + // TODO delete downloadTempFile + + try { + return new LimitedInputStream(new FileInputStream(downloadTempFile), size); + } catch (FileNotFoundException e) { + throw new RuntimeException(String.format( + "Failed to open downloaded shuffle file %s (from %s)", downloadTempFile, s3Url)); + } + } + + private TransferManager getTransferManager() { + synchronized (transferManagerLock) { + if (transferManager != null) { + return transferManager; + } + transferManager = createTransferManager(awsRegion, maxThreads, keepAliveTime); + return transferManager; + } + } + + private static TransferManager createTransferManager(String region, int maxThreads, long keepAliveTime) { + ClientConfiguration clientConfiguration = new ClientConfiguration(); + clientConfiguration.setConnectionTimeout(S3_PUT_TIMEOUT_MILLISEC); + clientConfiguration.setRequestTimeout(S3_PUT_TIMEOUT_MILLISEC); + clientConfiguration.setSocketTimeout(S3_PUT_TIMEOUT_MILLISEC); + clientConfiguration.setClientExecutionTimeout(S3_PUT_TIMEOUT_MILLISEC); + + ThreadFactory threadFactory = new ThreadFactory() { + private int threadCount = 1; + public Thread newThread(Runnable r) { + Thread thread = new Thread(r); + thread.setName("s3-shuffle-transfer-manager-worker-" + this.threadCount++); + return thread; + } + }; + ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor( + maxThreads, Integer.MAX_VALUE, + keepAliveTime, TimeUnit.SECONDS, + new LinkedBlockingQueue<>(), + threadFactory); + ExecutorFactory executorFactory = new ExecutorFactory() { + @Override + public ExecutorService newExecutor() { + return threadPoolExecutor; + } + }; + + AmazonS3 s3Client = AmazonS3ClientBuilder.standard() + .withRegion(region) + .withClientConfiguration(clientConfiguration) + .build(); + + return TransferManagerBuilder.standard() + .withS3Client(s3Client) + .withMinimumUploadPartSize(DEFAULT_MULTIPART_SIZE) + .withMultipartUploadThreshold(DEFAULT_MIN_MULTIPART_THRESHOLD) + .withMultipartCopyPartSize(DEFAULT_MULTIPART_SIZE) + .withMultipartCopyThreshold(DEFAULT_MIN_MULTIPART_THRESHOLD) + .withExecutorFactory(executorFactory) + .build(); + } + + public static void shutdownTransferManager() { + synchronized (transferManagerLock) { + if (transferManager == null) { + return; + } + transferManager.shutdownNow(true); + transferManager = null; + } + } + + public static class S3BucketAndKey { + private String bucket; + private String key; + + public static S3BucketAndKey getFromUrl(String s3Url) { + URI url = URI.create(s3Url); + String bucket = url.getHost(); + String key = url.getPath(); + if (key.startsWith("/")) { + key = key.substring(1); + } + if (key.isEmpty()) { + throw new RuntimeException(String.format( + "Could not get object key in s3 url: %s", s3Url)); + } + return new S3BucketAndKey(bucket, key); + } + + public S3BucketAndKey(String bucket, String key) { + this.bucket = bucket; + this.key = key; + } + + public String getBucket() { + return bucket; + } + + public String getKey() { + return key; + } + } +} diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarShuffleFileManager.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarShuffleFileManager.java new file mode 100644 index 0000000000000..65c7f849635bc --- /dev/null +++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarShuffleFileManager.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.starshuffle; + +import java.io.InputStream; + +/** + * This is interface to read/write shuffle file on external storage. + */ +public interface StarShuffleFileManager { + String createFile(String root); + + void write(InputStream data, long size, String file); + + InputStream read(String file, long offset, long size); +} diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarUtils.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarUtils.java new file mode 100644 index 0000000000000..86ad58b98af6e --- /dev/null +++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarUtils.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.starshuffle; + +import org.apache.spark.SparkConf; +import org.apache.spark.storage.FileSegment; + +public class StarUtils { + + public static long[] getLengths(FileSegment[] fileSegments) { + long[] lengths = new long[fileSegments.length]; + for (int i = 0; i < lengths.length; i++) { + lengths[i] = fileSegments[i].length(); + } + return lengths; + } + + public static StarShuffleFileManager createShuffleFileManager(SparkConf conf, String path) { + if (path == null || path.isEmpty() || path.startsWith("/")) { + return new StarLocalFileShuffleFileManager(); + } else if (path.toLowerCase().startsWith("s3")) { + return new StarS3ShuffleFileManager(conf); + } else { + throw new RuntimeException(String.format( + "Unsupported path for StarShuffleFileManager: %s", path)); + } + } +} diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StartFileSegmentWriter.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StartFileSegmentWriter.java new file mode 100644 index 0000000000000..d10f7364bdd9d --- /dev/null +++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StartFileSegmentWriter.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.starshuffle; + +import java.io.*; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.spark.SparkEnv; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.shuffle.StarShuffleUtils; +import org.apache.spark.storage.BlockManagerId; +import org.apache.spark.storage.FileSegment; + +/** + * This class writes shuffle segments into the result file. + */ +public class StartFileSegmentWriter { + private String rootDir; + + public StartFileSegmentWriter(String rootDir) { + this.rootDir = rootDir; + } + + public MapStatus write(long mapId, FileSegment[] fileSegments) { + List allStreams = Arrays.stream(fileSegments).map(t -> { + try { + FileInputStream fileStream = new FileInputStream(t.file()); + fileStream.skip(t.offset()); + return new LimitedInputStream(fileStream, t.length()); + } catch (IOException e) { + throw new RuntimeException(String.format( + "Failed to read shuffle temp file %s", t.file()), + e); + } + }).collect(Collectors.toList()); + try (SequenceInputStream sequenceInputStream = new SequenceInputStream( + Collections.enumeration(allStreams))) { + StarShuffleFileManager shuffleFileManager = StarUtils.createShuffleFileManager( + SparkEnv.get().conf(), rootDir); + String resultFile = shuffleFileManager.createFile(rootDir); + long size = Arrays.stream(fileSegments).mapToLong(t->t.length()).sum(); + shuffleFileManager.write(sequenceInputStream, size, resultFile); + long[] partitionLengths = new long[fileSegments.length]; + for (int i = 0; i < partitionLengths.length; i++) { + partitionLengths[i] = fileSegments[i].length(); + } + BlockManagerId blockManagerId = createMapTaskDummyBlockManagerId( + partitionLengths, resultFile); + MapStatus mapStatus = MapStatus$.MODULE$.apply( + blockManagerId, partitionLengths, mapId); + return mapStatus; + } catch (IOException e) { + throw new RuntimeException("Failed to close shuffle temp files", e); + } + } + + private BlockManagerId createMapTaskDummyBlockManagerId(long[] partitionLengths, String file) { + return StarShuffleUtils.createDummyBlockManagerId(file, partitionLengths); + } +} diff --git a/external-shuffle-storage/src/main/resources/log4j-external-shuffle-storage.properties b/external-shuffle-storage/src/main/resources/log4j-external-shuffle-storage.properties new file mode 100644 index 0000000000000..db5d9e512204e --- /dev/null +++ b/external-shuffle-storage/src/main/resources/log4j-external-shuffle-storage.properties @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Set everything to be logged to the console +log4j.rootCategory=INFO, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n diff --git a/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarOpts.scala b/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarOpts.scala new file mode 100644 index 0000000000000..f2bcf56429c85 --- /dev/null +++ b/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarOpts.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.internal.config.{ConfigBuilder, ConfigEntry} + +object StarOpts { + val rootDir: ConfigEntry[String] = + ConfigBuilder("spark.shuffle.star.rootDir") + .doc("Root directory for star shuffle files") + .stringConf + .createWithDefault("") +} diff --git a/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarShuffleManager.scala b/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarShuffleManager.scala new file mode 100644 index 0000000000000..cd68806b6b478 --- /dev/null +++ b/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarShuffleManager.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters._ +import org.apache.spark._ +import org.apache.spark.internal.{Logging, config} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.starshuffle.{StarBlockStoreClient, StarBypassMergeSortShuffleWriter, StarS3ShuffleFileManager} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, StarShuffleBlockFetcherIterator} +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.{ExternalSorter, OpenHashSet} + +/** + * This class is a Shuffle Manager implementation to store shuffle data on external storage + * like S3. + * Most code is copied from SortShuffleManager for quick prototype here. We will collect feedback + * from the community and work further to remove the code duplication. + */ +class StarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + import StarShuffleManager._ + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } + + /** + * A mapping from shuffle ids to the task ids of mappers producing output for those shuffles. + */ + private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]() + + private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) + + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + + /** + * Obtains a [[ShuffleHandle]] to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + new StarBypassMergeSortShuffleHandle[K, V]( + shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } + + override def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + new StarBlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics, + shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + val mapTaskIds = taskIdMapsForShuffle.computeIfAbsent( + handle.shuffleId, _ => new OpenHashSet[Long](16)) + mapTaskIds.synchronized { mapTaskIds.add(context.taskAttemptId()) } + val env = SparkEnv.get + handle match { + case bypassMergeSortHandle: StarBypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => + new StarBypassMergeSortShuffleWriter( + env.blockManager, + bypassMergeSortHandle, + mapId, + env.conf, + metrics, + shuffleExecutorComponents) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { mapTaskIds => + mapTaskIds.iterator.foreach { mapTaskId => + shuffleBlockResolver.removeDataByMap(shuffleId, mapTaskId) + } + } + true + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + shuffleBlockResolver.stop() + + // TODO use a better way to shutdown TransferManager in StarS3ShuffleFileManager + StarS3ShuffleFileManager.shutdownTransferManager() + } +} + + +private[spark] object StarShuffleManager extends Logging { + + /** + * The local property key for continuous shuffle block fetching feature. + */ + val FETCH_SHUFFLE_BLOCKS_IN_BATCH_ENABLED_KEY = + "__fetch_continuous_blocks_in_batch_enabled" + + /** + * Helper method for determining whether a shuffle reader should fetch the continuous blocks + * in batch. + */ + def canUseBatchFetch(startPartition: Int, endPartition: Int, context: TaskContext): Boolean = { + val fetchMultiPartitions = endPartition - startPartition > 1 + fetchMultiPartitions && + context.getLocalProperty(FETCH_SHUFFLE_BLOCKS_IN_BATCH_ENABLED_KEY) == "true" + } + + private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { + val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() + val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX) + .toMap + executorComponents.initializeExecutor( + conf.getAppId, + SparkEnv.get.executorId, + extraConfigs.asJava) + executorComponents + } +} + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the + * serialized shuffle. + */ +private[spark] class SerializedShuffleHandle[K, V]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, dependency) { +} + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the + * bypass merge sort shuffle path. + */ +private[spark] class StarBypassMergeSortShuffleHandle[K, V]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, dependency) { +} + +private[spark] class StarBlockStoreShuffleReader[K, C]( + handle: BaseShuffleHandle[K, _, C], + // (BlockId, Long, Int): similar like FetchBlockInfo (ShuffleBlockId, EstimatedSize, MapIndex) + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + context: TaskContext, + readMetrics: ShuffleReadMetricsReporter, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, + shouldBatchFetch: Boolean = false) + extends ShuffleReader[K, C] with Logging { + + private val dep = handle.dependency + + private def fetchContinuousBlocksInBatch: Boolean = { + val conf = SparkEnv.get.conf + val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects + val compressed = conf.get(config.SHUFFLE_COMPRESS) + val codecConcatenation = if (compressed) { + CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf)) + } else { + true + } + val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL) + + val doBatchFetch = shouldBatchFetch && serializerRelocatable && + (!compressed || codecConcatenation) && !useOldFetchProtocol + if (shouldBatchFetch && !doBatchFetch) { + logDebug("The feature tag of continuous shuffle block fetching is set to true, but " + + "we can not enable the feature because other conditions are not satisfied. " + + s"Shuffle compress: $compressed, serializer relocatable: $serializerRelocatable, " + + s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " + + s"$useOldFetchProtocol.") + } + doBatchFetch + } + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + val starBlockStoreClient = new StarBlockStoreClient() + val wrappedStreams = new StarShuffleBlockFetcherIterator( + context, + starBlockStoreClient, + blockManager, + mapOutputTracker, + blocksByAddress, + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, + SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM), + readMetrics, + fetchContinuousBlocksInBatch).toCompletionIterator + + val serializerInstance = dep.serializer.newInstance() + + // Create a key/value iterator for each stream + val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map { record => + readMetrics.incRecordsRead(1) + record + }, + context.taskMetrics().mergeShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) + + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { + if (dep.mapSideCombine) { + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) + } else { + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) + } + } else { + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] + } + + // Sort the output if there is a sort ordering defined. + val resultIter = dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Create an ExternalSorter to sort the data. + val sorter = + new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) + sorter.insertAll(aggregatedIter) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + context.addTaskCompletionListener[Unit](_ => { + sorter.stop() + }) + CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) + case None => + aggregatedIter + } + + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) + } + } +} diff --git a/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarShuffleUtils.scala b/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarShuffleUtils.scala new file mode 100644 index 0000000000000..edbbf5149afd3 --- /dev/null +++ b/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarShuffleUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.util.UUID + +import org.apache.spark.internal.Logging +import org.apache.spark.starshuffle.StarMapResultFileInfo +import org.apache.spark.storage.BlockManagerId + +object StarShuffleUtils extends Logging { + + def createDummyBlockManagerId(fileLocation: String, + partitionLengths: Array[Long]): BlockManagerId = { + val fileInfo = new StarMapResultFileInfo(fileLocation, partitionLengths) + val dummyHost = fileInfo.serializeToString() + val dummyPort = 9 + val dummyExecId = "starshuffle-" + UUID.randomUUID().toString + BlockManagerId(dummyExecId, dummyHost, dummyPort, None) + } + +} diff --git a/external-shuffle-storage/src/main/scala/org/apache/spark/storage/StarShuffleBlockFetcherIterator.scala b/external-shuffle-storage/src/main/scala/org/apache/spark/storage/StarShuffleBlockFetcherIterator.scala new file mode 100644 index 0000000000000..00d5f11b7431e --- /dev/null +++ b/external-shuffle-storage/src/main/scala/org/apache/spark/storage/StarShuffleBlockFetcherIterator.scala @@ -0,0 +1,1428 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.{IOException, InputStream} +import java.nio.channels.ClosedByInterruptException +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} +import java.util.zip.CheckedInputStream +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.util.{Failure, Success} + +import io.netty.util.internal.OutOfDirectMemoryError +import org.apache.commons.io.IOUtils +import org.apache.spark.errors.SparkCoreErrors +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} +import org.apache.spark.network.util.{NettyUtils, TransportConf} +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} +import org.apache.spark.{MapOutputTracker, TaskContext} +import org.roaringbitmap.RoaringBitmap + +/** + * This class fetches shuffle data from external storage like S3. + * Most code is copied from ShuffleBlockFetcherIterator for quick prototype. Will work + * further to remove code duplication after getting community feedback. + */ +private[spark] +final class StarShuffleBlockFetcherIterator( + context: TaskContext, + shuffleClient: BlockStoreClient, + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, + // (BlockId, Long, Int) in blocksByAddress is similar like: + // FetchBlockInfo (ShuffleBlockId, EstimatedSize, MapIndex) + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + streamWrapper: (BlockId, InputStream) => InputStream, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + val maxReqSizeShuffleToMem: Long, + maxAttemptsOnNettyOOM: Int, + detectCorrupt: Boolean, + detectCorruptUseExtraMemory: Boolean, + checksumEnabled: Boolean, + checksumAlgorithm: String, + shuffleMetrics: ShuffleReadMetricsReporter, + doBatchFetch: Boolean) + extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { + + import ShuffleBlockFetcherIterator._ + + // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L) + + /** + * Total number of blocks to fetch. + */ + private[this] var numBlocksToFetch = 0 + + /** + * The number of blocks processed by the caller. The iterator is exhausted when + * [[numBlocksProcessed]] == [[numBlocksToFetch]]. + */ + private[this] var numBlocksProcessed = 0 + + private[this] val startTimeNs = System.nanoTime() + + /** Host local blocks to fetch, excluding zero-sized blocks. */ + private[this] val hostLocalBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() + + /** + * A queue to hold our results. This turns the asynchronous model provided by + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). + */ + private[this] val results = new LinkedBlockingQueue[FetchResult] + + /** + * Current [[FetchResult]] being processed. We track this so we can release the current buffer + * in case of a runtime exception when processing the current buffer. + */ + @volatile private[this] var currentResult: SuccessFetchResult = null + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + * the number of bytes in flight is limited to maxBytesInFlight. + */ + private[this] val fetchRequests = new Queue[FetchRequest] + + /** + * Queue of fetch requests which could not be issued the first time they were dequeued. These + * requests are tried again when the fetch constraints are satisfied. + */ + private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]() + + /** Current bytes in flight from our requests */ + private[this] var bytesInFlight = 0L + + /** Current number of requests in flight */ + private[this] var reqsInFlight = 0 + + /** Current number of blocks in flight per host:port */ + private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + + /** + * Count the retry times for the blocks due to Netty OOM. The block will stop retry if + * retry times has exceeded the [[maxAttemptsOnNettyOOM]]. + */ + private[this] val blockOOMRetryCounts = new HashMap[String, Int] + + /** + * The blocks that can't be decompressed successfully, it is used to guarantee that we retry + * at most once for those corrupted blocks. + */ + private[this] val corruptedBlocks = mutable.HashSet[BlockId]() + + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @GuardedBy("this") + private[this] var isZombie = false + + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ + @GuardedBy("this") + private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]() + + private[this] val onCompleteCallback = new StarShuffleFetchCompletionListener(this) + + initialize() + + // Decrements the buffer reference count. + // The currentResult is set to null to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { + // Release the current buffer if necessary + if (currentResult != null) { + currentResult.buf.release() + } + currentResult = null + } + + override def createTempFile(transportConf: TransportConf): DownloadFile = { + // we never need to do any encryption or decryption here, regardless of configs, because that + // is handled at another layer in the code. When encryption is enabled, shuffle data is written + // to disk encrypted in the first place, and sent over the network still encrypted. + new SimpleDownloadFile( + blockManager.diskBlockManager.createTempLocalBlock()._2, transportConf) + } + + override def registerTempFileToClean(file: DownloadFile): Boolean = synchronized { + if (isZombie) { + false + } else { + shuffleFilesSet += file + true + } + } + + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[storage] def cleanup(): Unit = { + synchronized { + isZombie = true + } + releaseCurrentResultBuffer() + // Release buffers in the results queue + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + result match { + case SuccessFetchResult(blockId, mapIndex, address, _, buf, _) => + if (address != blockManager.blockManagerId) { + if (hostLocalBlocks.contains(blockId -> mapIndex)) { + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + } else { + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + shuffleMetrics.incRemoteBlocksFetched(1) + } + } + buf.release() + case _ => + } + } + shuffleFilesSet.foreach { file => + if (!file.delete()) { + logWarning("Failed to cleanup shuffle fetch temp file " + file.path()) + } + } + } + + private[this] def sendRequest(req: FetchRequest): Unit = { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + bytesInFlight += req.size + reqsInFlight += 1 + + // so we can look up the block info of each blockID + val infoMap = req.blocks.map { + case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex)) + }.toMap + val remainingBlocks = new HashSet[String]() ++= infoMap.keys + val deferredBlocks = new ArrayBuffer[String]() + val blockIds = req.blocks.map(_.blockId.toString) + val address = req.address + + @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = { + if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) { + val blocks = deferredBlocks.map { blockId => + val (size, mapIndex) = infoMap(blockId) + FetchBlockInfo(BlockId(blockId), size, mapIndex) + } + results.put(DeferFetchRequestResult(FetchRequest(address, blocks.toSeq))) + deferredBlocks.clear() + } + } + + val blockFetchingListener = new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + StarShuffleBlockFetcherIterator.this.synchronized { + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + remainingBlocks -= blockId + blockOOMRetryCounts.remove(blockId) + results.put(new SuccessFetchResult(BlockId(blockId), infoMap(blockId)._2, + address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) + enqueueDeferredFetchRequestIfNecessary() + } + } + logTrace(s"Got remote block $blockId after ${Utils.getUsedTimeNs(startTimeNs)}") + } + + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { + StarShuffleBlockFetcherIterator.this.synchronized { + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + e match { + // SPARK-27991: Catch the Netty OOM and set the flag `isNettyOOMOnShuffle` (shared among + // tasks) to true as early as possible. The pending fetch requests won't be sent + // afterwards until the flag is set to false on: + // 1) the Netty free memory >= maxReqSizeShuffleToMem + // - we'll check this whenever there's a fetch request succeeds. + // 2) the number of in-flight requests becomes 0 + // - we'll check this in `fetchUpToMaxBytes` whenever it's invoked. + // Although Netty memory is shared across multiple modules, e.g., shuffle, rpc, the flag + // only takes effect for the shuffle due to the implementation simplicity concern. + // And we'll buffer the consecutive block failures caused by the OOM error until there's + // no remaining blocks in the current request. Then, we'll package these blocks into + // a same fetch request for the retry later. In this way, instead of creating the fetch + // request per block, it would help reduce the concurrent connections and data loads + // pressure at remote server. + // Note that catching OOM and do something based on it is only a workaround for + // handling the Netty OOM issue, which is not the best way towards memory management. + // We can get rid of it when we find a way to manage Netty's memory precisely. + case _: OutOfDirectMemoryError + if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < maxAttemptsOnNettyOOM => + if (!isZombie) { + val failureTimes = blockOOMRetryCounts(blockId) + blockOOMRetryCounts(blockId) += 1 + if (isNettyOOMOnShuffle.compareAndSet(false, true)) { + // The fetcher can fail remaining blocks in batch for the same error. So we only + // log the warning once to avoid flooding the logs. + logInfo(s"Block $blockId has failed $failureTimes times " + + s"due to Netty OOM, will retry") + } + remainingBlocks -= blockId + deferredBlocks += blockId + enqueueDeferredFetchRequestIfNecessary() + } + + case _ => + val block = BlockId(blockId) + if (block.isShuffleChunk) { + remainingBlocks -= blockId + results.put(FallbackOnPushMergedFailureResult( + block, address, infoMap(blockId)._1, remainingBlocks.isEmpty)) + } else { + results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e)) + } + } + } + } + } + + // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is + // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch + // the data and write it to file directly. + if (req.size > maxReqSizeShuffleToMem) { + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, + blockFetchingListener, this) + } else { + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, + blockFetchingListener, null) + } + } + + /** + * This is called from initialize and also from the fallback which is triggered from + * [[PushBasedFetchHelper]]. + */ + private[this] def partitionBlocksByFetchMode( + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + localBlocks: mutable.LinkedHashSet[(BlockId, Int)], + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]], + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { + logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " + + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress") + + // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote) + // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight + // in order to limit the amount of data in flight + val collectedRemoteRequests = new ArrayBuffer[FetchRequest] + var localBlockBytes = 0L + var hostLocalBlockBytes = 0L + var numHostLocalBlocks = 0 + var pushMergedLocalBlockBytes = 0L + val prevNumBlocksToFetch = numBlocksToFetch + + val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId + val localExecIds = Set(blockManager.blockManagerId.executorId, fallback) + for ((address, blockInfos) <- blocksByAddress) { + checkBlockSizes(blockInfos) + if (localExecIds.contains(address.executorId)) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) + localBlockBytes += mergedBlockInfos.map(_.size).sum + } else if (blockManager.hostLocalDirManager.isDefined && + address.host == blockManager.blockManagerId.host) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + val blocksForAddress = + mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex)) + hostLocalBlocksByExecutor += address -> blocksForAddress + numHostLocalBlocks += blocksForAddress.size + hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum + } else { + val (_, timeCost) = Utils.timeTakenMs[Unit] { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + logDebug(s"Collected remote fetch requests for $address in $timeCost ms") + } + } + val (remoteBlockBytes, numRemoteBlocks) = + collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size)) + val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes + + pushMergedLocalBlockBytes + val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch + assert(blocksToFetchCurrentIteration == localBlocks.size + + numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size, + s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to the sum " + + s"of the number of local blocks ${localBlocks.size} + " + + s"the number of host-local blocks ${numHostLocalBlocks} " + + s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " + + s"+ the number of remote blocks ${numRemoteBlocks} ") + logInfo(s"Getting $blocksToFetchCurrentIteration " + + s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " + + s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + + s"${numHostLocalBlocks} (${Utils.bytesToString(hostLocalBlockBytes)}) " + + s"host-local and ${pushMergedLocalBlocks.size} " + + s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " + + s"push-merged-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + + s"remote blocks") + this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values + .flatMap { infos => infos.map(info => (info._1, info._3)) } + collectedRemoteRequests + } + + private def createFetchRequest( + blocks: Seq[FetchBlockInfo], + address: BlockManagerId, + forMergedMetas: Boolean): FetchRequest = { + logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address " + + s"with ${blocks.size} blocks") + FetchRequest(address, blocks, forMergedMetas) + } + + private def createFetchRequests( + curBlocks: Seq[FetchBlockInfo], + address: BlockManagerId, + isLast: Boolean, + collectedRemoteRequests: ArrayBuffer[FetchRequest], + enableBatchFetch: Boolean, + forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = { + val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch) + numBlocksToFetch += mergedBlocks.size + val retBlocks = new ArrayBuffer[FetchBlockInfo] + if (mergedBlocks.length <= maxBlocksInFlightPerAddress) { + collectedRemoteRequests += createFetchRequest(mergedBlocks, address, forMergedMetas) + } else { + mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { blocks => + if (blocks.length == maxBlocksInFlightPerAddress || isLast) { + collectedRemoteRequests += createFetchRequest(blocks, address, forMergedMetas) + } else { + // The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back + // to `curBlocks`. + retBlocks ++= blocks + numBlocksToFetch -= blocks.size + } + } + } + retBlocks + } + + private def collectFetchRequests( + address: BlockManagerId, + blockInfos: Seq[(BlockId, Long, Int)], + collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = { + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[FetchBlockInfo]() + + while (iterator.hasNext) { + val (blockId, size, mapIndex) = iterator.next() + curBlocks += FetchBlockInfo(blockId, size, mapIndex) + curRequestSize += size + blockId match { + // Either all blocks are push-merged blocks, shuffle chunks, or original blocks. + // Based on these types, we decide to do batch fetch and create FetchRequests with + // forMergedMetas set. + case ShuffleBlockChunkId(_, _, _, _) => + if (curRequestSize >= targetRemoteRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress) { + curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false, + collectedRemoteRequests, enableBatchFetch = false) + curRequestSize = curBlocks.map(_.size).sum + } + case ShuffleMergedBlockId(_, _, _) => + if (curBlocks.size >= maxBlocksInFlightPerAddress) { + curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false, + collectedRemoteRequests, enableBatchFetch = false, forMergedMetas = true) + } + case _ => + // For batch fetch, the actual block in flight should count for merged block. + val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress + if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) { + curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false, + collectedRemoteRequests, doBatchFetch) + curRequestSize = curBlocks.map(_.size).sum + } + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + val (enableBatchFetch, forMergedMetas) = { + curBlocks.head.blockId match { + case ShuffleBlockChunkId(_, _, _, _) => (false, false) + case ShuffleMergedBlockId(_, _, _) => (false, true) + case _ => (doBatchFetch, false) + } + } + createFetchRequests(curBlocks.toSeq, address, isLast = true, collectedRemoteRequests, + enableBatchFetch = enableBatchFetch, forMergedMetas = forMergedMetas) + } + } + + private def assertPositiveBlockSize(blockId: BlockId, blockSize: Long): Unit = { + if (blockSize < 0) { + throw BlockException(blockId, "Negative block size " + size) + } else if (blockSize == 0) { + throw BlockException(blockId, "Zero-sized blocks should be excluded.") + } + } + + private def checkBlockSizes(blockInfos: Seq[(BlockId, Long, Int)]): Unit = { + blockInfos.foreach { case (blockId, size, _) => assertPositiveBlockSize(blockId, size) } + } + + /** + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we + * track in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchLocalBlocks( + localBlocks: mutable.LinkedHashSet[(BlockId, Int)]): Unit = { + logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") + val iter = localBlocks.iterator + while (iter.hasNext) { + val (blockId, mapIndex) = iter.next() + try { + val buf = blockManager.getLocalBlockData(blockId) + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + buf.retain() + results.put(new SuccessFetchResult(blockId, mapIndex, blockManager.blockManagerId, + buf.size(), buf, false)) + } catch { + // If we see an exception, stop immediately. + case e: Exception => + e match { + // ClosedByInterruptException is an excepted exception when kill task, + // don't log the exception stack trace to avoid confusing users. + // See: SPARK-28340 + case ce: ClosedByInterruptException => + logError("Error occurred while fetching local blocks, " + ce.getMessage) + case ex: Exception => logError("Error occurred while fetching local blocks", ex) + } + results.put(new FailureFetchResult(blockId, mapIndex, blockManager.blockManagerId, e)) + return + } + } + } + + private[this] def fetchHostLocalBlock( + blockId: BlockId, + mapIndex: Int, + localDirs: Array[String], + blockManagerId: BlockManagerId): Boolean = { + try { + val buf = blockManager.getHostLocalShuffleData(blockId, localDirs) + buf.retain() + results.put(SuccessFetchResult(blockId, mapIndex, blockManagerId, buf.size(), buf, + isNetworkReqDone = false)) + true + } catch { + case e: Exception => + // If we see an exception, stop immediately. + logError(s"Error occurred while fetching local blocks", e) + results.put(FailureFetchResult(blockId, mapIndex, blockManagerId, e)) + false + } + } + + /** + * Fetch the host-local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we + * track in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchHostLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]): + Unit = { + val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs + val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = { + val (hasCache, noCache) = hostLocalBlocksByExecutor.partition { case (hostLocalBmId, _) => + cachedDirsByExec.contains(hostLocalBmId.executorId) + } + (hasCache.toMap, noCache.toMap) + } + + if (hostLocalBlocksWithMissingDirs.nonEmpty) { + logDebug(s"Asynchronous fetching host-local blocks without cached executors' dir: " + + s"${hostLocalBlocksWithMissingDirs.mkString(", ")}") + + // If the external shuffle service is enabled, we'll fetch the local directories for + // multiple executors from the external shuffle service, which located at the same host + // with the executors, in once. Otherwise, we'll fetch the local directories from those + // executors directly one by one. The fetch requests won't be too much since one host is + // almost impossible to have many executors at the same time practically. + val dirFetchRequests = if (blockManager.externalShuffleServiceEnabled) { + val host = blockManager.blockManagerId.host + val port = blockManager.externalShuffleServicePort + Seq((host, port, hostLocalBlocksWithMissingDirs.keys.toArray)) + } else { + hostLocalBlocksWithMissingDirs.keys.map(bmId => (bmId.host, bmId.port, Array(bmId))).toSeq + } + + dirFetchRequests.foreach { case (host, port, bmIds) => + hostLocalDirManager.getHostLocalDirs(host, port, bmIds.map(_.executorId)) { + case Success(dirsByExecId) => + fetchMultipleHostLocalBlocks( + hostLocalBlocksWithMissingDirs.filterKeys(bmIds.contains).toMap, + dirsByExecId, + cached = false) + + case Failure(throwable) => + logError("Error occurred while fetching host local blocks", throwable) + val bmId = bmIds.head + val blockInfoSeq = hostLocalBlocksWithMissingDirs(bmId) + val (blockId, _, mapIndex) = blockInfoSeq.head + results.put(FailureFetchResult(blockId, mapIndex, bmId, throwable)) + } + } + } + + if (hostLocalBlocksWithCachedDirs.nonEmpty) { + logDebug(s"Synchronous fetching host-local blocks with cached executors' dir: " + + s"${hostLocalBlocksWithCachedDirs.mkString(", ")}") + fetchMultipleHostLocalBlocks(hostLocalBlocksWithCachedDirs, cachedDirsByExec, cached = true) + } + } + + private def fetchMultipleHostLocalBlocks( + bmIdToBlocks: Map[BlockManagerId, Seq[(BlockId, Long, Int)]], + localDirsByExecId: Map[String, Array[String]], + cached: Boolean): Unit = { + // We use `forall` because once there's a failed block fetch, `fetchHostLocalBlock` will put + // a `FailureFetchResult` immediately to the `results`. So there's no reason to fetch the + // remaining blocks. + val allFetchSucceeded = bmIdToBlocks.forall { case (bmId, blockInfos) => + blockInfos.forall { case (blockId, _, mapIndex) => + fetchHostLocalBlock(blockId, mapIndex, localDirsByExecId(bmId.executorId), bmId) + } + } + if (allFetchSucceeded) { + logDebug(s"Got host-local blocks from ${bmIdToBlocks.keys.mkString(", ")} " + + s"(${if (cached) "with" else "without"} cached executors' dir) " + + s"in ${Utils.getUsedTimeNs(startTimeNs)}") + } + } + + private[this] def initialize(): Unit = { + // Add a task completion callback (called in both success case and failure case) to cleanup. + context.addTaskCompletionListener(onCompleteCallback) + // Local blocks to fetch, excluding zero-sized blocks. + val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val hostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() + val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + // Partition blocks by the different fetch modes: local, host-local, push-merged-local and + // remote blocks. + val remoteRequests = partitionBlocksByFetchMode( + blocksByAddress, localBlocks, hostLocalBlocksByExecutor, pushMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + assert ((0 == reqsInFlight) == (0 == bytesInFlight), + "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight + + ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight) + + // Send out initial requests for blocks, up to our maxBytesInFlight + fetchUpToMaxBytes() + + val numDeferredRequest = deferredFetchRequests.values.map(_.size).sum + val numFetches = remoteRequests.size - fetchRequests.size - numDeferredRequest + logInfo(s"Started $numFetches remote fetches in ${Utils.getUsedTimeNs(startTimeNs)}" + + (if (numDeferredRequest > 0 ) s", deferred $numDeferredRequest requests" else "")) + + // Get Local Blocks + fetchLocalBlocks(localBlocks) + logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") + // Get host local blocks if any + fetchAllHostLocalBlocks(hostLocalBlocksByExecutor) + } + + private def fetchAllHostLocalBlocks( + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]): + Unit = { + if (hostLocalBlocksByExecutor.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks(_, hostLocalBlocksByExecutor)) + } + } + + override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + + /** + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers + * underlying each InputStream will be freed by the cleanup() method registered with the + * TaskCompletionListener. However, callers should close() these InputStreams + * as soon as they are no longer needed, in order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. + */ + override def next(): (BlockId, InputStream) = { + if (!hasNext) { + throw SparkCoreErrors.noSuchElementError() + } + + numBlocksProcessed += 1 + + var result: FetchResult = null + var input: InputStream = null + // This's only initialized when shuffle checksum is enabled. + var checkedIn: CheckedInputStream = null + var streamCompressedOrEncrypted: Boolean = false + // Take the next fetched result and try to decompress it to detect data corruption, + // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch + // is also corrupt, so the previous stage could be retried. + // For local shuffle block, throw FailureFetchResult for the first IOException. + while (result == null) { + val startFetchWait = System.nanoTime() + result = results.take() + val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait) + shuffleMetrics.incFetchWaitTime(fetchWaitTime) + + result match { + case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => + if (address != blockManager.blockManagerId) { + if (hostLocalBlocks.contains(blockId -> mapIndex)) { + // It is a host local block or a local shuffle chunk + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + } else { + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + shuffleMetrics.incRemoteBlocksFetched(1) + bytesInFlight -= size + } + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + resetNettyOOMFlagIfPossible(maxReqSizeShuffleToMem) + logDebug("Number of requests in flight " + reqsInFlight) + } + + if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing" + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + throwFetchFailedException(blockId, mapIndex, address, new IOException(msg)) + } + + val in = try { + var bufIn = buf.createInputStream() + if (checksumEnabled) { + val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) + checkedIn = new CheckedInputStream(bufIn, checksum) + checkedIn + } else { + bufIn + } + } catch { + // The exception could only be throwed by local shuffle block + case e: IOException => + assert(buf.isInstanceOf[FileSegmentManagedBuffer]) + e match { + case ce: ClosedByInterruptException => + logError("Failed to create input stream from local block, " + + ce.getMessage) + case e: IOException => logError("Failed to create input stream from local block", e) + } + buf.release() + throwFetchFailedException(blockId, mapIndex, address, e) + } + if (in != null) { + try { + input = streamWrapper(blockId, in) + // If the stream is compressed or wrapped, then we optionally decompress/unwrap the + // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion + // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if + // the corruption is later, we'll still detect the corruption later in the stream. + streamCompressedOrEncrypted = !input.eq(in) + if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { + // TODO: manage the memory used here, and spill it into disk in case of OOM. + input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) + } + } catch { + case e: IOException => + // When shuffle checksum is enabled, for a block that is corrupted twice, + // we'd calculate the checksum of the block by consuming the remaining data + // in the buf. So, we should release the buf later. + if (!(checksumEnabled && corruptedBlocks.contains(blockId))) { + buf.release() + } + + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else if (corruptedBlocks.contains(blockId)) { + // It's the second time this block is detected corrupted + if (checksumEnabled) { + // Diagnose the cause of data corruption if shuffle checksum is enabled + val diagnosisResponse = diagnoseCorruption(checkedIn, address, blockId) + buf.release() + logError(diagnosisResponse) + throwFetchFailedException( + blockId, mapIndex, address, e, Some(diagnosisResponse)) + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } else { + // It's the first time this block is detected corrupted + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null + } + } finally { + // TODO: release the buf here to free memory earlier + if (input == null) { + // Close the underlying stream if there was an issue in wrapping the stream using + // streamWrapper + in.close() + } + } + } + + case FailureFetchResult(blockId, mapIndex, address, e) => + var errorMsg: String = null + if (e.isInstanceOf[OutOfDirectMemoryError]) { + errorMsg = s"Block $blockId fetch failed after $maxAttemptsOnNettyOOM " + + s"retries due to Netty OOM" + logError(errorMsg) + } + throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) + + case DeferFetchRequestResult(request) => + val address = request.address + numBlocksInFlightPerAddress(address) = + numBlocksInFlightPerAddress(address) - request.blocks.size + bytesInFlight -= request.size + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + val defReqQueue = + deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + result = null + + case FallbackOnPushMergedFailureResult(_, _, _, _) => + result = null + case PushMergedLocalMetaFetchResult(_, _, _, _, _) => + result = null + case PushMergedRemoteMetaFailedFetchResult(_, _, _, _) => + result = null + case PushMergedRemoteMetaFetchResult(_, _, _, _, _, _) => + result = null + } + + // Send fetch requests up to maxBytesInFlight + fetchUpToMaxBytes() + } + + currentResult = result.asInstanceOf[SuccessFetchResult] + (currentResult.blockId, + new StarBufferReleasingInputStream( + input, + this, + currentResult.blockId, + currentResult.mapIndex, + currentResult.address, + detectCorrupt && streamCompressedOrEncrypted, + currentResult.isNetworkReqDone, + Option(checkedIn))) + } + + /** + * Get the suspect corruption cause for the corrupted block. It should be only invoked + * when checksum is enabled and corruption was detected at least once. + * + * This will firstly consume the rest of stream of the corrupted block to calculate the + * checksum of the block. Then, it will raise a synchronized RPC call along with the + * checksum to ask the server(where the corrupted block is fetched from) to diagnose the + * cause of corruption and return it. + * + * Any exception raised during the process will result in the [[Cause.UNKNOWN_ISSUE]] of the + * corruption cause since corruption diagnosis is only a best effort. + * + * @param checkedIn the [[CheckedInputStream]] which is used to calculate the checksum. + * @param address the address where the corrupted block is fetched from. + * @param blockId the blockId of the corrupted block. + * @return The corruption diagnosis response for different causes. + */ + private[storage] def diagnoseCorruption( + checkedIn: CheckedInputStream, + address: BlockManagerId, + blockId: BlockId): String = { + logInfo("Start corruption diagnosis.") + val startTimeNs = System.nanoTime() + assert(blockId.isInstanceOf[ShuffleBlockId], s"Expected ShuffleBlockId, but got $blockId") + val shuffleBlock = blockId.asInstanceOf[ShuffleBlockId] + val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER) + // consume the remaining data to calculate the checksum + var cause: Cause = null + try { + while (checkedIn.read(buffer) != -1) {} + val checksum = checkedIn.getChecksum.getValue + cause = shuffleClient.diagnoseCorruption(address.host, address.port, address.executorId, + shuffleBlock.shuffleId, shuffleBlock.mapId, shuffleBlock.reduceId, checksum, + checksumAlgorithm) + } catch { + case e: Exception => + logWarning("Unable to diagnose the corruption cause of the corrupted block", e) + cause = Cause.UNKNOWN_ISSUE + } + val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + val diagnosisResponse = cause match { + case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => + s"Block $blockId is corrupted but corruption diagnosis failed due to " + + s"unsupported checksum algorithm: $checksumAlgorithm" + + case Cause.CHECKSUM_VERIFY_PASS => + s"Block $blockId is corrupted but checksum verification passed" + + case Cause.UNKNOWN_ISSUE => + s"Block $blockId is corrupted but the cause is unknown" + + case otherCause => + s"Block $blockId is corrupted due to $otherCause" + } + logInfo(s"Finished corruption diagnosis in $duration ms. $diagnosisResponse") + diagnosisResponse + } + + def toCompletionIterator: Iterator[(BlockId, InputStream)] = { + CompletionIterator[(BlockId, InputStream), this.type](this, + onCompleteCallback.onComplete(context)) + } + + private def fetchUpToMaxBytes(): Unit = { + if (isNettyOOMOnShuffle.get()) { + if (reqsInFlight > 0) { + // Return immediately if Netty is still OOMed and there're ongoing fetch requests + return + } else { + resetNettyOOMFlagIfPossible(0) + } + } + + // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host + // immediately, defer the request until the next time it can be processed. + + // Process any outstanding deferred fetch requests if possible. + if (deferredFetchRequests.nonEmpty) { + for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { + while (isRemoteBlockFetchable(defReqQueue) && + !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) { + val request = defReqQueue.dequeue() + logDebug(s"Processing deferred fetch request for $remoteAddress with " + + s"${request.blocks.length} blocks") + send(remoteAddress, request) + if (defReqQueue.isEmpty) { + deferredFetchRequests -= remoteAddress + } + } + } + } + + // Process any regular fetch requests if possible. + while (isRemoteBlockFetchable(fetchRequests)) { + val request = fetchRequests.dequeue() + val remoteAddress = request.address + if (isRemoteAddressMaxedOut(remoteAddress, request)) { + logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks") + val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + deferredFetchRequests(remoteAddress) = defReqQueue + } else { + send(remoteAddress, request) + } + } + + def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { + sendRequest(request) + numBlocksInFlightPerAddress(remoteAddress) = + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size + } + + def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { + fetchReqQueue.nonEmpty && + (bytesInFlight == 0 || + (reqsInFlight + 1 <= maxReqsInFlight && + bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight)) + } + + // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a + // given remote address. + def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = { + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > + maxBlocksInFlightPerAddress + } + } + + private[storage] def throwFetchFailedException( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + e: Throwable, + message: Option[String] = None) = { + val msg = message.getOrElse(e.getMessage) + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + throw SparkCoreErrors.fetchFailedError(address, shufId, mapId, mapIndex, reduceId, msg, e) + case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) => + throw SparkCoreErrors.fetchFailedError(address, shuffleId, mapId, mapIndex, startReduceId, + msg, e) + case _ => throw SparkCoreErrors.failToGetNonShuffleBlockError(blockId, e) + } + } + + /** + * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator + */ + private[storage] def addToResultsQueue(result: FetchResult): Unit = { + results.put(result) + } + + private[storage] def decreaseNumBlocksToFetch(blocksFetched: Int): Unit = { + numBlocksToFetch -= blocksFetched + } + + /** + * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch + * failure related to a push-merged block or shuffle chunk. + * This is executed by the task thread when the `iterator.next()` is invoked and if that initiates + * fallback. + */ + private[storage] def fallbackFetch( + originalBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = { + val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val originalHostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() + val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + val originalRemoteReqs = partitionBlocksByFetchMode(originalBlocksByAddr, + originalLocalBlocks, originalHostLocalBlocksByExecutor, originalMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(originalRemoteReqs) + logInfo(s"Created ${originalRemoteReqs.size} fallback remote requests for push-merged") + // fetch all the fallback blocks that are local. + fetchLocalBlocks(originalLocalBlocks) + // Merged local blocks should be empty during fallback + assert(originalMergedLocalBlocks.isEmpty, + "There should be zero push-merged blocks during fallback") + // Some of the fallback local blocks could be host local blocks + fetchAllHostLocalBlocks(originalHostLocalBlocksByExecutor) + } + + /** + * Removes all the pending shuffle chunks that are on the same host and have the same reduceId as + * the current chunk that had a fetch failure. + * This is executed by the task thread when the `iterator.next()` is invoked and if that initiates + * fallback. + * + * @return set of all the removed shuffle chunk Ids. + */ + private[storage] def removePendingChunks( + failedBlockId: ShuffleBlockChunkId, + address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = { + val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]() + + def sameShuffleReducePartition(block: BlockId): Boolean = { + val chunkId = block.asInstanceOf[ShuffleBlockChunkId] + chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId + } + + def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = { + val fetchRequestsToRemove = new mutable.Queue[FetchRequest]() + fetchRequestsToRemove ++= queue.dequeueAll { req => + val firstBlock = req.blocks.head + firstBlock.blockId.isShuffleChunk && req.address.equals(address) && + sameShuffleReducePartition(firstBlock.blockId) + } + fetchRequestsToRemove.foreach { _ => + removedChunkIds ++= + fetchRequestsToRemove.flatMap(_.blocks.map(_.blockId.asInstanceOf[ShuffleBlockChunkId])) + } + } + + filterRequests(fetchRequests) + deferredFetchRequests.get(address).foreach { defRequests => + filterRequests(defRequests) + if (defRequests.isEmpty) deferredFetchRequests.remove(address) + } + removedChunkIds + } +} + +/** + * Helper class that ensures a ManagedBuffer is released upon InputStream.close() and + * also detects stream corruption if streamCompressedOrEncrypted is true + */ +private class StarBufferReleasingInputStream( + // This is visible for testing + private[storage] val delegate: InputStream, + private val iterator: StarShuffleBlockFetcherIterator, + private val blockId: BlockId, + private val mapIndex: Int, + private val address: BlockManagerId, + private val detectCorruption: Boolean, + private val isNetworkReqDone: Boolean, + private val checkedInOpt: Option[CheckedInputStream]) + extends InputStream { + private[this] var closed = false + + override def read(): Int = + tryOrFetchFailedException(delegate.read()) + + override def close(): Unit = { + if (!closed) { + try { + delegate.close() + iterator.releaseCurrentResultBuffer() + } finally { + // Unset the flag when a remote request finished and free memory is fairly enough. + if (isNetworkReqDone) { + ShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible(iterator.maxReqSizeShuffleToMem) + } + closed = true + } + } + } + + override def available(): Int = delegate.available() + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = + tryOrFetchFailedException(delegate.skip(n)) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = + tryOrFetchFailedException(delegate.read(b)) + + override def read(b: Array[Byte], off: Int, len: Int): Int = + tryOrFetchFailedException(delegate.read(b, off, len)) + + override def reset(): Unit = delegate.reset() + + /** + * Execute a block of code that returns a value, close this stream quietly and re-throwing + * IOException as FetchFailedException when detectCorruption is true. This method is only + * used by the `read` and `skip` methods inside `BufferReleasingInputStream` currently. + */ + private def tryOrFetchFailedException[T](block: => T): T = { + try { + block + } catch { + case e: IOException if detectCorruption => + val diagnosisResponse = checkedInOpt.map { checkedIn => + iterator.diagnoseCorruption(checkedIn, address, blockId) + } + IOUtils.closeQuietly(this) + // We'd never retry the block whatever the cause is since the block has been + // partially consumed by downstream RDDs. + iterator.throwFetchFailedException(blockId, mapIndex, address, e, diagnosisResponse) + } + } +} + +/** + * A listener to be called at the completion of the ShuffleBlockFetcherIterator + * @param data the ShuffleBlockFetcherIterator to process + */ +private class StarShuffleFetchCompletionListener(var data: StarShuffleBlockFetcherIterator) + extends TaskCompletionListener { + + override def onTaskCompletion(context: TaskContext): Unit = { + if (data != null) { + data.cleanup() + // Null out the referent here to make sure we don't keep a reference to this + // ShuffleBlockFetcherIterator, after we're done reading from it, to let it be + // collected during GC. Otherwise we can hold metadata on block locations(blocksByAddress) + data = null + } + } + + // Just an alias for onTaskCompletion to avoid confusing + def onComplete(context: TaskContext): Unit = this.onTaskCompletion(context) +} + +private[storage] +object StarShuffleBlockFetcherIterator { + + /** + * A flag which indicates whether the Netty OOM error has raised during shuffle. + * If true, unless there's no in-flight fetch requests, all the pending shuffle + * fetch requests will be deferred until the flag is unset (whenever there's a + * complete fetch request). + */ + val isNettyOOMOnShuffle = new AtomicBoolean(false) + + def resetNettyOOMFlagIfPossible(freeMemoryLowerBound: Long): Unit = { + if (isNettyOOMOnShuffle.get() && NettyUtils.freeDirectMemory() >= freeMemoryLowerBound) { + isNettyOOMOnShuffle.compareAndSet(true, false) + } + } + + /** + * This function is used to merged blocks when doBatchFetch is true. Blocks which have the + * same `mapId` can be merged into one block batch. The block batch is specified by a range + * of reduceId, which implies the continuous shuffle blocks that we can fetch in a batch. + * For example, input blocks like (shuffle_0_0_0, shuffle_0_0_1, shuffle_0_1_0) can be + * merged into (shuffle_0_0_0_2, shuffle_0_1_0_1), and input blocks like (shuffle_0_0_0_2, + * shuffle_0_0_2, shuffle_0_0_3) can be merged into (shuffle_0_0_0_4). + * + * @param blocks blocks to be merged if possible. May contains already merged blocks. + * @param doBatchFetch whether to merge blocks. + * @return the input blocks if doBatchFetch=false, or the merged blocks if doBatchFetch=true. + */ + def mergeContinuousShuffleBlockIdsIfNeeded( + blocks: Seq[FetchBlockInfo], + doBatchFetch: Boolean): Seq[FetchBlockInfo] = { + val result = if (doBatchFetch) { + val curBlocks = new ArrayBuffer[FetchBlockInfo] + val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo] + + def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = { + val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId] + + // The last merged block may comes from the input, and we can merge more blocks + // into it, if the map id is the same. + def shouldMergeIntoPreviousBatchBlockId = + mergedBlockInfo.last.blockId.asInstanceOf[ShuffleBlockBatchId].mapId == startBlockId.mapId + + val (startReduceId, size) = + if (mergedBlockInfo.nonEmpty && shouldMergeIntoPreviousBatchBlockId) { + // Remove the previous batch block id as we will add a new one to replace it. + val removed = mergedBlockInfo.remove(mergedBlockInfo.length - 1) + (removed.blockId.asInstanceOf[ShuffleBlockBatchId].startReduceId, + removed.size + toBeMerged.map(_.size).sum) + } else { + (startBlockId.reduceId, toBeMerged.map(_.size).sum) + } + + FetchBlockInfo( + ShuffleBlockBatchId( + startBlockId.shuffleId, + startBlockId.mapId, + startReduceId, + toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1), + size, + toBeMerged.head.mapIndex) + } + + val iter = blocks.iterator + while (iter.hasNext) { + val info = iter.next() + // It's possible that the input block id is already a batch ID. For example, we merge some + // blocks, and then make fetch requests with the merged blocks according to "max blocks per + // request". The last fetch request may be too small, and we give up and put the remaining + // merged blocks back to the input list. + if (info.blockId.isInstanceOf[ShuffleBlockBatchId]) { + mergedBlockInfo += info + } else { + if (curBlocks.isEmpty) { + curBlocks += info + } else { + val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId] + val currentMapId = curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId + if (curBlockId.mapId != currentMapId) { + mergedBlockInfo += mergeFetchBlockInfo(curBlocks) + curBlocks.clear() + } + curBlocks += info + } + } + } + if (curBlocks.nonEmpty) { + mergedBlockInfo += mergeFetchBlockInfo(curBlocks) + } + mergedBlockInfo + } else { + blocks + } + result.toSeq + } + + /** + * The block information to fetch used in FetchRequest. + * @param blockId block id + * @param size estimated size of the block. Note that this is NOT the exact bytes. + * Size of remote block is used to calculate bytesInFlight. + * @param mapIndex the mapIndex for this block, which indicate the index in the map stage. + */ + private[storage] case class FetchBlockInfo( + blockId: BlockId, + size: Long, + mapIndex: Int) + + /** + * A request to fetch blocks from a remote BlockManager. + * @param address remote BlockManager to fetch from. + * @param blocks Sequence of the information for blocks to fetch from the same address. + * @param forMergedMetas true if this request is for requesting push-merged meta information; + * false if it is for regular or shuffle chunks. + */ + case class FetchRequest( + address: BlockManagerId, + blocks: Seq[FetchBlockInfo], + forMergedMetas: Boolean = false) { + val size = blocks.map(_.size).sum + } + + /** + * Result of a fetch from a remote block. + */ + private[storage] sealed trait FetchResult + + /** + * Result of a fetch from a remote block successfully. + * @param blockId block id + * @param mapIndex the mapIndex for this block, which indicate the index in the map stage. + * @param address BlockManager that the block was fetched from. + * @param size estimated size of the block. Note that this is NOT the exact bytes. + * Size of remote block is used to calculate bytesInFlight. + * @param buf `ManagedBuffer` for the content. + * @param isNetworkReqDone Is this the last network request for this host in this fetch request. + */ + private[storage] case class SuccessFetchResult( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + size: Long, + buf: ManagedBuffer, + isNetworkReqDone: Boolean) extends FetchResult { + require(buf != null) + require(size >= 0) + } + + /** + * Result of a fetch from a remote block unsuccessfully. + * @param blockId block id + * @param mapIndex the mapIndex for this block, which indicate the index in the map stage + * @param address BlockManager that the block was attempted to be fetched from + * @param e the failure exception + */ + private[storage] case class FailureFetchResult( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + e: Throwable) + extends FetchResult + + /** + * Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM + */ + private[storage] + case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult + + /** + * Result of an un-successful fetch of either of these: + * 1) Remote shuffle chunk. + * 2) Local push-merged block. + * + * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original blocks. + * + * @param blockId block id + * @param address BlockManager that the push-merged block was attempted to be fetched from + * @param size size of the block, used to update bytesInFlight. + * @param isNetworkReqDone Is this the last network request for this host in this fetch + * request. Used to update reqsInFlight. + */ + private[storage] case class FallbackOnPushMergedFailureResult(blockId: BlockId, + address: BlockManagerId, + size: Long, + isNetworkReqDone: Boolean) extends FetchResult + + /** + * Result of a successful fetch of meta information for a remote push-merged block. + * + * @param shuffleId shuffle id. + * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process + * of shuffle by an indeterminate stage attempt. + * @param reduceId reduce id. + * @param blockSize size of each push-merged block. + * @param bitmaps bitmaps for every chunk. + * @param address BlockManager that the meta was fetched from. + */ + private[storage] case class PushMergedRemoteMetaFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + blockSize: Long, + bitmaps: Array[RoaringBitmap], + address: BlockManagerId) extends FetchResult + + /** + * Result of a failure while fetching the meta information for a remote push-merged block. + * + * @param shuffleId shuffle id. + * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process + * of shuffle by an indeterminate stage attempt. + * @param reduceId reduce id. + * @param address BlockManager that the meta was fetched from. + */ + private[storage] case class PushMergedRemoteMetaFailedFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + address: BlockManagerId) extends FetchResult + + /** + * Result of a successful fetch of meta information for a push-merged-local block. + * + * @param shuffleId shuffle id. + * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process + * of shuffle by an indeterminate stage attempt. + * @param reduceId reduce id. + * @param bitmaps bitmaps for every chunk. + * @param localDirs local directories where the push-merged shuffle files are storedl + */ + private[storage] case class PushMergedLocalMetaFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + bitmaps: Array[RoaringBitmap], + localDirs: Array[String]) extends FetchResult +} diff --git a/external-shuffle-storage/src/test/java/org/apache/spark/starshuffle/ByteBufUtilsTest.java b/external-shuffle-storage/src/test/java/org/apache/spark/starshuffle/ByteBufUtilsTest.java new file mode 100644 index 0000000000000..2e7f6c3269086 --- /dev/null +++ b/external-shuffle-storage/src/test/java/org/apache/spark/starshuffle/ByteBufUtilsTest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.starshuffle; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.PooledByteBufAllocator; + +import org.junit.Assert; +import org.junit.Test; + +public class ByteBufUtilsTest { + + @Test + public void writeLengthAndString() { + ByteBuf buf = PooledByteBufAllocator.DEFAULT.buffer(1); + ByteBufUtils.writeLengthAndString(buf, "hello world"); + + String str = ByteBufUtils.readLengthAndString(buf); + Assert.assertEquals(str, "hello world"); + + buf.release(); + } +} diff --git a/external-shuffle-storage/src/test/resources/log4j.properties b/external-shuffle-storage/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..d1cfaf521ab08 --- /dev/null +++ b/external-shuffle-storage/src/test/resources/log4j.properties @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Set everything to be logged to the file target/unit-tests.log +test.appender=file +log4j.rootCategory=INFO, ${test.appender} +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n +# Tests that launch java subprocesses can set the "test.appender" system property to +# "console" to avoid having the child process's logs overwrite the unit test's +# log file. +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%t: %m%n +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.sparkproject.jetty=WARN diff --git a/external-shuffle-storage/src/test/scala/org/apache/spark/shuffle/StarShuffleManagerTest.scala b/external-shuffle-storage/src/test/scala/org/apache/spark/shuffle/StarShuffleManagerTest.scala new file mode 100644 index 0000000000000..e7e23eb0a4421 --- /dev/null +++ b/external-shuffle-storage/src/test/scala/org/apache/spark/shuffle/StarShuffleManagerTest.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.util.UUID + +import org.junit.Test +import org.scalatest.Assertions._ + +import org.apache.spark._ + +class StarShuffleManagerTest { + + @Test + def foldByKey(): Unit = { + val conf = newSparkConf() + runWithSparkConf(conf) + } + + @Test + def foldByKey_zeroBuffering(): Unit = { + val conf = newSparkConf() + conf.set("spark.reducer.maxSizeInFlight", "0") + conf.set("spark.network.maxRemoteBlockSizeFetchToMem", "0") + runWithSparkConf(conf) + } + + private def runWithSparkConf(conf: SparkConf) = { + var sc = new SparkContext(conf) + + try { + val numValues = 10000 + val numMaps = 3 + val numPartitions = 5 + + val rdd = sc.parallelize(0 until numValues, numMaps) + .map(t => ((t / 2) -> (t * 2).longValue())) + .foldByKey(0, numPartitions)((v1, v2) => v1 + v2) + val result = rdd.collect() + + assert(result.size === numValues / 2) + + for (i <- 0 until result.size) { + val key = result(i)._1 + val value = result(i)._2 + assert(key * 2 * 2 + (key * 2 + 1) * 2 === value) + } + + val keys = result.map(_._1).distinct.sorted + assert(keys.length === numValues / 2) + assert(keys(0) === 0) + assert(keys.last === (numValues - 1) / 2) + } finally { + sc.stop() + } + } + + def newSparkConf(): SparkConf = new SparkConf() + .setAppName("testApp") + .setMaster(s"local[2]") + .set("spark.ui.enabled", "false") + .set("spark.driver.allowMultipleContexts", "true") + .set("spark.app.id", "app-" + UUID.randomUUID()) + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.StarShuffleManager") +} diff --git a/pom.xml b/pom.xml index aefb5377e6b7b..251ff62ac7b87 100644 --- a/pom.xml +++ b/pom.xml @@ -3397,6 +3397,13 @@ + + external-shuffle-storage + + external-shuffle-storage + + + test-java-home