Skip to content

Commit c0e93b3

Browse files
ottercMin Shenzhouyejoe
authored andcommitted
[SPARK-32917][SHUFFLE][CORE] Adds support for executors to push shuffle blocks after successful map task completion
This is the shuffle writer side change where executors can push data to remote shuffle services. This is needed for push-based shuffle - SPIP [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602). Summary of changes: - This adds support for executors to push shuffle blocks after map tasks complete writing shuffle data. - This also introduces a timeout specifically for creating connection to remote shuffle services. - These changes are needed for push-based shuffle. Refer to the SPIP in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602). - The main reason to create a separate connection creation timeout is because the existing `connectionTimeoutMs` is overloaded and is used for connection creation timeouts as well as connection idle timeout. The connection creation timeout should be much lower than the idle timeouts. The default for `connectionTimeoutMs` is 120s. This is quite high for just establishing the connections. If a shuffle server node is bad then the connection creation will fail within few seconds. However, an overloaded shuffle server may take much longer to respond to a request and the channel can stay idle for a much longer time which is expected. Another reason is that with push-based shuffle, an executor may be fetching shuffle data and pushing shuffle data (next stage) simultaneously. Both these tasks will share the same connections with the shuffle service. If there is a bad shuffle server node and the connection creation timeout is very high then both these tasks end up waiting a long time time eventually impacting the performance. Yes. This PR introduces client-side configs for push-based shuffle. If push-based shuffle is turned-off then the users will not see any change. Added unit tests. The reference PR with the consolidated changes covering the complete implementation is also provided in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602). We have already verified the functionality and the improved performance as documented in the SPIP doc. Lead-authored-by: Min Shen mshenlinkedin.com Co-authored-by: Chandni Singh chsinghlinkedin.com Co-authored-by: Ye Zhou yezhoulinkedin.com Closes #30312 from otterc/SPARK-32917. Lead-authored-by: Chandni Singh <[email protected]> Co-authored-by: Chandni Singh <[email protected]> Co-authored-by: Min Shen <[email protected]> Co-authored-by: Ye Zhou <[email protected]> Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
1 parent 8ea1367 commit c0e93b3

File tree

12 files changed

+896
-12
lines changed

12 files changed

+896
-12
lines changed

common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ TransportClient createClient(InetSocketAddress address)
254254
// Disable Nagle's Algorithm since we don't want packets to wait
255255
.option(ChannelOption.TCP_NODELAY, true)
256256
.option(ChannelOption.SO_KEEPALIVE, true)
257-
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
257+
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionCreationTimeoutMs())
258258
.option(ChannelOption.ALLOCATOR, pooledAllocator);
259259

260260
if (conf.receiveBuf() > 0) {
@@ -280,9 +280,10 @@ public void initChannel(SocketChannel ch) {
280280
// Connect to the remote server
281281
long preConnect = System.nanoTime();
282282
ChannelFuture cf = bootstrap.connect(address);
283-
if (!cf.await(conf.connectionTimeoutMs())) {
283+
if (!cf.await(conf.connectionCreationTimeoutMs())) {
284284
throw new IOException(
285-
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
285+
String.format("Connecting to %s timed out (%s ms)",
286+
address, conf.connectionCreationTimeoutMs()));
286287
} else if (cf.cause() != null) {
287288
throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
288289
}

common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.util.Locale;
2121
import java.util.Properties;
22+
import java.util.concurrent.TimeUnit;
2223

2324
import com.google.common.primitives.Ints;
2425
import io.netty.util.NettyRuntime;
@@ -31,6 +32,7 @@ public class TransportConf {
3132
private final String SPARK_NETWORK_IO_MODE_KEY;
3233
private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY;
3334
private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY;
35+
private final String SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY;
3436
private final String SPARK_NETWORK_IO_BACKLOG_KEY;
3537
private final String SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY;
3638
private final String SPARK_NETWORK_IO_ACCEPTORTHREADS_KEY;
@@ -59,6 +61,7 @@ public TransportConf(String module, ConfigProvider conf) {
5961
SPARK_NETWORK_IO_MODE_KEY = getConfKey("io.mode");
6062
SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY = getConfKey("io.preferDirectBufs");
6163
SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY = getConfKey("io.connectionTimeout");
64+
SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY = getConfKey("io.connectionCreationTimeout");
6265
SPARK_NETWORK_IO_BACKLOG_KEY = getConfKey("io.backLog");
6366
SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY = getConfKey("io.numConnectionsPerPeer");
6467
SPARK_NETWORK_IO_ACCEPTORTHREADS_KEY = getConfKey("io.acceptorThreads");
@@ -108,7 +111,7 @@ public boolean preferDirectBufs() {
108111
return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true);
109112
}
110113

111-
/** Connect timeout in milliseconds. Default 120 secs. */
114+
/** Connection idle timeout in milliseconds. Default 120 secs. */
112115
public int connectionTimeoutMs() {
113116
long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec(
114117
conf.get("spark.network.timeout", "120s"));
@@ -139,6 +142,14 @@ public int streamReadTimeoutMs() {
139142
return (int) defaultTimeoutMs;
140143
}
141144

145+
/** Connect creation timeout in milliseconds. Default 30 secs. */
146+
public int connectionCreationTimeoutMs() {
147+
long connectionTimeoutS = TimeUnit.MILLISECONDS.toSeconds(connectionTimeoutMs());
148+
long defaultTimeoutMs = JavaUtils.timeStringAsSec(
149+
conf.get(SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY, connectionTimeoutS + "s")) * 1000;
150+
return (int) defaultTimeoutMs;
151+
}
152+
142153
/** Number of concurrent connections between two nodes for fetching data. */
143154
public int numConnectionsPerPeer() {
144155
return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1);

core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import scala.Tuple2;
3232
import scala.collection.Iterator;
3333

34-
import com.google.common.annotations.VisibleForTesting;
3534
import com.google.common.io.Closeables;
3635
import org.slf4j.Logger;
3736
import org.slf4j.LoggerFactory;
@@ -178,8 +177,8 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
178177
}
179178
}
180179

181-
@VisibleForTesting
182-
long[] getPartitionLengths() {
180+
@Override
181+
public long[] getPartitionLengths() {
183182
return partitionLengths;
184183
}
185184

core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
8787

8888
@Nullable private MapStatus mapStatus;
8989
@Nullable private ShuffleExternalSorter sorter;
90+
@Nullable private long[] partitionLengths;
9091
private long peakMemoryUsedBytes = 0;
9192

9293
/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
@@ -218,7 +219,6 @@ void closeAndWriteOutput() throws IOException {
218219
serOutputStream = null;
219220
final SpillInfo[] spills = sorter.closeAndGetSpills();
220221
sorter = null;
221-
final long[] partitionLengths;
222222
try {
223223
partitionLengths = mergeSpills(spills);
224224
} finally {
@@ -528,4 +528,9 @@ public void close() throws IOException {
528528
channel.close();
529529
}
530530
}
531+
532+
@Override
533+
public long[] getPartitionLengths() {
534+
return partitionLengths;
535+
}
531536
}

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ import org.apache.spark.metrics.source.JVMCPUSource
4646
import org.apache.spark.resource.ResourceInformation
4747
import org.apache.spark.rpc.RpcTimeout
4848
import org.apache.spark.scheduler._
49-
import org.apache.spark.shuffle.FetchFailedException
49+
import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher}
5050
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
5151
import org.apache.spark.util._
5252
import org.apache.spark.util.io.ChunkedByteBuffer
@@ -307,6 +307,7 @@ private[spark] class Executor(
307307
case NonFatal(e) =>
308308
logWarning("Unable to stop heartbeater", e)
309309
}
310+
ShuffleBlockPusher.stop()
310311
threadPool.shutdown()
311312

312313
// Notify plugins that executor is shutting down so they can terminate cleanly

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2134,4 +2134,33 @@ package object config {
21342134
.version("3.1.0")
21352135
.doubleConf
21362136
.createWithDefault(5)
2137+
2138+
private[spark] val SHUFFLE_NUM_PUSH_THREADS =
2139+
ConfigBuilder("spark.shuffle.push.numPushThreads")
2140+
.doc("Specify the number of threads in the block pusher pool. These threads assist " +
2141+
"in creating connections and pushing blocks to remote shuffle services. By default, the " +
2142+
"threadpool size is equal to the number of spark executor cores.")
2143+
.version("3.2.0")
2144+
.intConf
2145+
.createOptional
2146+
2147+
private[spark] val SHUFFLE_MAX_BLOCK_SIZE_TO_PUSH =
2148+
ConfigBuilder("spark.shuffle.push.maxBlockSizeToPush")
2149+
.doc("The max size of an individual block to push to the remote shuffle services. Blocks " +
2150+
"larger than this threshold are not pushed to be merged remotely. These shuffle blocks " +
2151+
"will be fetched by the executors in the original manner.")
2152+
.version("3.2.0")
2153+
.bytesConf(ByteUnit.BYTE)
2154+
.createWithDefaultString("1m")
2155+
2156+
private[spark] val SHUFFLE_MAX_BLOCK_BATCH_SIZE_FOR_PUSH =
2157+
ConfigBuilder("spark.shuffle.push.maxBlockBatchSize")
2158+
.doc("The max size of a batch of shuffle blocks to be grouped into a single push request.")
2159+
.version("3.2.0")
2160+
.bytesConf(ByteUnit.BYTE)
2161+
// Default is 3m because it is greater than 2m which is the default value for
2162+
// TransportConf#memoryMapBytes. If this defaults to 2m as well it is very likely that each
2163+
// batch of block will be loaded in memory with memory mapping, which has higher overhead
2164+
// with small MB sized chunk of data.
2165+
.createWithDefaultString("3m")
21372166
}

0 commit comments

Comments
 (0)