diff --git a/assembly/pom.xml b/assembly/pom.xml
index 74c2f44121fca..08e453077e87b 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -268,5 +268,17 @@
+
+
+
+ external-shuffle-storage
+
+
+ org.apache.spark
+ external-shuffle-storage_${scala.binary.version}
+ ${project.version}
+
+
+
diff --git a/external-shuffle-storage/README.md b/external-shuffle-storage/README.md
new file mode 100644
index 0000000000000..92b578b66e3f4
--- /dev/null
+++ b/external-shuffle-storage/README.md
@@ -0,0 +1,55 @@
+# External Shuffle Storage
+
+This module provides support to store shuffle files on external shuffle storage like S3. It helps Dynamic
+Allocation on Kubernetes. Spark driver could release idle executors without worrying about losing
+shuffle data because the shuffle data is store on external shuffle storage which are different
+from executors.
+
+This module implements a new Shuffle Manager named as StarShuffleManager, and copies a lot of codes
+from Spark SortShuffleManager. This is for a quick prototype. We want to use this as an example to discuss
+with Spark community and get feedback. We will work with the community to remove code duplication later
+and make StarShuffleManager more integrated with Spark code.
+
+## How to Build Spark Distribution with StarShuffleManager jar File
+
+Follow [Building Spark](https://spark.apache.org/docs/latest/building-spark.html) instructions,
+with extra `-Pexternal-shuffle-storage` to generate the new shuffle implementation jar file.
+
+Following is one command example to use `dev/make-distribution.sh` under Spark repo root directory:
+
+```
+./dev/make-distribution.sh --name spark-with-external-shuffle-storage --pip --tgz -Phive -Phive-thriftserver -Pkubernetes -Phadoop-3.2 -Phadoop-cloud -Dhadoop.version=3.2.0 -Pexternal-shuffle-storage
+```
+
+If you want to build a Spark docker image, you could unzip the Spark distribution tgz file, and run command like following:
+
+```
+./bin/docker-image-tool.sh -t spark-with-external-shuffle-storage build
+```
+
+This command creates `external-shuffle-storage_xxx.jar` file for StarShuffleManager
+under `jars` directory in the generated Spark distribution. Now you could use this Spark
+distribution to run your Spark application with external shuffle storage.
+
+## How to Run Spark Application With External Shuffle Storage in Kubernetes
+
+### Run Spark Application With S3 as External Shuffle Storage and Dynamic Allocation
+
+Add configure to your Spark application like following (you need to adjust the values based on your environment):
+
+```
+spark.shuffle.manager=org.apache.spark.shuffle.StarShuffleManager
+spark.shuffle.star.rootDir=s3://my_bucket_name/my_shuffle_folder
+spark.dynamicAllocation.enabled=true
+spark.dynamicAllocation.shuffleTracking.enabled=true
+spark.dynamicAllocation.shuffleTracking.timeout=1
+```
+
+### How to specify AWS region for the S3 files
+
+Add Spark config like following:
+
+```
+spark.hadoop.fs.s3a.endpoint.region=us-west-2
+```
+
diff --git a/external-shuffle-storage/pom.xml b/external-shuffle-storage/pom.xml
new file mode 100644
index 0000000000000..1699240b78fe6
--- /dev/null
+++ b/external-shuffle-storage/pom.xml
@@ -0,0 +1,143 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.12
+ 3.3.0-SNAPSHOT
+ ../pom.xml
+
+
+ external-shuffle-storage_2.12
+ jar
+ External Shuffle Storage
+ http://spark.apache.org/
+
+
+ external-shuffle-storage
+ none
+ package
+ provided
+ provided
+ provided
+
+
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ org.apache.commons
+ commons-math3
+ provided
+
+
+ org.scalacheck
+ scalacheck_${scala.binary.version}
+ test
+
+
+ org.scala-lang
+ scala-library
+ provided
+
+
+ io.netty
+ netty-all
+
+
+ org.slf4j
+ slf4j-api
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+ org.apache.httpcomponents
+ httpclient
+
+
+ commons-io
+ commons-io
+
+
+ org.apache.commons
+ commons-lang3
+
+
+ com.amazonaws
+ aws-java-sdk-s3
+ 1.11.975
+ provided
+
+
+ org.testng
+ testng
+ 6.14.3
+ test
+
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+ org.apache.maven.plugins
+ maven-deploy-plugin
+
+ true
+
+
+
+ org.apache.maven.plugins
+ maven-install-plugin
+
+ true
+
+
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+
+ ${jars.target.dir}
+
+
+
+
+
diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/ByteBufUtils.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/ByteBufUtils.java
new file mode 100644
index 0000000000000..7007ce8992560
--- /dev/null
+++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/ByteBufUtils.java
@@ -0,0 +1,46 @@
+/*
+ * This file is copied from Uber Remote Shuffle Service
+ * (https://github.com/uber/RemoteShuffleService) and modified.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.starshuffle;
+
+import io.netty.buffer.ByteBuf;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.charset.StandardCharsets;
+
+public class ByteBufUtils {
+ public static final void writeLengthAndString(ByteBuf buf, String str) {
+ if (str == null) {
+ buf.writeInt(-1);
+ return;
+ }
+
+ byte[] bytes = str.getBytes(StandardCharsets.UTF_8);
+ buf.writeInt(bytes.length);
+ buf.writeBytes(bytes);
+ }
+
+ public static final String readLengthAndString(ByteBuf buf) {
+ int length = buf.readInt();
+ if (length == -1) {
+ return null;
+ }
+
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return new String(bytes, StandardCharsets.UTF_8);
+ }
+}
diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarBlockStoreClient.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarBlockStoreClient.java
new file mode 100644
index 0000000000000..56260336e1b44
--- /dev/null
+++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarBlockStoreClient.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.starshuffle;
+
+import org.apache.spark.SparkEnv;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.netty.SparkTransportConf;
+import org.apache.spark.network.shuffle.*;
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.ShuffleBlockId;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.Option;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.util.concurrent.CompletableFuture;
+
+/**
+ * This class fetches shuffle blocks from external storage like S3
+ */
+public class StarBlockStoreClient extends BlockStoreClient {
+
+ private static final Logger logger = LoggerFactory.getLogger(StarBlockStoreClient.class);
+
+ // Fetch shuffle blocks from external shuffle storage.
+ // The shuffle location is encoded in the host argument. In the future, we should enhance
+ // Spark internal code to support abstraction of shuffle storage location.
+ @Override
+ public void fetchBlocks(String host, int port, String execId, String[] blockIds, BlockFetchingListener listener, DownloadFileManager downloadFileManager) {
+ for (int i = 0; i < blockIds.length; i++) {
+ String blockId = blockIds[i];
+ CompletableFuture.runAsync(() -> fetchBlock(host, execId, blockId, listener, downloadFileManager));
+ }
+ }
+
+ private void fetchBlock(String host, String execId, String blockIdStr, BlockFetchingListener listener, DownloadFileManager downloadFileManager) {
+ BlockId blockId = BlockId.apply(blockIdStr);
+ if (blockId instanceof ShuffleBlockId) {
+ ShuffleBlockId shuffleBlockId = (ShuffleBlockId)blockId;
+ StarMapResultFileInfo mapResultFileInfo = StarMapResultFileInfo.deserializeFromString(host);
+ long offset = 0;
+ for (int i = 0; i < shuffleBlockId.reduceId(); i++) {
+ offset += mapResultFileInfo.getPartitionLengths()[i];
+ }
+ long size = mapResultFileInfo.getPartitionLengths()[shuffleBlockId.reduceId()];
+ StarShuffleFileManager streamProvider = StarUtils.createShuffleFileManager(SparkEnv.get().conf(),
+ mapResultFileInfo.getLocation());
+ if (downloadFileManager != null) {
+ try (InputStream inputStream = streamProvider.read(mapResultFileInfo.getLocation(), offset, size)) {
+ TransportConf transportConf = SparkTransportConf.fromSparkConf(
+ SparkEnv.get().conf(), "starShuffle", 1, Option.empty());
+ DownloadFile downloadFile = downloadFileManager.createTempFile(transportConf);
+ downloadFileManager.registerTempFileToClean(downloadFile);
+ DownloadFileWritableChannel downloadFileWritableChannel = downloadFile.openForWriting();
+
+ int bufferSize = 64 * 1024;
+ byte[] bytes = new byte[bufferSize];
+ int readBytes = 0;
+ while (readBytes < size) {
+ int toReadBytes = Math.min((int)size - readBytes, bufferSize);
+ int n = inputStream.read(bytes, 0, toReadBytes);
+ if (n == -1) {
+ throw new RuntimeException(String.format(
+ "Failed to read file %s for shuffle block %s, hit end with remaining %s bytes",
+ mapResultFileInfo.getLocation(),
+ blockId,
+ size - readBytes));
+ }
+ readBytes += n;
+ downloadFileWritableChannel.write(ByteBuffer.wrap(bytes, 0, n));
+ }
+ ManagedBuffer managedBuffer = downloadFileWritableChannel.closeAndRead();
+ listener.onBlockFetchSuccess(blockIdStr, managedBuffer);
+ } catch (IOException e) {
+ throw new RuntimeException(String.format(
+ "Failed to read file %s for shuffle block %s",
+ mapResultFileInfo.getLocation(),
+ blockId),
+ e);
+ }
+ } else {
+ try (InputStream inputStream = streamProvider.read(mapResultFileInfo.getLocation(), offset, size)) {
+ ByteBuffer byteBuffer = ByteBuffer.allocate((int)size);
+ int b = inputStream.read();
+ while (b != -1) {
+ byteBuffer.put((byte)b);
+ if (byteBuffer.position() == size) {
+ break;
+ }
+ b = inputStream.read();
+ }
+ byteBuffer.flip();
+ NioManagedBuffer managedBuffer = new NioManagedBuffer(byteBuffer);
+ listener.onBlockFetchSuccess(blockIdStr, managedBuffer);
+ } catch (IOException e) {
+ throw new RuntimeException(String.format(
+ "Failed to read file %s for shuffle block %s",
+ mapResultFileInfo.getLocation(),
+ blockId),
+ e);
+ }
+ }
+ logger.info("Fetch blocks: {}, {}", host, execId);
+ } else {
+ throw new RuntimeException(String.format(
+ "%s does not support %s: %s",
+ this.getClass().getSimpleName(),
+ blockId.getClass().getSimpleName(),
+ blockId));
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ logger.info("Close");
+ }
+}
diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarBypassMergeSortShuffleWriter.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarBypassMergeSortShuffleWriter.java
new file mode 100644
index 0000000000000..16e73e8377107
--- /dev/null
+++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarBypassMergeSortShuffleWriter.java
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.starshuffle;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.spark.Partitioner;
+import org.apache.spark.ShuffleDependency;
+import org.apache.spark.SparkConf;
+import org.apache.spark.internal.config.package$;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.shuffle.StarBypassMergeSortShuffleHandle;
+import org.apache.spark.shuffle.StarOpts;
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
+import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
+import org.apache.spark.storage.*;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.None$;
+import scala.Option;
+import scala.Product2;
+import scala.Tuple2;
+import scala.collection.Iterator;
+
+import javax.annotation.Nullable;
+import java.io.File;
+import java.io.IOException;
+
+/**
+ * This class is copied from BypassMergeSortShuffleWriter for a quick prototype.
+ * Will rework it later to avoid code copy.
+ */
+public class StarBypassMergeSortShuffleWriter extends ShuffleWriter {
+
+ private static final Logger logger = LoggerFactory.getLogger(StarBypassMergeSortShuffleWriter.class);
+
+ private final int fileBufferSize;
+ private final boolean transferToEnabled;
+ private final int numPartitions;
+ private final BlockManager blockManager;
+ private final Partitioner partitioner;
+ private final ShuffleWriteMetricsReporter writeMetrics;
+ private final int shuffleId;
+ private final long mapId;
+ private final Serializer serializer;
+ private final ShuffleExecutorComponents shuffleExecutorComponents;
+
+ /** Array of file writers, one for each partition */
+ private DiskBlockObjectWriter[] partitionWriters;
+ private FileSegment[] partitionWriterSegments;
+ @Nullable private MapStatus mapStatus;
+ private long[] partitionLengths;
+
+ private String rootDir;
+
+ /**
+ * Are we in the process of stopping? Because map tasks can call stop() with success = true
+ * and then call stop() with success = false if they get an exception, we want to make sure
+ * we don't try deleting files, etc twice.
+ */
+ private boolean stopping = false;
+
+ public StarBypassMergeSortShuffleWriter(
+ BlockManager blockManager,
+ StarBypassMergeSortShuffleHandle handle,
+ long mapId,
+ SparkConf conf,
+ ShuffleWriteMetricsReporter writeMetrics,
+ ShuffleExecutorComponents shuffleExecutorComponents) {
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+ this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
+ this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
+ this.blockManager = blockManager;
+ final ShuffleDependency dep = handle.dependency();
+ this.mapId = mapId;
+ this.shuffleId = dep.shuffleId();
+ this.partitioner = dep.partitioner();
+ this.numPartitions = partitioner.numPartitions();
+ this.writeMetrics = writeMetrics;
+ this.serializer = dep.serializer();
+ this.shuffleExecutorComponents = shuffleExecutorComponents;
+ this.rootDir = conf.get(StarOpts.rootDir());
+ }
+
+ @Override
+ public void write(Iterator> records) throws IOException {
+ assert (partitionWriters == null);
+ StartFileSegmentWriter startFileSegmentWriter = new StartFileSegmentWriter(rootDir);
+ ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents
+ .createMapOutputWriter(shuffleId, mapId, numPartitions);
+ try {
+ if (!records.hasNext()) {
+ partitionLengths = StarUtils.getLengths(partitionWriterSegments);
+ mapStatus = startFileSegmentWriter.write(mapId, partitionWriterSegments);
+ return;
+ }
+ final SerializerInstance serInstance = serializer.newInstance();
+ final long openStartTime = System.nanoTime();
+ partitionWriters = new DiskBlockObjectWriter[numPartitions];
+ partitionWriterSegments = new FileSegment[numPartitions];
+ for (int i = 0; i < numPartitions; i++) {
+ final Tuple2 tempShuffleBlockIdPlusFile =
+ blockManager.diskBlockManager().createTempShuffleBlock();
+ final File file = tempShuffleBlockIdPlusFile._2();
+ final BlockId blockId = tempShuffleBlockIdPlusFile._1();
+ partitionWriters[i] =
+ blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
+ }
+ // Creating the file to write to and creating a disk writer both involve interacting with
+ // the disk, and can take a long time in aggregate when we open many files, so should be
+ // included in the shuffle write time.
+ writeMetrics.incWriteTime(System.nanoTime() - openStartTime);
+
+ while (records.hasNext()) {
+ final Product2 record = records.next();
+ final K key = record._1();
+ partitionWriters[partitioner.getPartition(key)].write(key, record._2());
+ }
+
+ for (int i = 0; i < numPartitions; i++) {
+ try (DiskBlockObjectWriter writer = partitionWriters[i]) {
+ partitionWriterSegments[i] = writer.commitAndGet();
+ }
+ }
+
+ partitionLengths = StarUtils.getLengths(partitionWriterSegments);
+ mapStatus = startFileSegmentWriter.write(mapId, partitionWriterSegments);
+ } catch (Exception e) {
+ try {
+ mapOutputWriter.abort(e);
+ } catch (Exception e2) {
+ logger.error("Failed to abort the writer after failing to write map output.", e2);
+ e.addSuppressed(e2);
+ }
+ throw e;
+ }
+ }
+
+ @VisibleForTesting
+ public long[] getPartitionLengths() {
+ return partitionLengths;
+ }
+
+ @Override
+ public Option stop(boolean success) {
+ if (stopping) {
+ return None$.empty();
+ } else {
+ stopping = true;
+ if (success) {
+ if (mapStatus == null) {
+ throw new IllegalStateException("Cannot call stop(true) without having called write()");
+ }
+ return Option.apply(mapStatus);
+ } else {
+ // The map task failed, so delete our output data.
+ if (partitionWriters != null) {
+ try {
+ for (DiskBlockObjectWriter writer : partitionWriters) {
+ // This method explicitly does _not_ throw exceptions:
+ File file = writer.revertPartialWritesAndClose();
+ if (!file.delete()) {
+ logger.error("Error while deleting file {}", file.getAbsolutePath());
+ }
+ }
+ } finally {
+ partitionWriters = null;
+ }
+ }
+ return None$.empty();
+ }
+ }
+ }
+}
diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarLocalFileShuffleFileManager.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarLocalFileShuffleFileManager.java
new file mode 100644
index 0000000000000..1a61c7aa89e66
--- /dev/null
+++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarLocalFileShuffleFileManager.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.starshuffle;
+
+import org.apache.commons.io.IOUtils;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.nio.file.Paths;
+import java.util.UUID;
+
+/**
+ * This class read/write shuffle file on external storage like local file or network file system.
+ */
+public class StarLocalFileShuffleFileManager implements StarShuffleFileManager {
+ private static final Logger logger = LoggerFactory.getLogger(
+ StarLocalFileShuffleFileManager.class);
+
+ @Override
+ public String createFile(String root) {
+ try {
+ if (root == null || root.isEmpty()) {
+ File file = File.createTempFile("shuffle", ".data");
+ return file.getAbsolutePath();
+ } else {
+ String fileName = String.format("shuffle-%s.data", UUID.randomUUID());
+ return Paths.get(root, fileName).toString();
+ }
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to create shuffle file", e);
+ }
+ }
+
+ @Override
+ public void write(InputStream data, long size, String file) {
+ logger.info("Writing to shuffle file: {}", file);
+ try (FileOutputStream outputStream = new FileOutputStream(file)) {
+ long copiedBytes = IOUtils.copyLarge(data, outputStream);
+ if (copiedBytes != size) {
+ throw new RuntimeException(String.format(
+ "Got corrupted shuffle data when writing to " +
+ "file %s, expected size: %s, actual written size: %s",
+ file, copiedBytes));
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(String.format(
+ "Failed to write shuffle file %s",
+ file));
+ }
+ }
+
+ @Override
+ public InputStream read(String file, long offset, long size) {
+ logger.info("Opening shuffle file: {}, offset: {}, size: {}", file, offset, size);
+ try {
+ FileInputStream inputStream = new FileInputStream(file);
+ inputStream.skip(offset);
+ return new LimitedInputStream(inputStream, size);
+ } catch (IOException e) {
+ throw new RuntimeException(String.format(
+ "Failed to open shuffle file %s, offset: %s, size: %s",
+ file, offset, size));
+ }
+ }
+}
diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarMapResultFileInfo.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarMapResultFileInfo.java
new file mode 100644
index 0000000000000..c132900c0831d
--- /dev/null
+++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarMapResultFileInfo.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.starshuffle;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import java.util.Arrays;
+import java.util.Base64;
+
+/**
+ * This class stores map result file (shuffle file) information.
+ */
+public class StarMapResultFileInfo {
+ private final String location;
+ private final long[] partitionLengths;
+
+ public void serialize(ByteBuf buf) {
+
+ ByteBufUtils.writeLengthAndString(buf, location);
+ buf.writeInt(partitionLengths.length);
+ for (int i = 0; i < partitionLengths.length; i++) {
+ buf.writeLong(partitionLengths[i]);
+ }
+ }
+
+ public static StarMapResultFileInfo deserialize(ByteBuf buf) {
+ String location = ByteBufUtils.readLengthAndString(buf);
+ int size = buf.readInt();
+ long[] partitionLengths = new long[size];
+ for (int i = 0; i < size; i++) {
+ partitionLengths[i] = buf.readLong();
+ }
+ return new StarMapResultFileInfo(location, partitionLengths);
+ }
+
+ /***
+ * This serialize method is faster than json serialization.
+ * @return
+ */
+ public String serializeToString() {
+ ByteBuf buf = Unpooled.buffer();
+ try {
+ serialize(buf);
+ byte[] bytes = new byte[buf.readableBytes()];
+ buf.readBytes(bytes);
+ return Base64.getEncoder().encodeToString(bytes);
+ } finally {
+ buf.release();
+ }
+ }
+
+ public static StarMapResultFileInfo deserializeFromString(String str) {
+ byte[] bytes = Base64.getDecoder().decode(str);
+ ByteBuf buf = Unpooled.wrappedBuffer(bytes);
+ try {
+ return deserialize(buf);
+ } finally {
+ buf.release();
+ }
+ }
+
+ public StarMapResultFileInfo(String location, long[] partitionLengths) {
+ this.location = location;
+ this.partitionLengths = partitionLengths;
+ }
+
+ public long[] getPartitionLengths() {
+ return partitionLengths;
+ }
+
+ public String getLocation() {
+ return location;
+ }
+
+ @Override
+ public String toString() {
+ return "MapResultFile{" +
+ "location='" + location + '\'' +
+ ", partitionLengths=" + Arrays.toString(partitionLengths) +
+ '}';
+ }
+}
diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarS3ShuffleFileManager.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarS3ShuffleFileManager.java
new file mode 100644
index 0000000000000..6c7d9e83e7a1f
--- /dev/null
+++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarS3ShuffleFileManager.java
@@ -0,0 +1,320 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.starshuffle;
+
+import com.amazonaws.ClientConfiguration;
+import com.amazonaws.client.builder.ExecutorFactory;
+import com.amazonaws.event.ProgressEvent;
+import com.amazonaws.event.ProgressListener;
+import com.amazonaws.regions.Regions;
+import com.amazonaws.services.s3.AmazonS3;
+import com.amazonaws.services.s3.AmazonS3ClientBuilder;
+import com.amazonaws.services.s3.model.GetObjectRequest;
+import com.amazonaws.services.s3.model.ObjectMetadata;
+import com.amazonaws.services.s3.model.PutObjectRequest;
+import com.amazonaws.services.s3.transfer.TransferManager;
+import com.amazonaws.services.s3.transfer.TransferManagerBuilder;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.spark.SparkConf;
+import org.apache.spark.deploy.SparkHadoopUtil;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.net.URI;
+import java.util.UUID;
+import java.util.concurrent.*;
+import java.util.concurrent.atomic.AtomicLong;
+
+/**
+ * This class read/write shuffle file on external storage like S3.
+ */
+public class StarS3ShuffleFileManager implements StarShuffleFileManager {
+ private static final Logger logger = LoggerFactory.getLogger(StarS3ShuffleFileManager.class);
+
+ // TODO make following values configurable
+ public final static int S3_PUT_TIMEOUT_MILLISEC = 180 * 1000;
+
+ // Following constants are copied from:
+ // https://github.com/apache/hadoop/blob/6c6d1b64d4a7cd5288fcded78043acaf23228f96/hadoop-tools/hadoop-aws/src/main/java/org/apache/hadoop/fs/s3a/Constants.java
+ public static final long DEFAULT_MULTIPART_SIZE = 67108864; // 64M
+ public static final long DEFAULT_MIN_MULTIPART_THRESHOLD = 134217728; // 128M
+ public static final String MAX_THREADS = "fs.s3a.threads.max";
+ public static final int DEFAULT_MAX_THREADS = 10;
+ public static final String KEEPALIVE_TIME = "fs.s3a.threads.keepalivetime";
+ public static final int DEFAULT_KEEPALIVE_TIME = 60;
+
+ public static final String AWS_REGION = "fs.s3a.endpoint.region";
+ public static final String DEFAULT_AWS_REGION = Regions.US_WEST_2.getName();
+
+ private static TransferManager transferManager;
+ private static final Object transferManagerLock = new Object();
+
+ private final String awsRegion;
+ private final int maxThreads;
+ private final long keepAliveTime;
+
+ public StarS3ShuffleFileManager(SparkConf conf) {
+ Configuration hadoopConf = SparkHadoopUtil.get().newConfiguration(conf);
+
+ awsRegion = hadoopConf.get(AWS_REGION, DEFAULT_AWS_REGION);
+
+ int threads = conf.getInt(MAX_THREADS, DEFAULT_MAX_THREADS);
+ if (threads < 2) {
+ logger.warn(MAX_THREADS + " must be at least 2: forcing to 2.");
+ threads = 2;
+ }
+ maxThreads = threads;
+
+ keepAliveTime = conf.getLong(KEEPALIVE_TIME, DEFAULT_KEEPALIVE_TIME);
+ }
+
+ @Override
+ public String createFile(String root) {
+ if (!root.endsWith("/")) {
+ root = root + "/";
+ }
+ String fileName = String.format("shuffle-%s.data", UUID.randomUUID());
+ return root + fileName;
+ }
+
+ @Override
+ public void write(InputStream data, long size, String file) {
+ logger.info("Writing to shuffle file: {}", file);
+ writeS3(data, size, file);
+ }
+
+ @Override
+ public InputStream read(String file, long offset, long size) {
+ logger.info("Opening shuffle file: {}, offset: {}, size: {}", file, offset, size);
+ return readS3(file, offset, size);
+ }
+
+ private void writeS3(InputStream inputStream, long size, String s3Url) {
+ logger.info("Uploading shuffle file to s3: {}, size: {}", s3Url, size);
+
+ S3BucketAndKey bucketAndKey = S3BucketAndKey.getFromUrl(s3Url);
+ String bucket = bucketAndKey.getBucket();
+ String key = bucketAndKey.getKey();
+
+ TransferManager transferManager = getTransferManager();
+
+ ObjectMetadata metadata = new ObjectMetadata();
+ metadata.setContentType("application/octet-stream");
+ metadata.setContentLength(size);
+
+ PutObjectRequest request = new PutObjectRequest(bucket,
+ key,
+ inputStream,
+ metadata);
+
+ AtomicLong totalTransferredBytes = new AtomicLong(0);
+
+ request.setGeneralProgressListener(new ProgressListener() {
+ private long lastLogTime = 0;
+
+ @Override
+ public void progressChanged(ProgressEvent progressEvent) {
+ long count = progressEvent.getBytesTransferred();
+ long total = totalTransferredBytes.addAndGet(count);
+ long currentTime = System.currentTimeMillis();
+ long logInterval = 10000;
+ if (currentTime - lastLogTime >= logInterval) {
+ logger.info("S3 upload progress: {}, recent transferred {} bytes, total transferred {}", key, count, total);
+ lastLogTime = currentTime;
+ }
+ }
+ });
+
+ // https://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/best-practices.html
+ request.getRequestClientOptions().setReadLimit((int) DEFAULT_MULTIPART_SIZE + 1);
+ request.setSdkRequestTimeout(S3_PUT_TIMEOUT_MILLISEC);
+ request.setSdkClientExecutionTimeout(S3_PUT_TIMEOUT_MILLISEC);
+ try {
+ long startTime = System.currentTimeMillis();
+ transferManager.upload(request).waitForCompletion();
+ long duration = System.currentTimeMillis() - startTime;
+ double mbs = 0;
+ if (duration != 0) {
+ mbs = ((double) size) / (1000 * 1000) / ((double) duration / 1000);
+ }
+ logger.info("S3 upload finished: {}, file size: {} bytes, total transferred: {}, throughput: {} mbs",
+ s3Url, size, totalTransferredBytes.get(), mbs);
+ } catch (InterruptedException e) {
+ throw new RuntimeException("Failed to upload to s3: " + key, e);
+ }
+ }
+
+ private InputStream readS3(String s3Url, long offset, long size) {
+ logger.info("Downloading shuffle file from s3: {}, size: {}", s3Url, size);
+
+ S3BucketAndKey bucketAndKey = S3BucketAndKey.getFromUrl(s3Url);
+
+ File downloadTempFile;
+ try {
+ downloadTempFile = File.createTempFile("shuffle-download", ".data");
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to create temp file for downloading shuffle file");
+ }
+
+ TransferManager transferManager = getTransferManager();
+
+ GetObjectRequest getObjectRequest = new GetObjectRequest(bucketAndKey.getBucket(), bucketAndKey.getKey())
+ .withRange(offset, offset + size);
+
+ AtomicLong totalTransferredBytes = new AtomicLong(0);
+
+ getObjectRequest.setGeneralProgressListener(new ProgressListener() {
+ private long lastLogTime = 0;
+
+ @Override
+ public void progressChanged(ProgressEvent progressEvent) {
+ long count = progressEvent.getBytesTransferred();
+ long total = totalTransferredBytes.addAndGet(count);
+ long currentTime = System.currentTimeMillis();
+ long logInterval = 10000;
+ if (currentTime - lastLogTime >= logInterval) {
+ logger.info("S3 download progress: {}, recent transferred {} bytes, total transferred {}", s3Url, count, total);
+ lastLogTime = currentTime;
+ }
+ }
+ });
+
+ try {
+ long startTime = System.currentTimeMillis();
+ transferManager.download(getObjectRequest, downloadTempFile).waitForCompletion();
+ long duration = System.currentTimeMillis() - startTime;
+ double mbs = 0;
+ if (duration != 0) {
+ mbs = ((double) size) / (1000 * 1000) / ((double) duration / 1000);
+ }
+ logger.info("S3 download finished: {}, file size: {} bytes, total transferred: {}, throughput: {} mbs",
+ s3Url, size, totalTransferredBytes.get(), mbs);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(String.format(
+ "Failed to download shuffle file %s", s3Url));
+ } finally {
+ // TODO
+ transferManager.shutdownNow();
+ }
+
+ // TODO delete downloadTempFile
+
+ try {
+ return new LimitedInputStream(new FileInputStream(downloadTempFile), size);
+ } catch (FileNotFoundException e) {
+ throw new RuntimeException(String.format(
+ "Failed to open downloaded shuffle file %s (from %s)", downloadTempFile, s3Url));
+ }
+ }
+
+ private TransferManager getTransferManager() {
+ synchronized (transferManagerLock) {
+ if (transferManager != null) {
+ return transferManager;
+ }
+ transferManager = createTransferManager(awsRegion, maxThreads, keepAliveTime);
+ return transferManager;
+ }
+ }
+
+ private static TransferManager createTransferManager(String region, int maxThreads, long keepAliveTime) {
+ ClientConfiguration clientConfiguration = new ClientConfiguration();
+ clientConfiguration.setConnectionTimeout(S3_PUT_TIMEOUT_MILLISEC);
+ clientConfiguration.setRequestTimeout(S3_PUT_TIMEOUT_MILLISEC);
+ clientConfiguration.setSocketTimeout(S3_PUT_TIMEOUT_MILLISEC);
+ clientConfiguration.setClientExecutionTimeout(S3_PUT_TIMEOUT_MILLISEC);
+
+ ThreadFactory threadFactory = new ThreadFactory() {
+ private int threadCount = 1;
+ public Thread newThread(Runnable r) {
+ Thread thread = new Thread(r);
+ thread.setName("s3-shuffle-transfer-manager-worker-" + this.threadCount++);
+ return thread;
+ }
+ };
+ ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(
+ maxThreads, Integer.MAX_VALUE,
+ keepAliveTime, TimeUnit.SECONDS,
+ new LinkedBlockingQueue<>(),
+ threadFactory);
+ ExecutorFactory executorFactory = new ExecutorFactory() {
+ @Override
+ public ExecutorService newExecutor() {
+ return threadPoolExecutor;
+ }
+ };
+
+ AmazonS3 s3Client = AmazonS3ClientBuilder.standard()
+ .withRegion(region)
+ .withClientConfiguration(clientConfiguration)
+ .build();
+
+ return TransferManagerBuilder.standard()
+ .withS3Client(s3Client)
+ .withMinimumUploadPartSize(DEFAULT_MULTIPART_SIZE)
+ .withMultipartUploadThreshold(DEFAULT_MIN_MULTIPART_THRESHOLD)
+ .withMultipartCopyPartSize(DEFAULT_MULTIPART_SIZE)
+ .withMultipartCopyThreshold(DEFAULT_MIN_MULTIPART_THRESHOLD)
+ .withExecutorFactory(executorFactory)
+ .build();
+ }
+
+ public static void shutdownTransferManager() {
+ synchronized (transferManagerLock) {
+ if (transferManager == null) {
+ return;
+ }
+ transferManager.shutdownNow(true);
+ transferManager = null;
+ }
+ }
+
+ public static class S3BucketAndKey {
+ private String bucket;
+ private String key;
+
+ public static S3BucketAndKey getFromUrl(String s3Url) {
+ URI url = URI.create(s3Url);
+ String bucket = url.getHost();
+ String key = url.getPath();
+ if (key.startsWith("/")) {
+ key = key.substring(1);
+ }
+ if (key.isEmpty()) {
+ throw new RuntimeException(String.format(
+ "Could not get object key in s3 url: %s", s3Url));
+ }
+ return new S3BucketAndKey(bucket, key);
+ }
+
+ public S3BucketAndKey(String bucket, String key) {
+ this.bucket = bucket;
+ this.key = key;
+ }
+
+ public String getBucket() {
+ return bucket;
+ }
+
+ public String getKey() {
+ return key;
+ }
+ }
+}
diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarShuffleFileManager.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarShuffleFileManager.java
new file mode 100644
index 0000000000000..65c7f849635bc
--- /dev/null
+++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarShuffleFileManager.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.starshuffle;
+
+import java.io.InputStream;
+
+/**
+ * This is interface to read/write shuffle file on external storage.
+ */
+public interface StarShuffleFileManager {
+ String createFile(String root);
+
+ void write(InputStream data, long size, String file);
+
+ InputStream read(String file, long offset, long size);
+}
diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarUtils.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarUtils.java
new file mode 100644
index 0000000000000..86ad58b98af6e
--- /dev/null
+++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StarUtils.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.starshuffle;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.storage.FileSegment;
+
+public class StarUtils {
+
+ public static long[] getLengths(FileSegment[] fileSegments) {
+ long[] lengths = new long[fileSegments.length];
+ for (int i = 0; i < lengths.length; i++) {
+ lengths[i] = fileSegments[i].length();
+ }
+ return lengths;
+ }
+
+ public static StarShuffleFileManager createShuffleFileManager(SparkConf conf, String path) {
+ if (path == null || path.isEmpty() || path.startsWith("/")) {
+ return new StarLocalFileShuffleFileManager();
+ } else if (path.toLowerCase().startsWith("s3")) {
+ return new StarS3ShuffleFileManager(conf);
+ } else {
+ throw new RuntimeException(String.format(
+ "Unsupported path for StarShuffleFileManager: %s", path));
+ }
+ }
+}
diff --git a/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StartFileSegmentWriter.java b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StartFileSegmentWriter.java
new file mode 100644
index 0000000000000..d10f7364bdd9d
--- /dev/null
+++ b/external-shuffle-storage/src/main/java/org/apache/spark/starshuffle/StartFileSegmentWriter.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.starshuffle;
+
+import java.io.*;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import org.apache.spark.SparkEnv;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.scheduler.MapStatus$;
+import org.apache.spark.shuffle.StarShuffleUtils;
+import org.apache.spark.storage.BlockManagerId;
+import org.apache.spark.storage.FileSegment;
+
+/**
+ * This class writes shuffle segments into the result file.
+ */
+public class StartFileSegmentWriter {
+ private String rootDir;
+
+ public StartFileSegmentWriter(String rootDir) {
+ this.rootDir = rootDir;
+ }
+
+ public MapStatus write(long mapId, FileSegment[] fileSegments) {
+ List allStreams = Arrays.stream(fileSegments).map(t -> {
+ try {
+ FileInputStream fileStream = new FileInputStream(t.file());
+ fileStream.skip(t.offset());
+ return new LimitedInputStream(fileStream, t.length());
+ } catch (IOException e) {
+ throw new RuntimeException(String.format(
+ "Failed to read shuffle temp file %s", t.file()),
+ e);
+ }
+ }).collect(Collectors.toList());
+ try (SequenceInputStream sequenceInputStream = new SequenceInputStream(
+ Collections.enumeration(allStreams))) {
+ StarShuffleFileManager shuffleFileManager = StarUtils.createShuffleFileManager(
+ SparkEnv.get().conf(), rootDir);
+ String resultFile = shuffleFileManager.createFile(rootDir);
+ long size = Arrays.stream(fileSegments).mapToLong(t->t.length()).sum();
+ shuffleFileManager.write(sequenceInputStream, size, resultFile);
+ long[] partitionLengths = new long[fileSegments.length];
+ for (int i = 0; i < partitionLengths.length; i++) {
+ partitionLengths[i] = fileSegments[i].length();
+ }
+ BlockManagerId blockManagerId = createMapTaskDummyBlockManagerId(
+ partitionLengths, resultFile);
+ MapStatus mapStatus = MapStatus$.MODULE$.apply(
+ blockManagerId, partitionLengths, mapId);
+ return mapStatus;
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to close shuffle temp files", e);
+ }
+ }
+
+ private BlockManagerId createMapTaskDummyBlockManagerId(long[] partitionLengths, String file) {
+ return StarShuffleUtils.createDummyBlockManagerId(file, partitionLengths);
+ }
+}
diff --git a/external-shuffle-storage/src/main/resources/log4j-external-shuffle-storage.properties b/external-shuffle-storage/src/main/resources/log4j-external-shuffle-storage.properties
new file mode 100644
index 0000000000000..db5d9e512204e
--- /dev/null
+++ b/external-shuffle-storage/src/main/resources/log4j-external-shuffle-storage.properties
@@ -0,0 +1,22 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Set everything to be logged to the console
+log4j.rootCategory=INFO, console
+log4j.appender.console=org.apache.log4j.ConsoleAppender
+log4j.appender.console.target=System.err
+log4j.appender.console.layout=org.apache.log4j.PatternLayout
+log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
diff --git a/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarOpts.scala b/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarOpts.scala
new file mode 100644
index 0000000000000..f2bcf56429c85
--- /dev/null
+++ b/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarOpts.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.apache.spark.internal.config.{ConfigBuilder, ConfigEntry}
+
+object StarOpts {
+ val rootDir: ConfigEntry[String] =
+ ConfigBuilder("spark.shuffle.star.rootDir")
+ .doc("Root directory for star shuffle files")
+ .stringConf
+ .createWithDefault("")
+}
diff --git a/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarShuffleManager.scala b/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarShuffleManager.scala
new file mode 100644
index 0000000000000..cd68806b6b478
--- /dev/null
+++ b/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarShuffleManager.scala
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.util.concurrent.ConcurrentHashMap
+import scala.collection.JavaConverters._
+import org.apache.spark._
+import org.apache.spark.internal.{Logging, config}
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.starshuffle.{StarBlockStoreClient, StarBypassMergeSortShuffleWriter, StarS3ShuffleFileManager}
+import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, StarShuffleBlockFetcherIterator}
+import org.apache.spark.util.CompletionIterator
+import org.apache.spark.util.collection.{ExternalSorter, OpenHashSet}
+
+/**
+ * This class is a Shuffle Manager implementation to store shuffle data on external storage
+ * like S3.
+ * Most code is copied from SortShuffleManager for quick prototype here. We will collect feedback
+ * from the community and work further to remove the code duplication.
+ */
+class StarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
+
+ import StarShuffleManager._
+
+ if (!conf.getBoolean("spark.shuffle.spill", true)) {
+ logWarning(
+ "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." +
+ " Shuffle will continue to spill to disk when necessary.")
+ }
+
+ /**
+ * A mapping from shuffle ids to the task ids of mappers producing output for those shuffles.
+ */
+ private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]()
+
+ private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf)
+
+ override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
+
+ /**
+ * Obtains a [[ShuffleHandle]] to pass to tasks.
+ */
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ new StarBypassMergeSortShuffleHandle[K, V](
+ shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ }
+
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
+ handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition)
+ new StarBlockStoreShuffleReader(
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics,
+ shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context))
+ }
+
+ /** Get a writer for a given partition. Called on executors by map tasks. */
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Long,
+ context: TaskContext,
+ metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+ val mapTaskIds = taskIdMapsForShuffle.computeIfAbsent(
+ handle.shuffleId, _ => new OpenHashSet[Long](16))
+ mapTaskIds.synchronized { mapTaskIds.add(context.taskAttemptId()) }
+ val env = SparkEnv.get
+ handle match {
+ case bypassMergeSortHandle: StarBypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
+ new StarBypassMergeSortShuffleWriter(
+ env.blockManager,
+ bypassMergeSortHandle,
+ mapId,
+ env.conf,
+ metrics,
+ shuffleExecutorComponents)
+ }
+ }
+
+ /** Remove a shuffle's metadata from the ShuffleManager. */
+ override def unregisterShuffle(shuffleId: Int): Boolean = {
+ Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { mapTaskIds =>
+ mapTaskIds.iterator.foreach { mapTaskId =>
+ shuffleBlockResolver.removeDataByMap(shuffleId, mapTaskId)
+ }
+ }
+ true
+ }
+
+ /** Shut down this ShuffleManager. */
+ override def stop(): Unit = {
+ shuffleBlockResolver.stop()
+
+ // TODO use a better way to shutdown TransferManager in StarS3ShuffleFileManager
+ StarS3ShuffleFileManager.shutdownTransferManager()
+ }
+}
+
+
+private[spark] object StarShuffleManager extends Logging {
+
+ /**
+ * The local property key for continuous shuffle block fetching feature.
+ */
+ val FETCH_SHUFFLE_BLOCKS_IN_BATCH_ENABLED_KEY =
+ "__fetch_continuous_blocks_in_batch_enabled"
+
+ /**
+ * Helper method for determining whether a shuffle reader should fetch the continuous blocks
+ * in batch.
+ */
+ def canUseBatchFetch(startPartition: Int, endPartition: Int, context: TaskContext): Boolean = {
+ val fetchMultiPartitions = endPartition - startPartition > 1
+ fetchMultiPartitions &&
+ context.getLocalProperty(FETCH_SHUFFLE_BLOCKS_IN_BATCH_ENABLED_KEY) == "true"
+ }
+
+ private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = {
+ val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor()
+ val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX)
+ .toMap
+ executorComponents.initializeExecutor(
+ conf.getAppId,
+ SparkEnv.get.executorId,
+ extraConfigs.asJava)
+ executorComponents
+ }
+}
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the
+ * serialized shuffle.
+ */
+private[spark] class SerializedShuffleHandle[K, V](
+ shuffleId: Int,
+ dependency: ShuffleDependency[K, V, V])
+ extends BaseShuffleHandle(shuffleId, dependency) {
+}
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the
+ * bypass merge sort shuffle path.
+ */
+private[spark] class StarBypassMergeSortShuffleHandle[K, V](
+ shuffleId: Int,
+ dependency: ShuffleDependency[K, V, V])
+ extends BaseShuffleHandle(shuffleId, dependency) {
+}
+
+private[spark] class StarBlockStoreShuffleReader[K, C](
+ handle: BaseShuffleHandle[K, _, C],
+ // (BlockId, Long, Int): similar like FetchBlockInfo (ShuffleBlockId, EstimatedSize, MapIndex)
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+ context: TaskContext,
+ readMetrics: ShuffleReadMetricsReporter,
+ serializerManager: SerializerManager = SparkEnv.get.serializerManager,
+ blockManager: BlockManager = SparkEnv.get.blockManager,
+ mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker,
+ shouldBatchFetch: Boolean = false)
+ extends ShuffleReader[K, C] with Logging {
+
+ private val dep = handle.dependency
+
+ private def fetchContinuousBlocksInBatch: Boolean = {
+ val conf = SparkEnv.get.conf
+ val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects
+ val compressed = conf.get(config.SHUFFLE_COMPRESS)
+ val codecConcatenation = if (compressed) {
+ CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf))
+ } else {
+ true
+ }
+ val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)
+
+ val doBatchFetch = shouldBatchFetch && serializerRelocatable &&
+ (!compressed || codecConcatenation) && !useOldFetchProtocol
+ if (shouldBatchFetch && !doBatchFetch) {
+ logDebug("The feature tag of continuous shuffle block fetching is set to true, but " +
+ "we can not enable the feature because other conditions are not satisfied. " +
+ s"Shuffle compress: $compressed, serializer relocatable: $serializerRelocatable, " +
+ s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " +
+ s"$useOldFetchProtocol.")
+ }
+ doBatchFetch
+ }
+
+ /** Read the combined key-values for this reduce task */
+ override def read(): Iterator[Product2[K, C]] = {
+ val starBlockStoreClient = new StarBlockStoreClient()
+ val wrappedStreams = new StarShuffleBlockFetcherIterator(
+ context,
+ starBlockStoreClient,
+ blockManager,
+ mapOutputTracker,
+ blocksByAddress,
+ serializerManager.wrapStream,
+ // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+ SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
+ SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
+ SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
+ SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
+ SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM),
+ SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
+ SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
+ SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED),
+ SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM),
+ readMetrics,
+ fetchContinuousBlocksInBatch).toCompletionIterator
+
+ val serializerInstance = dep.serializer.newInstance()
+
+ // Create a key/value iterator for each stream
+ val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
+ // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
+ // NextIterator. The NextIterator makes sure that close() is called on the
+ // underlying InputStream when all records have been read.
+ serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
+ }
+
+ // Update the context task metrics for each record read.
+ val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
+ recordIter.map { record =>
+ readMetrics.incRecordsRead(1)
+ record
+ },
+ context.taskMetrics().mergeShuffleReadMetrics())
+
+ // An interruptible iterator must be used here in order to support task cancellation
+ val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
+
+ val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
+ if (dep.mapSideCombine) {
+ // We are reading values that are already combined
+ val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
+ dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
+ } else {
+ // We don't know the value type, but also don't care -- the dependency *should*
+ // have made sure its compatible w/ this aggregator, which will convert the value
+ // type to the combined type C
+ val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
+ dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
+ }
+ } else {
+ interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
+ }
+
+ // Sort the output if there is a sort ordering defined.
+ val resultIter = dep.keyOrdering match {
+ case Some(keyOrd: Ordering[K]) =>
+ // Create an ExternalSorter to sort the data.
+ val sorter =
+ new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
+ sorter.insertAll(aggregatedIter)
+ context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
+ // Use completion callback to stop sorter if task was finished/cancelled.
+ context.addTaskCompletionListener[Unit](_ => {
+ sorter.stop()
+ })
+ CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
+ case None =>
+ aggregatedIter
+ }
+
+ resultIter match {
+ case _: InterruptibleIterator[Product2[K, C]] => resultIter
+ case _ =>
+ // Use another interruptible iterator here to support task cancellation as aggregator
+ // or(and) sorter may have consumed previous interruptible iterator.
+ new InterruptibleIterator[Product2[K, C]](context, resultIter)
+ }
+ }
+}
diff --git a/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarShuffleUtils.scala b/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarShuffleUtils.scala
new file mode 100644
index 0000000000000..edbbf5149afd3
--- /dev/null
+++ b/external-shuffle-storage/src/main/scala/org/apache/spark/shuffle/StarShuffleUtils.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.util.UUID
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.starshuffle.StarMapResultFileInfo
+import org.apache.spark.storage.BlockManagerId
+
+object StarShuffleUtils extends Logging {
+
+ def createDummyBlockManagerId(fileLocation: String,
+ partitionLengths: Array[Long]): BlockManagerId = {
+ val fileInfo = new StarMapResultFileInfo(fileLocation, partitionLengths)
+ val dummyHost = fileInfo.serializeToString()
+ val dummyPort = 9
+ val dummyExecId = "starshuffle-" + UUID.randomUUID().toString
+ BlockManagerId(dummyExecId, dummyHost, dummyPort, None)
+ }
+
+}
diff --git a/external-shuffle-storage/src/main/scala/org/apache/spark/storage/StarShuffleBlockFetcherIterator.scala b/external-shuffle-storage/src/main/scala/org/apache/spark/storage/StarShuffleBlockFetcherIterator.scala
new file mode 100644
index 0000000000000..00d5f11b7431e
--- /dev/null
+++ b/external-shuffle-storage/src/main/scala/org/apache/spark/storage/StarShuffleBlockFetcherIterator.scala
@@ -0,0 +1,1428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.{IOException, InputStream}
+import java.nio.channels.ClosedByInterruptException
+import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
+import java.util.zip.CheckedInputStream
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.mutable
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
+import scala.util.{Failure, Success}
+
+import io.netty.util.internal.OutOfDirectMemoryError
+import org.apache.commons.io.IOUtils
+import org.apache.spark.errors.SparkCoreErrors
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.shuffle._
+import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper}
+import org.apache.spark.network.util.{NettyUtils, TransportConf}
+import org.apache.spark.shuffle.ShuffleReadMetricsReporter
+import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}
+import org.apache.spark.{MapOutputTracker, TaskContext}
+import org.roaringbitmap.RoaringBitmap
+
+/**
+ * This class fetches shuffle data from external storage like S3.
+ * Most code is copied from ShuffleBlockFetcherIterator for quick prototype. Will work
+ * further to remove code duplication after getting community feedback.
+ */
+private[spark]
+final class StarShuffleBlockFetcherIterator(
+ context: TaskContext,
+ shuffleClient: BlockStoreClient,
+ blockManager: BlockManager,
+ mapOutputTracker: MapOutputTracker,
+ // (BlockId, Long, Int) in blocksByAddress is similar like:
+ // FetchBlockInfo (ShuffleBlockId, EstimatedSize, MapIndex)
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+ streamWrapper: (BlockId, InputStream) => InputStream,
+ maxBytesInFlight: Long,
+ maxReqsInFlight: Int,
+ maxBlocksInFlightPerAddress: Int,
+ val maxReqSizeShuffleToMem: Long,
+ maxAttemptsOnNettyOOM: Int,
+ detectCorrupt: Boolean,
+ detectCorruptUseExtraMemory: Boolean,
+ checksumEnabled: Boolean,
+ checksumAlgorithm: String,
+ shuffleMetrics: ShuffleReadMetricsReporter,
+ doBatchFetch: Boolean)
+ extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging {
+
+ import ShuffleBlockFetcherIterator._
+
+ // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
+ // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+ // nodes, rather than blocking on reading output from one node.
+ private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L)
+
+ /**
+ * Total number of blocks to fetch.
+ */
+ private[this] var numBlocksToFetch = 0
+
+ /**
+ * The number of blocks processed by the caller. The iterator is exhausted when
+ * [[numBlocksProcessed]] == [[numBlocksToFetch]].
+ */
+ private[this] var numBlocksProcessed = 0
+
+ private[this] val startTimeNs = System.nanoTime()
+
+ /** Host local blocks to fetch, excluding zero-sized blocks. */
+ private[this] val hostLocalBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]()
+
+ /**
+ * A queue to hold our results. This turns the asynchronous model provided by
+ * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator).
+ */
+ private[this] val results = new LinkedBlockingQueue[FetchResult]
+
+ /**
+ * Current [[FetchResult]] being processed. We track this so we can release the current buffer
+ * in case of a runtime exception when processing the current buffer.
+ */
+ @volatile private[this] var currentResult: SuccessFetchResult = null
+
+ /**
+ * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ * the number of bytes in flight is limited to maxBytesInFlight.
+ */
+ private[this] val fetchRequests = new Queue[FetchRequest]
+
+ /**
+ * Queue of fetch requests which could not be issued the first time they were dequeued. These
+ * requests are tried again when the fetch constraints are satisfied.
+ */
+ private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]()
+
+ /** Current bytes in flight from our requests */
+ private[this] var bytesInFlight = 0L
+
+ /** Current number of requests in flight */
+ private[this] var reqsInFlight = 0
+
+ /** Current number of blocks in flight per host:port */
+ private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]()
+
+ /**
+ * Count the retry times for the blocks due to Netty OOM. The block will stop retry if
+ * retry times has exceeded the [[maxAttemptsOnNettyOOM]].
+ */
+ private[this] val blockOOMRetryCounts = new HashMap[String, Int]
+
+ /**
+ * The blocks that can't be decompressed successfully, it is used to guarantee that we retry
+ * at most once for those corrupted blocks.
+ */
+ private[this] val corruptedBlocks = mutable.HashSet[BlockId]()
+
+ /**
+ * Whether the iterator is still active. If isZombie is true, the callback interface will no
+ * longer place fetched blocks into [[results]].
+ */
+ @GuardedBy("this")
+ private[this] var isZombie = false
+
+ /**
+ * A set to store the files used for shuffling remote huge blocks. Files in this set will be
+ * deleted when cleanup. This is a layer of defensiveness against disk file leaks.
+ */
+ @GuardedBy("this")
+ private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]()
+
+ private[this] val onCompleteCallback = new StarShuffleFetchCompletionListener(this)
+
+ initialize()
+
+ // Decrements the buffer reference count.
+ // The currentResult is set to null to prevent releasing the buffer again on cleanup()
+ private[storage] def releaseCurrentResultBuffer(): Unit = {
+ // Release the current buffer if necessary
+ if (currentResult != null) {
+ currentResult.buf.release()
+ }
+ currentResult = null
+ }
+
+ override def createTempFile(transportConf: TransportConf): DownloadFile = {
+ // we never need to do any encryption or decryption here, regardless of configs, because that
+ // is handled at another layer in the code. When encryption is enabled, shuffle data is written
+ // to disk encrypted in the first place, and sent over the network still encrypted.
+ new SimpleDownloadFile(
+ blockManager.diskBlockManager.createTempLocalBlock()._2, transportConf)
+ }
+
+ override def registerTempFileToClean(file: DownloadFile): Boolean = synchronized {
+ if (isZombie) {
+ false
+ } else {
+ shuffleFilesSet += file
+ true
+ }
+ }
+
+ /**
+ * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
+ */
+ private[storage] def cleanup(): Unit = {
+ synchronized {
+ isZombie = true
+ }
+ releaseCurrentResultBuffer()
+ // Release buffers in the results queue
+ val iter = results.iterator()
+ while (iter.hasNext) {
+ val result = iter.next()
+ result match {
+ case SuccessFetchResult(blockId, mapIndex, address, _, buf, _) =>
+ if (address != blockManager.blockManagerId) {
+ if (hostLocalBlocks.contains(blockId -> mapIndex)) {
+ shuffleMetrics.incLocalBlocksFetched(1)
+ shuffleMetrics.incLocalBytesRead(buf.size)
+ } else {
+ shuffleMetrics.incRemoteBytesRead(buf.size)
+ if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
+ shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
+ }
+ shuffleMetrics.incRemoteBlocksFetched(1)
+ }
+ }
+ buf.release()
+ case _ =>
+ }
+ }
+ shuffleFilesSet.foreach { file =>
+ if (!file.delete()) {
+ logWarning("Failed to cleanup shuffle fetch temp file " + file.path())
+ }
+ }
+ }
+
+ private[this] def sendRequest(req: FetchRequest): Unit = {
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
+ bytesInFlight += req.size
+ reqsInFlight += 1
+
+ // so we can look up the block info of each blockID
+ val infoMap = req.blocks.map {
+ case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex))
+ }.toMap
+ val remainingBlocks = new HashSet[String]() ++= infoMap.keys
+ val deferredBlocks = new ArrayBuffer[String]()
+ val blockIds = req.blocks.map(_.blockId.toString)
+ val address = req.address
+
+ @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = {
+ if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) {
+ val blocks = deferredBlocks.map { blockId =>
+ val (size, mapIndex) = infoMap(blockId)
+ FetchBlockInfo(BlockId(blockId), size, mapIndex)
+ }
+ results.put(DeferFetchRequestResult(FetchRequest(address, blocks.toSeq)))
+ deferredBlocks.clear()
+ }
+ }
+
+ val blockFetchingListener = new BlockFetchingListener {
+ override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
+ // Only add the buffer to results queue if the iterator is not zombie,
+ // i.e. cleanup() has not been called yet.
+ StarShuffleBlockFetcherIterator.this.synchronized {
+ if (!isZombie) {
+ // Increment the ref count because we need to pass this to a different thread.
+ // This needs to be released after use.
+ buf.retain()
+ remainingBlocks -= blockId
+ blockOOMRetryCounts.remove(blockId)
+ results.put(new SuccessFetchResult(BlockId(blockId), infoMap(blockId)._2,
+ address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty))
+ logDebug("remainingBlocks: " + remainingBlocks)
+ enqueueDeferredFetchRequestIfNecessary()
+ }
+ }
+ logTrace(s"Got remote block $blockId after ${Utils.getUsedTimeNs(startTimeNs)}")
+ }
+
+ override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
+ StarShuffleBlockFetcherIterator.this.synchronized {
+ logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
+ e match {
+ // SPARK-27991: Catch the Netty OOM and set the flag `isNettyOOMOnShuffle` (shared among
+ // tasks) to true as early as possible. The pending fetch requests won't be sent
+ // afterwards until the flag is set to false on:
+ // 1) the Netty free memory >= maxReqSizeShuffleToMem
+ // - we'll check this whenever there's a fetch request succeeds.
+ // 2) the number of in-flight requests becomes 0
+ // - we'll check this in `fetchUpToMaxBytes` whenever it's invoked.
+ // Although Netty memory is shared across multiple modules, e.g., shuffle, rpc, the flag
+ // only takes effect for the shuffle due to the implementation simplicity concern.
+ // And we'll buffer the consecutive block failures caused by the OOM error until there's
+ // no remaining blocks in the current request. Then, we'll package these blocks into
+ // a same fetch request for the retry later. In this way, instead of creating the fetch
+ // request per block, it would help reduce the concurrent connections and data loads
+ // pressure at remote server.
+ // Note that catching OOM and do something based on it is only a workaround for
+ // handling the Netty OOM issue, which is not the best way towards memory management.
+ // We can get rid of it when we find a way to manage Netty's memory precisely.
+ case _: OutOfDirectMemoryError
+ if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < maxAttemptsOnNettyOOM =>
+ if (!isZombie) {
+ val failureTimes = blockOOMRetryCounts(blockId)
+ blockOOMRetryCounts(blockId) += 1
+ if (isNettyOOMOnShuffle.compareAndSet(false, true)) {
+ // The fetcher can fail remaining blocks in batch for the same error. So we only
+ // log the warning once to avoid flooding the logs.
+ logInfo(s"Block $blockId has failed $failureTimes times " +
+ s"due to Netty OOM, will retry")
+ }
+ remainingBlocks -= blockId
+ deferredBlocks += blockId
+ enqueueDeferredFetchRequestIfNecessary()
+ }
+
+ case _ =>
+ val block = BlockId(blockId)
+ if (block.isShuffleChunk) {
+ remainingBlocks -= blockId
+ results.put(FallbackOnPushMergedFailureResult(
+ block, address, infoMap(blockId)._1, remainingBlocks.isEmpty))
+ } else {
+ results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e))
+ }
+ }
+ }
+ }
+ }
+
+ // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is
+ // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
+ // the data and write it to file directly.
+ if (req.size > maxReqSizeShuffleToMem) {
+ shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
+ blockFetchingListener, this)
+ } else {
+ shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
+ blockFetchingListener, null)
+ }
+ }
+
+ /**
+ * This is called from initialize and also from the fallback which is triggered from
+ * [[PushBasedFetchHelper]].
+ */
+ private[this] def partitionBlocksByFetchMode(
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+ localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+ hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+ pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
+ logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
+ + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
+
+ // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote)
+ // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight
+ // in order to limit the amount of data in flight
+ val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+ var localBlockBytes = 0L
+ var hostLocalBlockBytes = 0L
+ var numHostLocalBlocks = 0
+ var pushMergedLocalBlockBytes = 0L
+ val prevNumBlocksToFetch = numBlocksToFetch
+
+ val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
+ val localExecIds = Set(blockManager.blockManagerId.executorId, fallback)
+ for ((address, blockInfos) <- blocksByAddress) {
+ checkBlockSizes(blockInfos)
+ if (localExecIds.contains(address.executorId)) {
+ val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+ blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
+ numBlocksToFetch += mergedBlockInfos.size
+ localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex))
+ localBlockBytes += mergedBlockInfos.map(_.size).sum
+ } else if (blockManager.hostLocalDirManager.isDefined &&
+ address.host == blockManager.blockManagerId.host) {
+ val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+ blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
+ numBlocksToFetch += mergedBlockInfos.size
+ val blocksForAddress =
+ mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex))
+ hostLocalBlocksByExecutor += address -> blocksForAddress
+ numHostLocalBlocks += blocksForAddress.size
+ hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
+ } else {
+ val (_, timeCost) = Utils.timeTakenMs[Unit] {
+ collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+ }
+ logDebug(s"Collected remote fetch requests for $address in $timeCost ms")
+ }
+ }
+ val (remoteBlockBytes, numRemoteBlocks) =
+ collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size))
+ val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+ pushMergedLocalBlockBytes
+ val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+ assert(blocksToFetchCurrentIteration == localBlocks.size +
+ numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size,
+ s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to the sum " +
+ s"of the number of local blocks ${localBlocks.size} + " +
+ s"the number of host-local blocks ${numHostLocalBlocks} " +
+ s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " +
+ s"+ the number of remote blocks ${numRemoteBlocks} ")
+ logInfo(s"Getting $blocksToFetchCurrentIteration " +
+ s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+ s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
+ s"${numHostLocalBlocks} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
+ s"host-local and ${pushMergedLocalBlocks.size} " +
+ s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " +
+ s"push-merged-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " +
+ s"remote blocks")
+ this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values
+ .flatMap { infos => infos.map(info => (info._1, info._3)) }
+ collectedRemoteRequests
+ }
+
+ private def createFetchRequest(
+ blocks: Seq[FetchBlockInfo],
+ address: BlockManagerId,
+ forMergedMetas: Boolean): FetchRequest = {
+ logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address "
+ + s"with ${blocks.size} blocks")
+ FetchRequest(address, blocks, forMergedMetas)
+ }
+
+ private def createFetchRequests(
+ curBlocks: Seq[FetchBlockInfo],
+ address: BlockManagerId,
+ isLast: Boolean,
+ collectedRemoteRequests: ArrayBuffer[FetchRequest],
+ enableBatchFetch: Boolean,
+ forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = {
+ val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch)
+ numBlocksToFetch += mergedBlocks.size
+ val retBlocks = new ArrayBuffer[FetchBlockInfo]
+ if (mergedBlocks.length <= maxBlocksInFlightPerAddress) {
+ collectedRemoteRequests += createFetchRequest(mergedBlocks, address, forMergedMetas)
+ } else {
+ mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { blocks =>
+ if (blocks.length == maxBlocksInFlightPerAddress || isLast) {
+ collectedRemoteRequests += createFetchRequest(blocks, address, forMergedMetas)
+ } else {
+ // The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back
+ // to `curBlocks`.
+ retBlocks ++= blocks
+ numBlocksToFetch -= blocks.size
+ }
+ }
+ }
+ retBlocks
+ }
+
+ private def collectFetchRequests(
+ address: BlockManagerId,
+ blockInfos: Seq[(BlockId, Long, Int)],
+ collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = {
+ val iterator = blockInfos.iterator
+ var curRequestSize = 0L
+ var curBlocks = new ArrayBuffer[FetchBlockInfo]()
+
+ while (iterator.hasNext) {
+ val (blockId, size, mapIndex) = iterator.next()
+ curBlocks += FetchBlockInfo(blockId, size, mapIndex)
+ curRequestSize += size
+ blockId match {
+ // Either all blocks are push-merged blocks, shuffle chunks, or original blocks.
+ // Based on these types, we decide to do batch fetch and create FetchRequests with
+ // forMergedMetas set.
+ case ShuffleBlockChunkId(_, _, _, _) =>
+ if (curRequestSize >= targetRemoteRequestSize ||
+ curBlocks.size >= maxBlocksInFlightPerAddress) {
+ curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false,
+ collectedRemoteRequests, enableBatchFetch = false)
+ curRequestSize = curBlocks.map(_.size).sum
+ }
+ case ShuffleMergedBlockId(_, _, _) =>
+ if (curBlocks.size >= maxBlocksInFlightPerAddress) {
+ curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false,
+ collectedRemoteRequests, enableBatchFetch = false, forMergedMetas = true)
+ }
+ case _ =>
+ // For batch fetch, the actual block in flight should count for merged block.
+ val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress
+ if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
+ curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false,
+ collectedRemoteRequests, doBatchFetch)
+ curRequestSize = curBlocks.map(_.size).sum
+ }
+ }
+ }
+ // Add in the final request
+ if (curBlocks.nonEmpty) {
+ val (enableBatchFetch, forMergedMetas) = {
+ curBlocks.head.blockId match {
+ case ShuffleBlockChunkId(_, _, _, _) => (false, false)
+ case ShuffleMergedBlockId(_, _, _) => (false, true)
+ case _ => (doBatchFetch, false)
+ }
+ }
+ createFetchRequests(curBlocks.toSeq, address, isLast = true, collectedRemoteRequests,
+ enableBatchFetch = enableBatchFetch, forMergedMetas = forMergedMetas)
+ }
+ }
+
+ private def assertPositiveBlockSize(blockId: BlockId, blockSize: Long): Unit = {
+ if (blockSize < 0) {
+ throw BlockException(blockId, "Negative block size " + size)
+ } else if (blockSize == 0) {
+ throw BlockException(blockId, "Zero-sized blocks should be excluded.")
+ }
+ }
+
+ private def checkBlockSizes(blockInfos: Seq[(BlockId, Long, Int)]): Unit = {
+ blockInfos.foreach { case (blockId, size, _) => assertPositiveBlockSize(blockId, size) }
+ }
+
+ /**
+ * Fetch the local blocks while we are fetching remote blocks. This is ok because
+ * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we
+ * track in-memory are the ManagedBuffer references themselves.
+ */
+ private[this] def fetchLocalBlocks(
+ localBlocks: mutable.LinkedHashSet[(BlockId, Int)]): Unit = {
+ logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}")
+ val iter = localBlocks.iterator
+ while (iter.hasNext) {
+ val (blockId, mapIndex) = iter.next()
+ try {
+ val buf = blockManager.getLocalBlockData(blockId)
+ shuffleMetrics.incLocalBlocksFetched(1)
+ shuffleMetrics.incLocalBytesRead(buf.size)
+ buf.retain()
+ results.put(new SuccessFetchResult(blockId, mapIndex, blockManager.blockManagerId,
+ buf.size(), buf, false))
+ } catch {
+ // If we see an exception, stop immediately.
+ case e: Exception =>
+ e match {
+ // ClosedByInterruptException is an excepted exception when kill task,
+ // don't log the exception stack trace to avoid confusing users.
+ // See: SPARK-28340
+ case ce: ClosedByInterruptException =>
+ logError("Error occurred while fetching local blocks, " + ce.getMessage)
+ case ex: Exception => logError("Error occurred while fetching local blocks", ex)
+ }
+ results.put(new FailureFetchResult(blockId, mapIndex, blockManager.blockManagerId, e))
+ return
+ }
+ }
+ }
+
+ private[this] def fetchHostLocalBlock(
+ blockId: BlockId,
+ mapIndex: Int,
+ localDirs: Array[String],
+ blockManagerId: BlockManagerId): Boolean = {
+ try {
+ val buf = blockManager.getHostLocalShuffleData(blockId, localDirs)
+ buf.retain()
+ results.put(SuccessFetchResult(blockId, mapIndex, blockManagerId, buf.size(), buf,
+ isNetworkReqDone = false))
+ true
+ } catch {
+ case e: Exception =>
+ // If we see an exception, stop immediately.
+ logError(s"Error occurred while fetching local blocks", e)
+ results.put(FailureFetchResult(blockId, mapIndex, blockManagerId, e))
+ false
+ }
+ }
+
+ /**
+ * Fetch the host-local blocks while we are fetching remote blocks. This is ok because
+ * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we
+ * track in-memory are the ManagedBuffer references themselves.
+ */
+ private[this] def fetchHostLocalBlocks(
+ hostLocalDirManager: HostLocalDirManager,
+ hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]):
+ Unit = {
+ val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs
+ val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = {
+ val (hasCache, noCache) = hostLocalBlocksByExecutor.partition { case (hostLocalBmId, _) =>
+ cachedDirsByExec.contains(hostLocalBmId.executorId)
+ }
+ (hasCache.toMap, noCache.toMap)
+ }
+
+ if (hostLocalBlocksWithMissingDirs.nonEmpty) {
+ logDebug(s"Asynchronous fetching host-local blocks without cached executors' dir: " +
+ s"${hostLocalBlocksWithMissingDirs.mkString(", ")}")
+
+ // If the external shuffle service is enabled, we'll fetch the local directories for
+ // multiple executors from the external shuffle service, which located at the same host
+ // with the executors, in once. Otherwise, we'll fetch the local directories from those
+ // executors directly one by one. The fetch requests won't be too much since one host is
+ // almost impossible to have many executors at the same time practically.
+ val dirFetchRequests = if (blockManager.externalShuffleServiceEnabled) {
+ val host = blockManager.blockManagerId.host
+ val port = blockManager.externalShuffleServicePort
+ Seq((host, port, hostLocalBlocksWithMissingDirs.keys.toArray))
+ } else {
+ hostLocalBlocksWithMissingDirs.keys.map(bmId => (bmId.host, bmId.port, Array(bmId))).toSeq
+ }
+
+ dirFetchRequests.foreach { case (host, port, bmIds) =>
+ hostLocalDirManager.getHostLocalDirs(host, port, bmIds.map(_.executorId)) {
+ case Success(dirsByExecId) =>
+ fetchMultipleHostLocalBlocks(
+ hostLocalBlocksWithMissingDirs.filterKeys(bmIds.contains).toMap,
+ dirsByExecId,
+ cached = false)
+
+ case Failure(throwable) =>
+ logError("Error occurred while fetching host local blocks", throwable)
+ val bmId = bmIds.head
+ val blockInfoSeq = hostLocalBlocksWithMissingDirs(bmId)
+ val (blockId, _, mapIndex) = blockInfoSeq.head
+ results.put(FailureFetchResult(blockId, mapIndex, bmId, throwable))
+ }
+ }
+ }
+
+ if (hostLocalBlocksWithCachedDirs.nonEmpty) {
+ logDebug(s"Synchronous fetching host-local blocks with cached executors' dir: " +
+ s"${hostLocalBlocksWithCachedDirs.mkString(", ")}")
+ fetchMultipleHostLocalBlocks(hostLocalBlocksWithCachedDirs, cachedDirsByExec, cached = true)
+ }
+ }
+
+ private def fetchMultipleHostLocalBlocks(
+ bmIdToBlocks: Map[BlockManagerId, Seq[(BlockId, Long, Int)]],
+ localDirsByExecId: Map[String, Array[String]],
+ cached: Boolean): Unit = {
+ // We use `forall` because once there's a failed block fetch, `fetchHostLocalBlock` will put
+ // a `FailureFetchResult` immediately to the `results`. So there's no reason to fetch the
+ // remaining blocks.
+ val allFetchSucceeded = bmIdToBlocks.forall { case (bmId, blockInfos) =>
+ blockInfos.forall { case (blockId, _, mapIndex) =>
+ fetchHostLocalBlock(blockId, mapIndex, localDirsByExecId(bmId.executorId), bmId)
+ }
+ }
+ if (allFetchSucceeded) {
+ logDebug(s"Got host-local blocks from ${bmIdToBlocks.keys.mkString(", ")} " +
+ s"(${if (cached) "with" else "without"} cached executors' dir) " +
+ s"in ${Utils.getUsedTimeNs(startTimeNs)}")
+ }
+ }
+
+ private[this] def initialize(): Unit = {
+ // Add a task completion callback (called in both success case and failure case) to cleanup.
+ context.addTaskCompletionListener(onCompleteCallback)
+ // Local blocks to fetch, excluding zero-sized blocks.
+ val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
+ val hostLocalBlocksByExecutor =
+ mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
+ val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
+ // Partition blocks by the different fetch modes: local, host-local, push-merged-local and
+ // remote blocks.
+ val remoteRequests = partitionBlocksByFetchMode(
+ blocksByAddress, localBlocks, hostLocalBlocksByExecutor, pushMergedLocalBlocks)
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(remoteRequests)
+ assert ((0 == reqsInFlight) == (0 == bytesInFlight),
+ "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
+ ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)
+
+ // Send out initial requests for blocks, up to our maxBytesInFlight
+ fetchUpToMaxBytes()
+
+ val numDeferredRequest = deferredFetchRequests.values.map(_.size).sum
+ val numFetches = remoteRequests.size - fetchRequests.size - numDeferredRequest
+ logInfo(s"Started $numFetches remote fetches in ${Utils.getUsedTimeNs(startTimeNs)}" +
+ (if (numDeferredRequest > 0 ) s", deferred $numDeferredRequest requests" else ""))
+
+ // Get Local Blocks
+ fetchLocalBlocks(localBlocks)
+ logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}")
+ // Get host local blocks if any
+ fetchAllHostLocalBlocks(hostLocalBlocksByExecutor)
+ }
+
+ private def fetchAllHostLocalBlocks(
+ hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]):
+ Unit = {
+ if (hostLocalBlocksByExecutor.nonEmpty) {
+ blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks(_, hostLocalBlocksByExecutor))
+ }
+ }
+
+ override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
+
+ /**
+ * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers
+ * underlying each InputStream will be freed by the cleanup() method registered with the
+ * TaskCompletionListener. However, callers should close() these InputStreams
+ * as soon as they are no longer needed, in order to release memory as early as possible.
+ *
+ * Throws a FetchFailedException if the next block could not be fetched.
+ */
+ override def next(): (BlockId, InputStream) = {
+ if (!hasNext) {
+ throw SparkCoreErrors.noSuchElementError()
+ }
+
+ numBlocksProcessed += 1
+
+ var result: FetchResult = null
+ var input: InputStream = null
+ // This's only initialized when shuffle checksum is enabled.
+ var checkedIn: CheckedInputStream = null
+ var streamCompressedOrEncrypted: Boolean = false
+ // Take the next fetched result and try to decompress it to detect data corruption,
+ // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
+ // is also corrupt, so the previous stage could be retried.
+ // For local shuffle block, throw FailureFetchResult for the first IOException.
+ while (result == null) {
+ val startFetchWait = System.nanoTime()
+ result = results.take()
+ val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait)
+ shuffleMetrics.incFetchWaitTime(fetchWaitTime)
+
+ result match {
+ case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) =>
+ if (address != blockManager.blockManagerId) {
+ if (hostLocalBlocks.contains(blockId -> mapIndex)) {
+ // It is a host local block or a local shuffle chunk
+ shuffleMetrics.incLocalBlocksFetched(1)
+ shuffleMetrics.incLocalBytesRead(buf.size)
+ } else {
+ numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+ shuffleMetrics.incRemoteBytesRead(buf.size)
+ if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
+ shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
+ }
+ shuffleMetrics.incRemoteBlocksFetched(1)
+ bytesInFlight -= size
+ }
+ }
+ if (isNetworkReqDone) {
+ reqsInFlight -= 1
+ resetNettyOOMFlagIfPossible(maxReqSizeShuffleToMem)
+ logDebug("Number of requests in flight " + reqsInFlight)
+ }
+
+ if (buf.size == 0) {
+ // We will never legitimately receive a zero-size block. All blocks with zero records
+ // have zero size and all zero-size blocks have no records (and hence should never
+ // have been requested in the first place). This statement relies on behaviors of the
+ // shuffle writers, which are guaranteed by the following test cases:
+ //
+ // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions"
+ // - UnsafeShuffleWriterSuite: "writeEmptyIterator"
+ // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing"
+ //
+ // There is not an explicit test for SortShuffleWriter but the underlying APIs that
+ // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter
+ // which returns a zero-size from commitAndGet() in case no records were written
+ // since the last call.
+ val msg = s"Received a zero-size buffer for block $blockId from $address " +
+ s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)"
+ throwFetchFailedException(blockId, mapIndex, address, new IOException(msg))
+ }
+
+ val in = try {
+ var bufIn = buf.createInputStream()
+ if (checksumEnabled) {
+ val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm)
+ checkedIn = new CheckedInputStream(bufIn, checksum)
+ checkedIn
+ } else {
+ bufIn
+ }
+ } catch {
+ // The exception could only be throwed by local shuffle block
+ case e: IOException =>
+ assert(buf.isInstanceOf[FileSegmentManagedBuffer])
+ e match {
+ case ce: ClosedByInterruptException =>
+ logError("Failed to create input stream from local block, " +
+ ce.getMessage)
+ case e: IOException => logError("Failed to create input stream from local block", e)
+ }
+ buf.release()
+ throwFetchFailedException(blockId, mapIndex, address, e)
+ }
+ if (in != null) {
+ try {
+ input = streamWrapper(blockId, in)
+ // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
+ // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
+ // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
+ // the corruption is later, we'll still detect the corruption later in the stream.
+ streamCompressedOrEncrypted = !input.eq(in)
+ if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
+ // TODO: manage the memory used here, and spill it into disk in case of OOM.
+ input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
+ }
+ } catch {
+ case e: IOException =>
+ // When shuffle checksum is enabled, for a block that is corrupted twice,
+ // we'd calculate the checksum of the block by consuming the remaining data
+ // in the buf. So, we should release the buf later.
+ if (!(checksumEnabled && corruptedBlocks.contains(blockId))) {
+ buf.release()
+ }
+
+ if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
+ throwFetchFailedException(blockId, mapIndex, address, e)
+ } else if (corruptedBlocks.contains(blockId)) {
+ // It's the second time this block is detected corrupted
+ if (checksumEnabled) {
+ // Diagnose the cause of data corruption if shuffle checksum is enabled
+ val diagnosisResponse = diagnoseCorruption(checkedIn, address, blockId)
+ buf.release()
+ logError(diagnosisResponse)
+ throwFetchFailedException(
+ blockId, mapIndex, address, e, Some(diagnosisResponse))
+ } else {
+ throwFetchFailedException(blockId, mapIndex, address, e)
+ }
+ } else {
+ // It's the first time this block is detected corrupted
+ logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
+ corruptedBlocks += blockId
+ fetchRequests += FetchRequest(
+ address, Array(FetchBlockInfo(blockId, size, mapIndex)))
+ result = null
+ }
+ } finally {
+ // TODO: release the buf here to free memory earlier
+ if (input == null) {
+ // Close the underlying stream if there was an issue in wrapping the stream using
+ // streamWrapper
+ in.close()
+ }
+ }
+ }
+
+ case FailureFetchResult(blockId, mapIndex, address, e) =>
+ var errorMsg: String = null
+ if (e.isInstanceOf[OutOfDirectMemoryError]) {
+ errorMsg = s"Block $blockId fetch failed after $maxAttemptsOnNettyOOM " +
+ s"retries due to Netty OOM"
+ logError(errorMsg)
+ }
+ throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg))
+
+ case DeferFetchRequestResult(request) =>
+ val address = request.address
+ numBlocksInFlightPerAddress(address) =
+ numBlocksInFlightPerAddress(address) - request.blocks.size
+ bytesInFlight -= request.size
+ reqsInFlight -= 1
+ logDebug("Number of requests in flight " + reqsInFlight)
+ val defReqQueue =
+ deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
+ defReqQueue.enqueue(request)
+ result = null
+
+ case FallbackOnPushMergedFailureResult(_, _, _, _) =>
+ result = null
+ case PushMergedLocalMetaFetchResult(_, _, _, _, _) =>
+ result = null
+ case PushMergedRemoteMetaFailedFetchResult(_, _, _, _) =>
+ result = null
+ case PushMergedRemoteMetaFetchResult(_, _, _, _, _, _) =>
+ result = null
+ }
+
+ // Send fetch requests up to maxBytesInFlight
+ fetchUpToMaxBytes()
+ }
+
+ currentResult = result.asInstanceOf[SuccessFetchResult]
+ (currentResult.blockId,
+ new StarBufferReleasingInputStream(
+ input,
+ this,
+ currentResult.blockId,
+ currentResult.mapIndex,
+ currentResult.address,
+ detectCorrupt && streamCompressedOrEncrypted,
+ currentResult.isNetworkReqDone,
+ Option(checkedIn)))
+ }
+
+ /**
+ * Get the suspect corruption cause for the corrupted block. It should be only invoked
+ * when checksum is enabled and corruption was detected at least once.
+ *
+ * This will firstly consume the rest of stream of the corrupted block to calculate the
+ * checksum of the block. Then, it will raise a synchronized RPC call along with the
+ * checksum to ask the server(where the corrupted block is fetched from) to diagnose the
+ * cause of corruption and return it.
+ *
+ * Any exception raised during the process will result in the [[Cause.UNKNOWN_ISSUE]] of the
+ * corruption cause since corruption diagnosis is only a best effort.
+ *
+ * @param checkedIn the [[CheckedInputStream]] which is used to calculate the checksum.
+ * @param address the address where the corrupted block is fetched from.
+ * @param blockId the blockId of the corrupted block.
+ * @return The corruption diagnosis response for different causes.
+ */
+ private[storage] def diagnoseCorruption(
+ checkedIn: CheckedInputStream,
+ address: BlockManagerId,
+ blockId: BlockId): String = {
+ logInfo("Start corruption diagnosis.")
+ val startTimeNs = System.nanoTime()
+ assert(blockId.isInstanceOf[ShuffleBlockId], s"Expected ShuffleBlockId, but got $blockId")
+ val shuffleBlock = blockId.asInstanceOf[ShuffleBlockId]
+ val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER)
+ // consume the remaining data to calculate the checksum
+ var cause: Cause = null
+ try {
+ while (checkedIn.read(buffer) != -1) {}
+ val checksum = checkedIn.getChecksum.getValue
+ cause = shuffleClient.diagnoseCorruption(address.host, address.port, address.executorId,
+ shuffleBlock.shuffleId, shuffleBlock.mapId, shuffleBlock.reduceId, checksum,
+ checksumAlgorithm)
+ } catch {
+ case e: Exception =>
+ logWarning("Unable to diagnose the corruption cause of the corrupted block", e)
+ cause = Cause.UNKNOWN_ISSUE
+ }
+ val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)
+ val diagnosisResponse = cause match {
+ case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM =>
+ s"Block $blockId is corrupted but corruption diagnosis failed due to " +
+ s"unsupported checksum algorithm: $checksumAlgorithm"
+
+ case Cause.CHECKSUM_VERIFY_PASS =>
+ s"Block $blockId is corrupted but checksum verification passed"
+
+ case Cause.UNKNOWN_ISSUE =>
+ s"Block $blockId is corrupted but the cause is unknown"
+
+ case otherCause =>
+ s"Block $blockId is corrupted due to $otherCause"
+ }
+ logInfo(s"Finished corruption diagnosis in $duration ms. $diagnosisResponse")
+ diagnosisResponse
+ }
+
+ def toCompletionIterator: Iterator[(BlockId, InputStream)] = {
+ CompletionIterator[(BlockId, InputStream), this.type](this,
+ onCompleteCallback.onComplete(context))
+ }
+
+ private def fetchUpToMaxBytes(): Unit = {
+ if (isNettyOOMOnShuffle.get()) {
+ if (reqsInFlight > 0) {
+ // Return immediately if Netty is still OOMed and there're ongoing fetch requests
+ return
+ } else {
+ resetNettyOOMFlagIfPossible(0)
+ }
+ }
+
+ // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
+ // immediately, defer the request until the next time it can be processed.
+
+ // Process any outstanding deferred fetch requests if possible.
+ if (deferredFetchRequests.nonEmpty) {
+ for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
+ while (isRemoteBlockFetchable(defReqQueue) &&
+ !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
+ val request = defReqQueue.dequeue()
+ logDebug(s"Processing deferred fetch request for $remoteAddress with "
+ + s"${request.blocks.length} blocks")
+ send(remoteAddress, request)
+ if (defReqQueue.isEmpty) {
+ deferredFetchRequests -= remoteAddress
+ }
+ }
+ }
+ }
+
+ // Process any regular fetch requests if possible.
+ while (isRemoteBlockFetchable(fetchRequests)) {
+ val request = fetchRequests.dequeue()
+ val remoteAddress = request.address
+ if (isRemoteAddressMaxedOut(remoteAddress, request)) {
+ logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks")
+ val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
+ defReqQueue.enqueue(request)
+ deferredFetchRequests(remoteAddress) = defReqQueue
+ } else {
+ send(remoteAddress, request)
+ }
+ }
+
+ def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
+ sendRequest(request)
+ numBlocksInFlightPerAddress(remoteAddress) =
+ numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
+ }
+
+ def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = {
+ fetchReqQueue.nonEmpty &&
+ (bytesInFlight == 0 ||
+ (reqsInFlight + 1 <= maxReqsInFlight &&
+ bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight))
+ }
+
+ // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a
+ // given remote address.
+ def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = {
+ numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size >
+ maxBlocksInFlightPerAddress
+ }
+ }
+
+ private[storage] def throwFetchFailedException(
+ blockId: BlockId,
+ mapIndex: Int,
+ address: BlockManagerId,
+ e: Throwable,
+ message: Option[String] = None) = {
+ val msg = message.getOrElse(e.getMessage)
+ blockId match {
+ case ShuffleBlockId(shufId, mapId, reduceId) =>
+ throw SparkCoreErrors.fetchFailedError(address, shufId, mapId, mapIndex, reduceId, msg, e)
+ case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) =>
+ throw SparkCoreErrors.fetchFailedError(address, shuffleId, mapId, mapIndex, startReduceId,
+ msg, e)
+ case _ => throw SparkCoreErrors.failToGetNonShuffleBlockError(blockId, e)
+ }
+ }
+
+ /**
+ * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator
+ */
+ private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+ results.put(result)
+ }
+
+ private[storage] def decreaseNumBlocksToFetch(blocksFetched: Int): Unit = {
+ numBlocksToFetch -= blocksFetched
+ }
+
+ /**
+ * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch
+ * failure related to a push-merged block or shuffle chunk.
+ * This is executed by the task thread when the `iterator.next()` is invoked and if that initiates
+ * fallback.
+ */
+ private[storage] def fallbackFetch(
+ originalBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = {
+ val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
+ val originalHostLocalBlocksByExecutor =
+ mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
+ val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
+ val originalRemoteReqs = partitionBlocksByFetchMode(originalBlocksByAddr,
+ originalLocalBlocks, originalHostLocalBlocksByExecutor, originalMergedLocalBlocks)
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(originalRemoteReqs)
+ logInfo(s"Created ${originalRemoteReqs.size} fallback remote requests for push-merged")
+ // fetch all the fallback blocks that are local.
+ fetchLocalBlocks(originalLocalBlocks)
+ // Merged local blocks should be empty during fallback
+ assert(originalMergedLocalBlocks.isEmpty,
+ "There should be zero push-merged blocks during fallback")
+ // Some of the fallback local blocks could be host local blocks
+ fetchAllHostLocalBlocks(originalHostLocalBlocksByExecutor)
+ }
+
+ /**
+ * Removes all the pending shuffle chunks that are on the same host and have the same reduceId as
+ * the current chunk that had a fetch failure.
+ * This is executed by the task thread when the `iterator.next()` is invoked and if that initiates
+ * fallback.
+ *
+ * @return set of all the removed shuffle chunk Ids.
+ */
+ private[storage] def removePendingChunks(
+ failedBlockId: ShuffleBlockChunkId,
+ address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = {
+ val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]()
+
+ def sameShuffleReducePartition(block: BlockId): Boolean = {
+ val chunkId = block.asInstanceOf[ShuffleBlockChunkId]
+ chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId
+ }
+
+ def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = {
+ val fetchRequestsToRemove = new mutable.Queue[FetchRequest]()
+ fetchRequestsToRemove ++= queue.dequeueAll { req =>
+ val firstBlock = req.blocks.head
+ firstBlock.blockId.isShuffleChunk && req.address.equals(address) &&
+ sameShuffleReducePartition(firstBlock.blockId)
+ }
+ fetchRequestsToRemove.foreach { _ =>
+ removedChunkIds ++=
+ fetchRequestsToRemove.flatMap(_.blocks.map(_.blockId.asInstanceOf[ShuffleBlockChunkId]))
+ }
+ }
+
+ filterRequests(fetchRequests)
+ deferredFetchRequests.get(address).foreach { defRequests =>
+ filterRequests(defRequests)
+ if (defRequests.isEmpty) deferredFetchRequests.remove(address)
+ }
+ removedChunkIds
+ }
+}
+
+/**
+ * Helper class that ensures a ManagedBuffer is released upon InputStream.close() and
+ * also detects stream corruption if streamCompressedOrEncrypted is true
+ */
+private class StarBufferReleasingInputStream(
+ // This is visible for testing
+ private[storage] val delegate: InputStream,
+ private val iterator: StarShuffleBlockFetcherIterator,
+ private val blockId: BlockId,
+ private val mapIndex: Int,
+ private val address: BlockManagerId,
+ private val detectCorruption: Boolean,
+ private val isNetworkReqDone: Boolean,
+ private val checkedInOpt: Option[CheckedInputStream])
+ extends InputStream {
+ private[this] var closed = false
+
+ override def read(): Int =
+ tryOrFetchFailedException(delegate.read())
+
+ override def close(): Unit = {
+ if (!closed) {
+ try {
+ delegate.close()
+ iterator.releaseCurrentResultBuffer()
+ } finally {
+ // Unset the flag when a remote request finished and free memory is fairly enough.
+ if (isNetworkReqDone) {
+ ShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible(iterator.maxReqSizeShuffleToMem)
+ }
+ closed = true
+ }
+ }
+ }
+
+ override def available(): Int = delegate.available()
+
+ override def mark(readlimit: Int): Unit = delegate.mark(readlimit)
+
+ override def skip(n: Long): Long =
+ tryOrFetchFailedException(delegate.skip(n))
+
+ override def markSupported(): Boolean = delegate.markSupported()
+
+ override def read(b: Array[Byte]): Int =
+ tryOrFetchFailedException(delegate.read(b))
+
+ override def read(b: Array[Byte], off: Int, len: Int): Int =
+ tryOrFetchFailedException(delegate.read(b, off, len))
+
+ override def reset(): Unit = delegate.reset()
+
+ /**
+ * Execute a block of code that returns a value, close this stream quietly and re-throwing
+ * IOException as FetchFailedException when detectCorruption is true. This method is only
+ * used by the `read` and `skip` methods inside `BufferReleasingInputStream` currently.
+ */
+ private def tryOrFetchFailedException[T](block: => T): T = {
+ try {
+ block
+ } catch {
+ case e: IOException if detectCorruption =>
+ val diagnosisResponse = checkedInOpt.map { checkedIn =>
+ iterator.diagnoseCorruption(checkedIn, address, blockId)
+ }
+ IOUtils.closeQuietly(this)
+ // We'd never retry the block whatever the cause is since the block has been
+ // partially consumed by downstream RDDs.
+ iterator.throwFetchFailedException(blockId, mapIndex, address, e, diagnosisResponse)
+ }
+ }
+}
+
+/**
+ * A listener to be called at the completion of the ShuffleBlockFetcherIterator
+ * @param data the ShuffleBlockFetcherIterator to process
+ */
+private class StarShuffleFetchCompletionListener(var data: StarShuffleBlockFetcherIterator)
+ extends TaskCompletionListener {
+
+ override def onTaskCompletion(context: TaskContext): Unit = {
+ if (data != null) {
+ data.cleanup()
+ // Null out the referent here to make sure we don't keep a reference to this
+ // ShuffleBlockFetcherIterator, after we're done reading from it, to let it be
+ // collected during GC. Otherwise we can hold metadata on block locations(blocksByAddress)
+ data = null
+ }
+ }
+
+ // Just an alias for onTaskCompletion to avoid confusing
+ def onComplete(context: TaskContext): Unit = this.onTaskCompletion(context)
+}
+
+private[storage]
+object StarShuffleBlockFetcherIterator {
+
+ /**
+ * A flag which indicates whether the Netty OOM error has raised during shuffle.
+ * If true, unless there's no in-flight fetch requests, all the pending shuffle
+ * fetch requests will be deferred until the flag is unset (whenever there's a
+ * complete fetch request).
+ */
+ val isNettyOOMOnShuffle = new AtomicBoolean(false)
+
+ def resetNettyOOMFlagIfPossible(freeMemoryLowerBound: Long): Unit = {
+ if (isNettyOOMOnShuffle.get() && NettyUtils.freeDirectMemory() >= freeMemoryLowerBound) {
+ isNettyOOMOnShuffle.compareAndSet(true, false)
+ }
+ }
+
+ /**
+ * This function is used to merged blocks when doBatchFetch is true. Blocks which have the
+ * same `mapId` can be merged into one block batch. The block batch is specified by a range
+ * of reduceId, which implies the continuous shuffle blocks that we can fetch in a batch.
+ * For example, input blocks like (shuffle_0_0_0, shuffle_0_0_1, shuffle_0_1_0) can be
+ * merged into (shuffle_0_0_0_2, shuffle_0_1_0_1), and input blocks like (shuffle_0_0_0_2,
+ * shuffle_0_0_2, shuffle_0_0_3) can be merged into (shuffle_0_0_0_4).
+ *
+ * @param blocks blocks to be merged if possible. May contains already merged blocks.
+ * @param doBatchFetch whether to merge blocks.
+ * @return the input blocks if doBatchFetch=false, or the merged blocks if doBatchFetch=true.
+ */
+ def mergeContinuousShuffleBlockIdsIfNeeded(
+ blocks: Seq[FetchBlockInfo],
+ doBatchFetch: Boolean): Seq[FetchBlockInfo] = {
+ val result = if (doBatchFetch) {
+ val curBlocks = new ArrayBuffer[FetchBlockInfo]
+ val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo]
+
+ def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = {
+ val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId]
+
+ // The last merged block may comes from the input, and we can merge more blocks
+ // into it, if the map id is the same.
+ def shouldMergeIntoPreviousBatchBlockId =
+ mergedBlockInfo.last.blockId.asInstanceOf[ShuffleBlockBatchId].mapId == startBlockId.mapId
+
+ val (startReduceId, size) =
+ if (mergedBlockInfo.nonEmpty && shouldMergeIntoPreviousBatchBlockId) {
+ // Remove the previous batch block id as we will add a new one to replace it.
+ val removed = mergedBlockInfo.remove(mergedBlockInfo.length - 1)
+ (removed.blockId.asInstanceOf[ShuffleBlockBatchId].startReduceId,
+ removed.size + toBeMerged.map(_.size).sum)
+ } else {
+ (startBlockId.reduceId, toBeMerged.map(_.size).sum)
+ }
+
+ FetchBlockInfo(
+ ShuffleBlockBatchId(
+ startBlockId.shuffleId,
+ startBlockId.mapId,
+ startReduceId,
+ toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1),
+ size,
+ toBeMerged.head.mapIndex)
+ }
+
+ val iter = blocks.iterator
+ while (iter.hasNext) {
+ val info = iter.next()
+ // It's possible that the input block id is already a batch ID. For example, we merge some
+ // blocks, and then make fetch requests with the merged blocks according to "max blocks per
+ // request". The last fetch request may be too small, and we give up and put the remaining
+ // merged blocks back to the input list.
+ if (info.blockId.isInstanceOf[ShuffleBlockBatchId]) {
+ mergedBlockInfo += info
+ } else {
+ if (curBlocks.isEmpty) {
+ curBlocks += info
+ } else {
+ val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId]
+ val currentMapId = curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId
+ if (curBlockId.mapId != currentMapId) {
+ mergedBlockInfo += mergeFetchBlockInfo(curBlocks)
+ curBlocks.clear()
+ }
+ curBlocks += info
+ }
+ }
+ }
+ if (curBlocks.nonEmpty) {
+ mergedBlockInfo += mergeFetchBlockInfo(curBlocks)
+ }
+ mergedBlockInfo
+ } else {
+ blocks
+ }
+ result.toSeq
+ }
+
+ /**
+ * The block information to fetch used in FetchRequest.
+ * @param blockId block id
+ * @param size estimated size of the block. Note that this is NOT the exact bytes.
+ * Size of remote block is used to calculate bytesInFlight.
+ * @param mapIndex the mapIndex for this block, which indicate the index in the map stage.
+ */
+ private[storage] case class FetchBlockInfo(
+ blockId: BlockId,
+ size: Long,
+ mapIndex: Int)
+
+ /**
+ * A request to fetch blocks from a remote BlockManager.
+ * @param address remote BlockManager to fetch from.
+ * @param blocks Sequence of the information for blocks to fetch from the same address.
+ * @param forMergedMetas true if this request is for requesting push-merged meta information;
+ * false if it is for regular or shuffle chunks.
+ */
+ case class FetchRequest(
+ address: BlockManagerId,
+ blocks: Seq[FetchBlockInfo],
+ forMergedMetas: Boolean = false) {
+ val size = blocks.map(_.size).sum
+ }
+
+ /**
+ * Result of a fetch from a remote block.
+ */
+ private[storage] sealed trait FetchResult
+
+ /**
+ * Result of a fetch from a remote block successfully.
+ * @param blockId block id
+ * @param mapIndex the mapIndex for this block, which indicate the index in the map stage.
+ * @param address BlockManager that the block was fetched from.
+ * @param size estimated size of the block. Note that this is NOT the exact bytes.
+ * Size of remote block is used to calculate bytesInFlight.
+ * @param buf `ManagedBuffer` for the content.
+ * @param isNetworkReqDone Is this the last network request for this host in this fetch request.
+ */
+ private[storage] case class SuccessFetchResult(
+ blockId: BlockId,
+ mapIndex: Int,
+ address: BlockManagerId,
+ size: Long,
+ buf: ManagedBuffer,
+ isNetworkReqDone: Boolean) extends FetchResult {
+ require(buf != null)
+ require(size >= 0)
+ }
+
+ /**
+ * Result of a fetch from a remote block unsuccessfully.
+ * @param blockId block id
+ * @param mapIndex the mapIndex for this block, which indicate the index in the map stage
+ * @param address BlockManager that the block was attempted to be fetched from
+ * @param e the failure exception
+ */
+ private[storage] case class FailureFetchResult(
+ blockId: BlockId,
+ mapIndex: Int,
+ address: BlockManagerId,
+ e: Throwable)
+ extends FetchResult
+
+ /**
+ * Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM
+ */
+ private[storage]
+ case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+ /**
+ * Result of an un-successful fetch of either of these:
+ * 1) Remote shuffle chunk.
+ * 2) Local push-merged block.
+ *
+ * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original blocks.
+ *
+ * @param blockId block id
+ * @param address BlockManager that the push-merged block was attempted to be fetched from
+ * @param size size of the block, used to update bytesInFlight.
+ * @param isNetworkReqDone Is this the last network request for this host in this fetch
+ * request. Used to update reqsInFlight.
+ */
+ private[storage] case class FallbackOnPushMergedFailureResult(blockId: BlockId,
+ address: BlockManagerId,
+ size: Long,
+ isNetworkReqDone: Boolean) extends FetchResult
+
+ /**
+ * Result of a successful fetch of meta information for a remote push-merged block.
+ *
+ * @param shuffleId shuffle id.
+ * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
+ * of shuffle by an indeterminate stage attempt.
+ * @param reduceId reduce id.
+ * @param blockSize size of each push-merged block.
+ * @param bitmaps bitmaps for every chunk.
+ * @param address BlockManager that the meta was fetched from.
+ */
+ private[storage] case class PushMergedRemoteMetaFetchResult(
+ shuffleId: Int,
+ shuffleMergeId: Int,
+ reduceId: Int,
+ blockSize: Long,
+ bitmaps: Array[RoaringBitmap],
+ address: BlockManagerId) extends FetchResult
+
+ /**
+ * Result of a failure while fetching the meta information for a remote push-merged block.
+ *
+ * @param shuffleId shuffle id.
+ * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
+ * of shuffle by an indeterminate stage attempt.
+ * @param reduceId reduce id.
+ * @param address BlockManager that the meta was fetched from.
+ */
+ private[storage] case class PushMergedRemoteMetaFailedFetchResult(
+ shuffleId: Int,
+ shuffleMergeId: Int,
+ reduceId: Int,
+ address: BlockManagerId) extends FetchResult
+
+ /**
+ * Result of a successful fetch of meta information for a push-merged-local block.
+ *
+ * @param shuffleId shuffle id.
+ * @param shuffleMergeId shuffleMergeId is used to uniquely identify merging process
+ * of shuffle by an indeterminate stage attempt.
+ * @param reduceId reduce id.
+ * @param bitmaps bitmaps for every chunk.
+ * @param localDirs local directories where the push-merged shuffle files are storedl
+ */
+ private[storage] case class PushMergedLocalMetaFetchResult(
+ shuffleId: Int,
+ shuffleMergeId: Int,
+ reduceId: Int,
+ bitmaps: Array[RoaringBitmap],
+ localDirs: Array[String]) extends FetchResult
+}
diff --git a/external-shuffle-storage/src/test/java/org/apache/spark/starshuffle/ByteBufUtilsTest.java b/external-shuffle-storage/src/test/java/org/apache/spark/starshuffle/ByteBufUtilsTest.java
new file mode 100644
index 0000000000000..2e7f6c3269086
--- /dev/null
+++ b/external-shuffle-storage/src/test/java/org/apache/spark/starshuffle/ByteBufUtilsTest.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.starshuffle;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.PooledByteBufAllocator;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class ByteBufUtilsTest {
+
+ @Test
+ public void writeLengthAndString() {
+ ByteBuf buf = PooledByteBufAllocator.DEFAULT.buffer(1);
+ ByteBufUtils.writeLengthAndString(buf, "hello world");
+
+ String str = ByteBufUtils.readLengthAndString(buf);
+ Assert.assertEquals(str, "hello world");
+
+ buf.release();
+ }
+}
diff --git a/external-shuffle-storage/src/test/resources/log4j.properties b/external-shuffle-storage/src/test/resources/log4j.properties
new file mode 100644
index 0000000000000..d1cfaf521ab08
--- /dev/null
+++ b/external-shuffle-storage/src/test/resources/log4j.properties
@@ -0,0 +1,33 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Set everything to be logged to the file target/unit-tests.log
+test.appender=file
+log4j.rootCategory=INFO, ${test.appender}
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=true
+log4j.appender.file.file=target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+# Tests that launch java subprocesses can set the "test.appender" system property to
+# "console" to avoid having the child process's logs overwrite the unit test's
+# log file.
+log4j.appender.console=org.apache.log4j.ConsoleAppender
+log4j.appender.console.target=System.err
+log4j.appender.console.layout=org.apache.log4j.PatternLayout
+log4j.appender.console.layout.ConversionPattern=%t: %m%n
+# Ignore messages below warning level from Jetty, because it's a bit verbose
+log4j.logger.org.sparkproject.jetty=WARN
diff --git a/external-shuffle-storage/src/test/scala/org/apache/spark/shuffle/StarShuffleManagerTest.scala b/external-shuffle-storage/src/test/scala/org/apache/spark/shuffle/StarShuffleManagerTest.scala
new file mode 100644
index 0000000000000..e7e23eb0a4421
--- /dev/null
+++ b/external-shuffle-storage/src/test/scala/org/apache/spark/shuffle/StarShuffleManagerTest.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.util.UUID
+
+import org.junit.Test
+import org.scalatest.Assertions._
+
+import org.apache.spark._
+
+class StarShuffleManagerTest {
+
+ @Test
+ def foldByKey(): Unit = {
+ val conf = newSparkConf()
+ runWithSparkConf(conf)
+ }
+
+ @Test
+ def foldByKey_zeroBuffering(): Unit = {
+ val conf = newSparkConf()
+ conf.set("spark.reducer.maxSizeInFlight", "0")
+ conf.set("spark.network.maxRemoteBlockSizeFetchToMem", "0")
+ runWithSparkConf(conf)
+ }
+
+ private def runWithSparkConf(conf: SparkConf) = {
+ var sc = new SparkContext(conf)
+
+ try {
+ val numValues = 10000
+ val numMaps = 3
+ val numPartitions = 5
+
+ val rdd = sc.parallelize(0 until numValues, numMaps)
+ .map(t => ((t / 2) -> (t * 2).longValue()))
+ .foldByKey(0, numPartitions)((v1, v2) => v1 + v2)
+ val result = rdd.collect()
+
+ assert(result.size === numValues / 2)
+
+ for (i <- 0 until result.size) {
+ val key = result(i)._1
+ val value = result(i)._2
+ assert(key * 2 * 2 + (key * 2 + 1) * 2 === value)
+ }
+
+ val keys = result.map(_._1).distinct.sorted
+ assert(keys.length === numValues / 2)
+ assert(keys(0) === 0)
+ assert(keys.last === (numValues - 1) / 2)
+ } finally {
+ sc.stop()
+ }
+ }
+
+ def newSparkConf(): SparkConf = new SparkConf()
+ .setAppName("testApp")
+ .setMaster(s"local[2]")
+ .set("spark.ui.enabled", "false")
+ .set("spark.driver.allowMultipleContexts", "true")
+ .set("spark.app.id", "app-" + UUID.randomUUID())
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .set("spark.shuffle.manager", "org.apache.spark.shuffle.StarShuffleManager")
+}
diff --git a/pom.xml b/pom.xml
index aefb5377e6b7b..251ff62ac7b87 100644
--- a/pom.xml
+++ b/pom.xml
@@ -3397,6 +3397,13 @@
+
+ external-shuffle-storage
+
+ external-shuffle-storage
+
+
+
test-java-home