Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

import scala.Tuple2;

import com.google.common.base.Preconditions;
import io.netty.channel.Channel;
import org.slf4j.Logger;
Expand Down Expand Up @@ -94,6 +96,25 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
return nextChunk;
}

@Override
public ManagedBuffer openStream(String streamChunkId) {
Tuple2<Long, Integer> streamIdAndChunkId = parseStreamChunkId(streamChunkId);
return getChunk(streamIdAndChunkId._1, streamIdAndChunkId._2);
}

public static String genStreamChunkId(long streamId, int chunkId) {
return String.format("%d_%d", streamId, chunkId);
}

public static Tuple2<Long, Integer> parseStreamChunkId(String streamChunkId) {
String[] array = streamChunkId.split("_");
assert array.length == 2:
"Stream id and chunk index should be specified when open stream for fetching block.";
long streamId = Long.valueOf(array[0]);
int chunkIndex = Integer.valueOf(array[1]);
return new Tuple2<>(streamId, chunkIndex);
}

@Override
public void connectionTerminated(Channel channel) {
// Close all streams which have been associated with the channel.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.network.shuffle;

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
Expand Down Expand Up @@ -86,14 +87,16 @@ public void fetchBlocks(
int port,
String execId,
String[] blockIds,
BlockFetchingListener listener) {
BlockFetchingListener listener,
File[] shuffleFiles) {
checkInit();
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try {
RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
(blockIds1, listener1) -> {
TransportClient client = clientFactory.createClient(host, port);
new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1).start();
new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1, conf,
shuffleFiles).start();
};

int maxRetries = conf.maxIORetries();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,28 @@

package org.apache.spark.network.shuffle;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.util.Arrays;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.util.TransportConf;

/**
* Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and
Expand All @@ -48,6 +57,8 @@ public class OneForOneBlockFetcher {
private final String[] blockIds;
private final BlockFetchingListener listener;
private final ChunkReceivedCallback chunkCallback;
private TransportConf transportConf = null;
private File[] shuffleFiles = null;

private StreamHandle streamHandle = null;

Expand All @@ -56,12 +67,20 @@ public OneForOneBlockFetcher(
String appId,
String execId,
String[] blockIds,
BlockFetchingListener listener) {
BlockFetchingListener listener,
TransportConf transportConf,
File[] shuffleFiles) {
this.client = client;
this.openMessage = new OpenBlocks(appId, execId, blockIds);
this.blockIds = blockIds;
this.listener = listener;
this.chunkCallback = new ChunkCallback();
this.transportConf = transportConf;
if (shuffleFiles != null) {
this.shuffleFiles = shuffleFiles;
assert this.shuffleFiles.length == blockIds.length:
"Number of shuffle files should equal to blocks";
}
}

/** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
Expand Down Expand Up @@ -100,7 +119,12 @@ public void onSuccess(ByteBuffer response) {
// Immediately request all chunks -- we expect that the total size of the request is
// reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
for (int i = 0; i < streamHandle.numChunks; i++) {
client.fetchChunk(streamHandle.streamId, i, chunkCallback);
if (shuffleFiles != null) {
client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
new DownloadCallback(shuffleFiles[i], i));
} else {
client.fetchChunk(streamHandle.streamId, i, chunkCallback);
}
}
} catch (Exception e) {
logger.error("Failed while starting block fetches after success", e);
Expand All @@ -126,4 +150,38 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
}
}
}

private class DownloadCallback implements StreamCallback {

private WritableByteChannel channel = null;
private File targetFile = null;
private int chunkIndex;

public DownloadCallback(File targetFile, int chunkIndex) throws IOException {
this.targetFile = targetFile;
this.channel = Channels.newChannel(new FileOutputStream(targetFile));
this.chunkIndex = chunkIndex;
}

@Override
public void onData(String streamId, ByteBuffer buf) throws IOException {
channel.write(buf);
}

@Override
public void onComplete(String streamId) throws IOException {
channel.close();
ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0,
targetFile.length());
listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
}

@Override
public void onFailure(String streamId, Throwable cause) throws IOException {
channel.close();
// On receipt of a failure, fail every block from chunkIndex onwards.
String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length);
failRemainingBlocks(remainingBlockIds, cause);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.network.shuffle;

import java.io.Closeable;
import java.io.File;

/** Provides an interface for reading shuffle files, either from an Executor or external service. */
public abstract class ShuffleClient implements Closeable {
Expand All @@ -40,5 +41,6 @@ public abstract void fetchBlocks(
int port,
String execId,
String[] blockIds,
BlockFetchingListener listener);
BlockFetchingListener listener,
File[] shuffleFiles);
}
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ public void onBlockFetchFailure(String blockId, Throwable t) {

String[] blockIds = { "shuffle_2_3_4", "shuffle_6_7_8" };
OneForOneBlockFetcher fetcher =
new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener);
new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf, null);
fetcher.start();
blockFetchLatch.await();
checkSecurityException(exception.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) {
}
}
}
});
}, null);

if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) {
fail("Timeout getting response from the server");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,13 @@
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;

public class OneForOneBlockFetcherSuite {

private static final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);

@Test
public void testFetchOne() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
Expand Down Expand Up @@ -126,7 +131,7 @@ private static BlockFetchingListener fetchBlocks(LinkedHashMap<String, ManagedBu
BlockFetchingListener listener = mock(BlockFetchingListener.class);
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
OneForOneBlockFetcher fetcher =
new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener);
new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener, conf, null);

// Respond to the "OpenBlocks" message with an appropriate ShuffleStreamHandle with streamId 123
doAnswer(invocationOnMock -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import javax.annotation.concurrent.GuardedBy;
import java.io.IOException;
import java.nio.channels.ClosedByInterruptException;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.BitSet;
Expand Down Expand Up @@ -184,6 +185,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
break;
}
}
} catch (ClosedByInterruptException e) {
// This called by user to kill a task (e.g: speculative task).
logger.error("error while calling spill() on " + c, e);
throw new RuntimeException(e.getMessage());
} catch (IOException e) {
logger.error("error while calling spill() on " + c, e);
throw new OutOfMemoryError("error while calling spill() on " + c + " : "
Expand All @@ -201,6 +206,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
Utils.bytesToString(released), consumer);
got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode);
}
} catch (ClosedByInterruptException e) {
// This called by user to kill a task (e.g: speculative task).
logger.error("error while calling spill() on " + consumer, e);
throw new RuntimeException(e.getMessage());
} catch (IOException e) {
logger.error("error while calling spill() on " + consumer, e);
throw new OutOfMemoryError("error while calling spill() on " + consumer + " : "
Expand Down
47 changes: 23 additions & 24 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1801,40 +1801,39 @@ class SparkContext(config: SparkConf) extends Logging {
* an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node.
*/
def addJar(path: String) {
def addJarFile(file: File): String = {
try {
if (!file.exists()) {
throw new FileNotFoundException(s"Jar ${file.getAbsolutePath} not found")
}
if (file.isDirectory) {
throw new IllegalArgumentException(
s"Directory ${file.getAbsoluteFile} is not allowed for addJar")
}
env.rpcEnv.fileServer.addJar(file)
} catch {
case NonFatal(e) =>
logError(s"Failed to add $path to Spark environment", e)
null
}
}

if (path == null) {
logWarning("null specified as parameter to addJar")
} else {
var key = ""
if (path.contains("\\")) {
val key = if (path.contains("\\")) {
// For local paths with backslashes on Windows, URI throws an exception
key = env.rpcEnv.fileServer.addJar(new File(path))
addJarFile(new File(path))
} else {
val uri = new URI(path)
// SPARK-17650: Make sure this is a valid URL before adding it to the list of dependencies
Utils.validateURL(uri)
key = uri.getScheme match {
uri.getScheme match {
// A JAR file which exists only on the driver node
case null | "file" =>
try {
val file = new File(uri.getPath)
if (!file.exists()) {
throw new FileNotFoundException(s"Jar ${file.getAbsolutePath} not found")
}
if (file.isDirectory) {
throw new IllegalArgumentException(
s"Directory ${file.getAbsoluteFile} is not allowed for addJar")
}
env.rpcEnv.fileServer.addJar(new File(uri.getPath))
} catch {
case NonFatal(e) =>
logError(s"Failed to add $path to Spark environment", e)
null
}
case null | "file" => addJarFile(new File(uri.getPath))
// A JAR file which exists locally on every worker node
case "local" =>
"file:" + uri.getPath
case _ =>
path
case "local" => "file:" + uri.getPath
case _ => path
}
}
if (key != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,4 +287,10 @@ package object config {
.bytesConf(ByteUnit.BYTE)
.createWithDefault(100 * 1024 * 1024)

private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM =
ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem")
.doc("The blocks of a shuffle request will be fetched to disk when size of the request is " +
"above this threshold. This is to avoid a giant request takes too much memory.")
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("200m")
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.network

import java.io.Closeable
import java.io.{Closeable, File}
import java.nio.ByteBuffer

import scala.concurrent.{Future, Promise}
Expand Down Expand Up @@ -67,7 +67,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
port: Int,
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener): Unit
listener: BlockFetchingListener,
shuffleFiles: Array[File]): Unit

/**
* Upload a single block to a remote node, available only after [[init]] is invoked.
Expand Down Expand Up @@ -100,7 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
ret.flip()
result.success(new NioManagedBuffer(ret))
}
})
}, shuffleFiles = null)
ThreadUtils.awaitResult(result.future, Duration.Inf)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.network.netty

import java.io.File
import java.nio.ByteBuffer

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -88,13 +89,15 @@ private[spark] class NettyBlockTransferService(
port: Int,
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener): Unit = {
listener: BlockFetchingListener,
shuffleFiles: Array[File]): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
val client = clientFactory.createClient(host, port)
new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener,
transportConf, shuffleFiles).start()
}
}

Expand Down
Loading